weixin_39710462
2021-01-07 11:56PPO training error with rewards other than 0
Getting the error below while training a PPO. The point that I have discovered during debugging is that the error comes when I pass a reward other than 0. with 0 reward, everything works fine.
I make the agent using the following code
class PPOTrainer:
def __init__(self, state_dim, action_dim):
with open("rl_models/data.json") as fp:
agent = json.load(fp=fp)
network=[dict(type='dense', size=64),dict(type='dense', size=64)]
self.agent = Agent.from_spec(spec=agent, kwargs=dict(states=OpenAIGym.state_from_space(state_dim), actions=OpenAIGym.action_from_space(action_dim), network=network))
def action(self, state):
return self.agent.act(state, Independent=True)
def update(self, reward, terminal):
self.agent.observe(reward, terminal)
where data.json is the ppo.json file that can be found in the example folder. update: I have tried other agents as well. Apparently having problem with all.
InvalidArgumentError Traceback (most recent call last)
/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py in _do_call(self, fn, *args)
1322 try:
-> 1323 return fn(*args)
1324 except errors.OpError as e:
/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py in _run_fn(session, feed_dict, fetch_list, target_list, options, run_metadata)
1301 feed_dict, fetch_list, target_list,
-> 1302 status, run_metadata)
1303
/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/errors_impl.py in __exit__(self, type_arg, value_arg, traceback_arg)
472 compat.as_text(c_api.TF_Message(self.status.status)),
--> 473 c_api.TF_GetCode(self.status.status))
474 # Delete the underlying status object from memory otherwise it stays alive
InvalidArgumentError: Must have updates.shape = indices.shape + params.shape[1:], got updates.shape [2,4], indices.shape [1], params.shape [5000,4]
[[Node: ppo/observe-timestep/store/ScatterUpdate = ScatterUpdate[T=DT_FLOAT, Tindices=DT_INT32, _class=["loc:/initialize/latest/initialize/state-state"], use_locking=true, _device="/job:localhost/replica:0/task:0/device:GPU:0"](ppo/initialize/latest/initialize/state-state, ppo/observe-timestep/store/mod/_367, ppo/strided_slice, ^ppo/observe-timestep/store/AssignSub/_371)]]
[[Node: ppo/observe-timestep/cond/optimization/multi-step/step/subsampling-step/step/adam/step/add_7/_512 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device_incarnation=1, tensor_name="edge_2223_ppo/observe-timestep/cond/optimization/multi-step/step/subsampling-step/step/adam/step/add_7", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]
During handling of the above exception, another exception occurred:
InvalidArgumentError Traceback (most recent call last)
~/environments/vip_protection/main_code/vip_protection_rl.py in <module>()
195 if __name__ == '__main__':
196 arglist = parse_args()
--> 197 train(arglist)
~/environments/vip_protection/main_code/vip_protection_rl.py in train(arglist)
143 for i, agent in enumerate(trainers):
144 print(rew_n[i])
--> 145 agent.update(rew_n[i], terminal)
146
147 if done or terminal:
~/environments/vip_protection/main_code/rl_models/agents.py in update(self, reward, terminal)
13
14 def update(self, reward, terminal):
---> 15 self.agent.observe(reward, terminal)
16 #
17 # class PPOTrainer(PPOAgent):
~/environments/vip_protection/tensorforce/tensorforce/agents/agent.py in observe(self, terminal, reward)
224 self.episode = self.model.observe(
225 terminal=self.observe_terminal,
--> 226 reward=self.observe_reward
227 )
228 self.observe_terminal = list()
~/environments/vip_protection/tensorforce/tensorforce/models/model.py in observe(self, terminal, reward)
1240
1241 self.is_observe = True
-> 1242 episode = self.monitored_session.run(fetches=fetches, feed_dict=feed_dict)
1243 self.is_observe = False
1244
/usr/local/lib/python3.5/dist-packages/tensorflow/python/training/monitored_session.py in run(self, fetches, feed_dict, options, run_metadata)
519 feed_dict=feed_dict,
520 options=options,
--> 521 run_metadata=run_metadata)
522
523 def should_stop(self):
/usr/local/lib/python3.5/dist-packages/tensorflow/python/training/monitored_session.py in run(self, *args, **kwargs)
965 raise six.reraise(*original_exc_info)
966 else:
--> 967 raise six.reraise(*original_exc_info)
968
969
/usr/lib/python3/dist-packages/six.py in reraise(tp, value, tb)
684 if value.__traceback__ is not tb:
685 raise value.with_traceback(tb)
--> 686 raise value
687
688 else:
/usr/local/lib/python3.5/dist-packages/tensorflow/python/training/monitored_session.py in run(self, *args, **kwargs)
950 def run(self, *args, **kwargs):
951 try:
--> 952 return self._sess.run(*args, **kwargs)
953 except _PREEMPTION_ERRORS:
954 raise
/usr/local/lib/python3.5/dist-packages/tensorflow/python/training/monitored_session.py in run(self, fetches, feed_dict, options, run_metadata)
1022 feed_dict=feed_dict,
1023 options=options,
-> 1024 run_metadata=run_metadata)
1025
1026 for hook in self._hooks:
/usr/local/lib/python3.5/dist-packages/tensorflow/python/training/monitored_session.py in run(self, *args, **kwargs)
825
826 def run(self, *args, **kwargs):
--> 827 return self._sess.run(*args, **kwargs)
828
829
/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata)
887 try:
888 result = self._run(None, fetches, feed_dict, options_ptr,
--> 889 run_metadata_ptr)
890 if run_metadata:
891 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
1118 if final_fetches or final_targets or (handle and feed_dict_tensor):
1119 results = self._do_run(handle, final_targets, final_fetches,
-> 1120 feed_dict_tensor, options, run_metadata)
1121 else:
1122 results = []
/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)
1315 if handle is None:
1316 return self._do_call(_run_fn, self._session, feeds, fetches, targets,
-> 1317 options, run_metadata)
1318 else:
1319 return self._do_call(_prun_fn, self._session, handle, feeds, fetches)
/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py in _do_call(self, fn, *args)
1334 except KeyError:
1335 pass
-> 1336 raise type(e)(node_def, op, message)
1337
1338 def _extend_graph(self):
InvalidArgumentError: Must have updates.shape = indices.shape + params.shape[1:], got updates.shape [2,4], indices.shape [1], params.shape [5000,4]
[[Node: ppo/observe-timestep/store/ScatterUpdate = ScatterUpdate[T=DT_FLOAT, Tindices=DT_INT32, _class=["loc:/initialize/latest/initialize/state-state"], use_locking=true, _device="/job:localhost/replica:0/task:0/device:GPU:0"](ppo/initialize/latest/initialize/state-state, ppo/observe-timestep/store/mod/_367, ppo/strided_slice, ^ppo/observe-timestep/store/AssignSub/_371)]]
[[Node: ppo/observe-timestep/cond/optimization/multi-step/step/subsampling-step/step/adam/step/add_7/_512 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device_incarnation=1, tensor_name="edge_2223_ppo/observe-timestep/cond/optimization/multi-step/step/subsampling-step/step/adam/step/add_7", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]
Caused by op 'ppo/observe-timestep/store/ScatterUpdate', defined at:
File "/usr/local/bin/ipython", line 11, in <module>
sys.exit(start_ipython())
File "/usr/local/lib/python3.5/dist-packages/IPython/__init__.py", line 125, in start_ipython
return launch_new_instance(argv=argv, **kwargs)
File "/usr/local/lib/python3.5/dist-packages/traitlets/config/application.py", line 658, in launch_instance
app.start()
File "/usr/local/lib/python3.5/dist-packages/IPython/terminal/ipapp.py", line 356, in start
self.shell.mainloop()
File "/usr/local/lib/python3.5/dist-packages/IPython/terminal/interactiveshell.py", line 480, in mainloop
self.interact()
File "/usr/local/lib/python3.5/dist-packages/IPython/terminal/interactiveshell.py", line 471, in interact
self.run_cell(code, store_history=True)
File "/usr/local/lib/python3.5/dist-packages/IPython/core/interactiveshell.py", line 2728, in run_cell
interactivity=interactivity, compiler=compiler, result=result)
File "/usr/local/lib/python3.5/dist-packages/IPython/core/interactiveshell.py", line 2856, in run_ast_nodes
if self.run_code(code, result):
File "/usr/local/lib/python3.5/dist-packages/IPython/core/interactiveshell.py", line 2910, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-1-2d2d706ed1b7>", line 1, in <module>
get_ipython().run_line_magic('run', 'vip_protection_rl')
File "/usr/local/lib/python3.5/dist-packages/IPython/core/interactiveshell.py", line 2095, in run_line_magic
result = fn(*args,**kwargs)
File "<decorator-gen-60>", line 2, in run
File "/usr/local/lib/python3.5/dist-packages/IPython/core/magic.py", line 187, in <lambda>
call = lambda f, *a, **k: f(*a, **k)
File "/usr/local/lib/python3.5/dist-packages/IPython/core/magics/execution.py", line 775, in run
run()
File "/usr/local/lib/python3.5/dist-packages/IPython/core/magics/execution.py", line 761, in run
exit_ignore=exit_ignore)
File "/usr/local/lib/python3.5/dist-packages/IPython/core/interactiveshell.py", line 2491, in safe_execfile
self.compile if shell_futures else None)
File "/usr/local/lib/python3.5/dist-packages/IPython/utils/py3compat.py", line 186, in execfile
exec(compiler(f.read(), fname, 'exec'), glob, loc)
File "/home/hassam/environments/vip_protection/main_code/vip_protection_rl.py", line 197, in <module>
train(arglist)
File "/home/hassam/environments/vip_protection/main_code/vip_protection_rl.py", line 110, in train
trainers = get_trainers_ppo(obs_shape_n, action_shape_n, env.n)
File "/home/hassam/environments/vip_protection/main_code/vip_protection_rl.py", line 73, in get_trainers_ppo
agent = PPOTrainer(observation_space_dimension[i], action_space_dimension[i])
File "/home/hassam/environments/vip_protection/main_code/rl_models/agents.py", line 9, in __init__
self.agent = Agent.from_spec(spec=agent, kwargs=dict(states=OpenAIGym.state_from_space(state_dim), actions=OpenAIGym.action_from_space(action_dim), network=network))
File "/home/hassam/environments/vip_protection/tensorforce/tensorforce/agents/agent.py", line 291, in from_spec
kwargs=kwargs
File "/home/hassam/environments/vip_protection/tensorforce/tensorforce/util.py", line 159, in get_object
return obj(*args, **kwargs)
File "/home/hassam/environments/vip_protection/tensorforce/tensorforce/agents/ppo_agent.py", line 151, in __init__
entropy_regularization=entropy_regularization
File "/home/hassam/environments/vip_protection/tensorforce/tensorforce/agents/learning_agent.py", line 149, in __init__
batching_capacity=batching_capacity
File "/home/hassam/environments/vip_protection/tensorforce/tensorforce/agents/agent.py", line 79, in __init__
self.model = self.initialize_model()
File "/home/hassam/environments/vip_protection/tensorforce/tensorforce/agents/ppo_agent.py", line 179, in initialize_model
likelihood_ratio_clipping=self.likelihood_ratio_clipping
File "/home/hassam/environments/vip_protection/tensorforce/tensorforce/models/pg_prob_ratio_model.py", line 88, in __init__
gae_lambda=gae_lambda
File "/home/hassam/environments/vip_protection/tensorforce/tensorforce/models/pg_model.py", line 95, in __init__
requires_deterministic=False
File "/home/hassam/environments/vip_protection/tensorforce/tensorforce/models/distribution_model.py", line 86, in __init__
discount=discount
File "/home/hassam/environments/vip_protection/tensorforce/tensorforce/models/memory_model.py", line 106, in __init__
reward_preprocessing=reward_preprocessing
File "/home/hassam/environments/vip_protection/tensorforce/tensorforce/models/model.py", line 200, in __init__
self.setup()
File "/home/hassam/environments/vip_protection/tensorforce/tensorforce/models/model.py", line 344, in setup
independent=independent
File "/home/hassam/environments/vip_protection/tensorforce/tensorforce/models/memory_model.py", line 583, in create_operations
independent=independent
File "/home/hassam/environments/vip_protection/tensorforce/tensorforce/models/model.py", line 1006, in create_operations
self.create_observe_operations(reward=reward, terminal=terminal)
File "/home/hassam/environments/vip_protection/tensorforce/tensorforce/models/model.py", line 983, in create_observe_operations
reward=reward
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/template.py", line 278, in __call__
result = self._call_func(args, kwargs, check_for_new_variables=False)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/template.py", line 217, in _call_func
result = self._func(*args, **kwargs)
File "/home/hassam/environments/vip_protection/tensorforce/tensorforce/models/memory_model.py", line 481, in tf_observe_timestep
reward=reward
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/template.py", line 261, in __call__
return self._call_func(args, kwargs, check_for_new_variables=True)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/template.py", line 217, in _call_func
result = self._func(*args, **kwargs)
File "/home/hassam/environments/vip_protection/tensorforce/tensorforce/core/memories/queue.py", line 169, in tf_store
updates=state
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/gen_state_ops.py", line 786, in scatter_update
use_locking=use_locking, name=name)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
op_def=op_def)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py", line 2956, in create_op
op_def=op_def)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py", line 1470, in __init__
self._traceback = self._graph._extract_stack() # pylint: disable=protected-access
InvalidArgumentError (see above for traceback): Must have updates.shape = indices.shape + params.shape[1:], got updates.shape [2,4], indices.shape [1], params.shape [5000,4]
[[Node: ppo/observe-timestep/store/ScatterUpdate = ScatterUpdate[T=DT_FLOAT, Tindices=DT_INT32, _class=["loc:/initialize/latest/initialize/state-state"], use_locking=true, _device="/job:localhost/replica:0/task:0/device:GPU:0"](ppo/initialize/latest/initialize/state-state, ppo/observe-timestep/store/mod/_367, ppo/strided_slice, ^ppo/observe-timestep/store/AssignSub/_371)]]
[[Node: ppo/observe-timestep/cond/optimization/multi-step/step/subsampling-step/step/adam/step/add_7/_512 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device_incarnation=1, tensor_name="edge_2223_ppo/observe-timestep/cond/optimization/multi-step/step/subsampling-step/step/adam/step/add_7", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]
</module></lambda></decorator-gen-60></module></ipython-input-1-2d2d706ed1b7></module></module>
该提问来源于开源项目:tensorforce/tensorforce
- 点赞
- 回答
- 收藏
- 复制链接分享
8条回答
为你推荐
- 在基于golang的静态API应用程序方面需要帮助
- python
- api
- 1个回答
- Opencart将奖励积分添加到客户的订单电子邮件中
- php
- 1个回答
- 亚马逊奖励网站的会员代码
- php
- 1个回答
- 使用数据库登录时重定向到用户特定URL,并在用户进行时更新此URL
- database
- php
- redirect
- 1个回答
- 传递.php文件作为参数
- php
- 3个回答
换一换