问题遇到的现象和发生背景
问题相关代码,请勿粘贴截图
def dqn(n_episodes=EPISODE_COUNT, eps_start=1.0, eps_end=0.01, eps_decay=0.995):
scores = []
scores_window = deque(maxlen=100)
eps = eps_start
for i_episode in range(1, n_episodes+1):
print("Episode" + str(i_episode))
state = getState(stockData, 0, STATE_SIZE + 1)
pos_old = 0 #初始化持仓
money_initial = 10000 # 初始化资金
money = money_initial
cost = 0 # 初始化等效成本
# total_profit = 0
total_share = 0 # 初始化持股
agent.inventory = []
reward = 0
for t in range(l):
action = agent.act(state, eps)
next_state = getState(stockData, t + 1, STATE_SIZE + 1)
# reward = 0
if action == 1 :# 加仓20%
#agent.inventory.append(stockData[t])
#print("buy" + str(stockData[t]))
pos_new = min(pos_old + 0.2, 1)
total_share += money * (pos_new - pos_old) / stockData[t]
elif action == 2:
# 减仓20%
#bought_price = agent.inventory.pop(0)
pos_new = max(pos_old - 0.2, 0)
total_share += money * (pos_new - pos_old) / stockData[t]
# reward = max(stockData[t] - bought_price, 0)
# reward = stockData[t] - cost
# print("Sell: " + str(stockData[t]) + " | Profit: " + str(stockData[t] - bought_price))
else: # 持仓
pos_new = pos_old
# cost = cost_calculate(stockData[t], money, pos_new, total_share)
money = money_calculate(money, total_share, stockData[t], pos_new)
if money < 0 or t == l - 1:
done = 1
else:
done = 0
reward = (money - money_initial) / money_initial
agent.step(state, action, reward, next_state, done)
eps = max(eps_end, eps * eps_decay)
state = next_state
pos_old = pos_new
if done:
print("------------------------------")
print("total_profit = " + str((money - money_initial) / money_initial))
print("------------------------------")
break
scores.append((money - money_initial) / money_initial)
scores_window.append((money - money_initial) / money_initial)
if np.mean(scores_window) > 0.2 and len(scores_window) == 100:
torch.save(agent.actor_local.state_dict(), 'checkpoint_actor.pth')
torch.save(agent.critic_local.state_dict(), 'checkpoint_critic.pth')
break
torch.save(agent.actor_local.state_dict(), 'checkpoint_actor.pth')
torch.save(agent.critic_local.state_dict(), 'checkpoint_critic.pth')
return scores
运行结果及报错内容
Episode1
C:\Users\22536.conda\envs\torch\lib\site-packages\torch\nn\functional.py:1795: UserWarning: nn.functional.tanh is deprecated. Use torch.tanh instead.
warnings.warn("nn.functional.tanh is deprecated. Use torch.tanh instead.")
Traceback (most recent call last):
File "D:/Git/stockPrediction2/main_new1.py", line 124, in
scores = dqn()
File "D:/Git/stockPrediction2/main_new1.py", line 45, in dqn
if action == 1 :# 加仓20%
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
Process finished with exit code 1