2301_81502608 2024-06-09 19:05 采纳率: 100%
浏览 4
已结题

tensorflow解CartPole遇到问题,求解答。

机器学习遇到了一些问题,下面是源代码。


```python
class DQLAgent:
    # hu -> hidden units, opt -> optimizer
    def __init__(self, gamma=0.95, hu=24, opt=keras.optimizers.Adam, lr=0.001, finish=False):
        self.finish = finish
        self.epsilon = 1.0
        self.epsilon_min = 0.01
        self.epsilon_decay = 0.995
        self.gamma = gamma
        self.batch_size = 32
        self.max_treward = 0
        self.averages = list()
        self.memory = deque(maxlen=2000)
        self.osn = env.observation_space.shape[0]
        self.model = self._build_model(hu, opt, lr)
    
    def _build_model(self, hu, opt, lr):
        model = Sequential()
        model.add(Dense(hu, input_dim=self.osn, activation='relu'))
        model.add(Dense(hu, activation='relu'))
        model.add(Dense(env.action_space.n, activation='linear'))
        model.compile(loss='mse', optimizer=opt(learning_rate=lr))
        return model
    
    def act(self, state):
        if random.random() <= self.epsilon:
            return env.action_space.sample()
        action = self.model.predict(state)[0]
        return np.argmax(action)
    
    def replay(self):
        batch = random.sample(self.memory, self.batch_size)
        for state, action, reward, next_state, done in batch:
            if not done:
                reward += self.gamma * np.amax(self.model.predict(state)[0])
            target = self.model.predict(next_state)
            target[0, action] = reward
            self.model.fit(state, target, epochs=1, verbose=False)
        if self.epsilon > self.epsilon.min:
            self.epsilon *= self.epsilon_decay
            
    def learn(self, episodes):
        trewards = []
        for e in range(1, episodes + 1):
            state = env.reset()
            state = np.reshape(state, [1, self.osn])
            for _ in range(5000):
                action = self.act(state)
                next_state, reward, done, info = env.step(action)
                next_state = np.reshape(next_state, [1, self.osn])
                self.memory.append([state, action, reward, next_state, done])
                state = next_state
                if done:
                    treward = _ + 1
                    trewards.append(treward)
                    av = sum(trewards[-25:]) / 25
                    self.averages.append(av)
                    self.max_treward = max(self.max_treward, treward)
                    templ = 'episode: {:4d}/{} | treward: {:4d} | '
                    templ += 'av: {:6.1f} | max: {:4d}'
                    print(templ.format(e, episodes, treward, av, self.max_treward), end='\r')
                    break
            if av > 195 and self.finish:
                break
            if len(self.memory) > self.batch_size:
                self.replay()
            
    def test(self, episodes):
        trewards = []
        for e in range(1, episodes+1):
            state = env.reset()
            for _ in range(5001):
                state = np.reshape(state, [1, self.osn])
                action = np.argmax(self.model.predict(state)[0])
                next_state, reward, done, info = nev.step(action)
                state = next_state
                if done:
                    treward = _ + 1
                    trewards.append(treward)
                    print('episode: {:4d}/{} | treward: {:4d} | '
                          .format(e, episodes, treward, end='\r'))
                    break
        return rewards

报错 `tf.data.Dataset` only supports Python-style iteration in eager mode or within tf.function.
问题出在self.replay()里面,self.model.predict(state)出了问题,请问如何解决?
  • 写回答

2条回答 默认 最新

  • 阿里嘎多学长 2024-06-09 19:05
    关注

    以下内容由CHATGPT及阿里嘎多学长共同生成、有用望采纳:


    在TensorFlow 2.x中,tf.data.Dataset 只支持在 eager execution(即默认情况下)或在 tf.function 装饰的函数中进行 Python 风格的迭代。你遇到的问题可能是由于在 self.replay() 方法中对 self.model.predict(state) 的调用方式不当造成的。

    在TensorFlow 2.x中,predict 方法返回的是一个 NumPy 数组,可以直接迭代,不需要使用 Python 风格的迭代。但是,如果你的代码中有使用 tf.data.Dataset 的地方,确保它们在 tf.function 装饰的函数中使用。

    针对你的问题,这里有一些可能的解决方案:

    1. 确保predict方法的使用是正确的predict 方法应该直接返回预测结果,不需要迭代。你的代码中 self.model.predict(state)[0] 已经正确地获取了预测的第一个元素。

    2. 检查self.replay()方法中的循环:确保在循环中使用 predict 方法时,传入的参数是正确的。特别是 next_state 应该是一个已经reshape为 [1, self.osn] 的数组。

    3. 使用tf.function装饰器:如果你在其他地方使用了 tf.data.Dataset,并且需要迭代它,确保这些操作是在 tf.function 装饰的函数中进行的。

    4. 修改test方法中的env.step调用:在 test 方法中,你使用了 nev.step(action),这显然是一个拼写错误,应该是 env.step(action)

    5. 确保环境和模型的兼容性:确保你的环境和模型都是使用 TensorFlow 2.x 兼容的方式构建的。

    下面是一个修改后的 self.replay() 方法示例,它应该能够解决你的问题:

    def replay(self):
        minibatch = random.sample(self.memory, self.batch_size)
        for state, action, reward, next_state, done in minibatch:
            target = reward
            if not done:
                target = (reward + self.gamma * np.amax(self.model.predict(next_state)[0]))
            
            target_f = self.model.predict(state)
            target_f[0][action] = target
            
            self.model.fit(state, target_f, epochs=1, verbose=0)
            
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay
    

    请注意,这里使用 np.amax(self.model.predict(next_state)[0]) 来获取下一个状态的最大预测值,这是 DQN 算法中的一个常见步骤。

    至于参考资料,你可以查看 TensorFlow 官方文档,特别是关于 tf.functiontf.data.Dataset 的部分。这里有一个链接到 TensorFlow 2.x 的 tf.function 文档:tf.function

    希望这些信息能帮助你解决问题。如果还有其他问题或需要进一步的帮助,请随时告诉我。

    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论 编辑记录
查看更多回答(1条)

报告相同问题?

问题事件

  • 系统已结题 6月18日
  • 已采纳回答 6月10日
  • 创建了问题 6月9日