Mapless丶 2019-08-02 22:30 采纳率: 0%
浏览 1060

我的keras的model.fit写在一个loop里,callback每一个epoch会生成一个events文件,如何处理这种问题?

if resume:
# creates a generic neural network architecture
model = Sequential()

# hidden layer takes a pre-processed frame as input, and has 200 units
model.add(Dense(units=200,input_dim=80*80, activation='relu', kernel_initializer='glorot_uniform'))

# output layer
model.add(Dense(units=1, activation='sigmoid', kernel_initializer='RandomNormal'))

# compile the model using traditional Machine Learning losses and optimizers
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

#print model
model.summary()
if os.path.isfile('Basic_Rl_weights.h5'):
#load pre-trained model weight
    print("loading previous weights")
    model.load_weights('Basic_Rl_weights.h5')

else :
# creates a generic neural network architecture
model = Sequential()

# hidden layer takes a pre-processed frame as input, and has 200 units
model.add(Dense(units=200,input_dim=80*80, activation='relu', kernel_initializer='glorot_uniform'))

# output layer
model.add(Dense(units=1, activation='sigmoid', kernel_initializer='RandomNormal'))

# compile the model using traditional Machine Learning losses and optimizers
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

#print model
model.summary()

#save model
# model.save_weights('my_model_weights.h5')

log_dir = './log' + datetime.now().strftime("%Y%m%d-%H%M%S") + "/"
callbacks = callbacks.TensorBoard(log_dir=log_dir, histogram_freq=0,

write_graph=True, write_images=True)

gym initialization

env = gym.make("Pong-v0")
observation = env.reset()
prev_x = None # used in computing the difference frame
running_reward = None

initialization of variables used in the main loop

x_train, y_train, rewards = [],[],[]
reward_sum = 0
episode_number = 0

main loop

while True:
if render : env.render()
# preprocess the observation, set input as difference between images
cur_x = prepro(observation)
# i=np.expand_dims(cur_x,axis=0)
# print(i.shape)
# print(cur_x.shape)
if prev_x is not None :
x = cur_x - prev_x
else:
x = np.zeros(Input_dim)
# print(x.shape)
# print(np.expand_dims(cur_x,axis=0).shape)
prev_x = cur_x

# forward the policy network and sample action according to the proba distribution

# two ways to calculate returned probability
# print(x.shape)
prob = model.predict(np.expand_dims(x, axis=1).T)
# aprob = model.predict(np.expand_dims(x, axis=1).T)

if np.random.uniform() < prob:
    action = action_up
else :
    action = action_down

# 0 and 1 labels( a fake label in order to achive back propagation algorithm)
if action == 2:
    y = 1     
else:
    y = 0 

# log the input and label to train later
x_train.append(x)
y_train.append(y)

# do one step in our environment
observation, reward, done, info = env.step(action)
rewards.append(reward)
reward_sum += reward

# end of an episode
if done:
    print('At the end of episode', episode_number, 'the total reward was :', reward_sum)

    # increment episode number
    episode_number += 1

    # training
    # history = LossHistory()
    model.fit(x=np.vstack(x_train), 
              y=np.vstack(y_train), 
              verbose=1, 
              sample_weight=discount_rewards(rewards),
              callbacks=[callbacks])
    if episode_number % 100 == 0:
        model.save_weights('Basic_Rl_weights' + datetime.now().strftime("%Y%m%d-%H%M%S") + '.h5')
    # Log the reward
    running_reward = reward_sum if running_reward is None else running_reward * 0.99 + reward_sum * 0.01
    # if episode_number % 10 == 0:
    tflog('running_reward', running_reward, custom_dir=log_dir)

    # Reinitialization
    x_train, y_train, rewards = [],[],[]
    observation = env.reset()
    reward_sum = 0
    prev_x = None

  • 写回答

1条回答 默认 最新

  • 关注
    评论

报告相同问题?

悬赏问题

  • ¥60 版本过低apk如何修改可以兼容新的安卓系统
  • ¥25 由IPR导致的DRIVER_POWER_STATE_FAILURE蓝屏
  • ¥50 有数据,怎么建立模型求影响全要素生产率的因素
  • ¥50 有数据,怎么用matlab求全要素生产率
  • ¥15 TI的insta-spin例程
  • ¥15 完成下列问题完成下列问题
  • ¥15 C#算法问题, 不知道怎么处理这个数据的转换
  • ¥15 YoloV5 第三方库的版本对照问题
  • ¥15 请完成下列相关问题!
  • ¥15 drone 推送镜像时候 purge: true 推送完毕后没有删除对应的镜像,手动拷贝到服务器执行结果正确在样才能让指令自动执行成功删除对应镜像,如何解决?