You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

Untitled.ipynb 1.5KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "code",
  5. "execution_count": 33,
  6. "metadata": {},
  7. "outputs": [
  8. {
  9. "name": "stdout",
  10. "output_type": "stream",
  11. "text": [
  12. "[0 1]\n",
  13. "[1 0]\n"
  14. ]
  15. }
  16. ],
  17. "source": [
  18. "import numpy as np\n",
  19. "\n",
  20. "y_out = [[[-0.52742714, -0.8918941 , -0.53989583, -0.874211 ]]]\n",
  21. "\n",
  22. "outputs = y_out[0][0]\n",
  23. "\n",
  24. "prob_action1 = outputs[:2]\n",
  25. "prob_action2 = outputs[2:]\n",
  26. "\n",
  27. "norm_action1 = [float(i)/sum(prob_action1) for i in prob_action1]\n",
  28. "norm_action2 = [float(i)/sum(prob_action2) for i in prob_action2]\n",
  29. "\n",
  30. "action1 = np.random.multinomial(1,norm_action1)\n",
  31. "action2 = np.random.multinomial(1,norm_action2)\n",
  32. "print(action1)\n",
  33. "print(action2)"
  34. ]
  35. },
  36. {
  37. "cell_type": "code",
  38. "execution_count": null,
  39. "metadata": {},
  40. "outputs": [],
  41. "source": []
  42. },
  43. {
  44. "cell_type": "code",
  45. "execution_count": null,
  46. "metadata": {},
  47. "outputs": [],
  48. "source": []
  49. },
  50. {
  51. "cell_type": "code",
  52. "execution_count": null,
  53. "metadata": {},
  54. "outputs": [],
  55. "source": []
  56. }
  57. ],
  58. "metadata": {
  59. "kernelspec": {
  60. "display_name": "Python 3",
  61. "language": "python",
  62. "name": "python3"
  63. },
  64. "language_info": {
  65. "codemirror_mode": {
  66. "name": "ipython",
  67. "version": 3
  68. },
  69. "file_extension": ".py",
  70. "mimetype": "text/x-python",
  71. "name": "python",
  72. "nbconvert_exporter": "python",
  73. "pygments_lexer": "ipython3",
  74. "version": "3.6.7"
  75. }
  76. },
  77. "nbformat": 4,
  78. "nbformat_minor": 2
  79. }