Mapless丶 2019-08-02 22:30 采纳率: 0%
浏览 1060

我的keras的model.fit写在一个loop里,callback每一个epoch会生成一个events文件,如何处理这种问题?

if resume:
# creates a generic neural network architecture
model = Sequential()

# hidden layer takes a pre-processed frame as input, and has 200 units
model.add(Dense(units=200,input_dim=80*80, activation='relu', kernel_initializer='glorot_uniform'))

# output layer
model.add(Dense(units=1, activation='sigmoid', kernel_initializer='RandomNormal'))

# compile the model using traditional Machine Learning losses and optimizers
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

#print model
model.summary()
if os.path.isfile('Basic_Rl_weights.h5'):
#load pre-trained model weight
    print("loading previous weights")
    model.load_weights('Basic_Rl_weights.h5')

else :
# creates a generic neural network architecture
model = Sequential()

# hidden layer takes a pre-processed frame as input, and has 200 units
model.add(Dense(units=200,input_dim=80*80, activation='relu', kernel_initializer='glorot_uniform'))

# output layer
model.add(Dense(units=1, activation='sigmoid', kernel_initializer='RandomNormal'))

# compile the model using traditional Machine Learning losses and optimizers
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

#print model
model.summary()

#save model
# model.save_weights('my_model_weights.h5')

log_dir = './log' + datetime.now().strftime("%Y%m%d-%H%M%S") + "/"
callbacks = callbacks.TensorBoard(log_dir=log_dir, histogram_freq=0,

write_graph=True, write_images=True)

gym initialization

env = gym.make("Pong-v0")
observation = env.reset()
prev_x = None # used in computing the difference frame
running_reward = None

initialization of variables used in the main loop

x_train, y_train, rewards = [],[],[]
reward_sum = 0
episode_number = 0

main loop

while True:
if render : env.render()
# preprocess the observation, set input as difference between images
cur_x = prepro(observation)
# i=np.expand_dims(cur_x,axis=0)
# print(i.shape)
# print(cur_x.shape)
if prev_x is not None :
x = cur_x - prev_x
else:
x = np.zeros(Input_dim)
# print(x.shape)
# print(np.expand_dims(cur_x,axis=0).shape)
prev_x = cur_x

# forward the policy network and sample action according to the proba distribution

# two ways to calculate returned probability
# print(x.shape)
prob = model.predict(np.expand_dims(x, axis=1).T)
# aprob = model.predict(np.expand_dims(x, axis=1).T)

if np.random.uniform() < prob:
    action = action_up
else :
    action = action_down

# 0 and 1 labels( a fake label in order to achive back propagation algorithm)
if action == 2:
    y = 1     
else:
    y = 0 

# log the input and label to train later
x_train.append(x)
y_train.append(y)

# do one step in our environment
observation, reward, done, info = env.step(action)
rewards.append(reward)
reward_sum += reward

# end of an episode
if done:
    print('At the end of episode', episode_number, 'the total reward was :', reward_sum)

    # increment episode number
    episode_number += 1

    # training
    # history = LossHistory()
    model.fit(x=np.vstack(x_train), 
              y=np.vstack(y_train), 
              verbose=1, 
              sample_weight=discount_rewards(rewards),
              callbacks=[callbacks])
    if episode_number % 100 == 0:
        model.save_weights('Basic_Rl_weights' + datetime.now().strftime("%Y%m%d-%H%M%S") + '.h5')
    # Log the reward
    running_reward = reward_sum if running_reward is None else running_reward * 0.99 + reward_sum * 0.01
    # if episode_number % 10 == 0:
    tflog('running_reward', running_reward, custom_dir=log_dir)

    # Reinitialization
    x_train, y_train, rewards = [],[],[]
    observation = env.reset()
    reward_sum = 0
    prev_x = None

  • 写回答

1条回答

  • 关注
    评论

报告相同问题?

悬赏问题

  • ¥15 安卓adb backup备份应用数据失败
  • ¥15 eclipse运行项目时遇到的问题
  • ¥15 关于#c##的问题:最近需要用CAT工具Trados进行一些开发
  • ¥15 南大pa1 小游戏没有界面,并且报了如下错误,尝试过换显卡驱动,但是好像不行
  • ¥15 没有证书,nginx怎么反向代理到只能接受https的公网网站
  • ¥50 成都蓉城足球俱乐部小程序抢票
  • ¥15 yolov7训练自己的数据集
  • ¥15 esp8266与51单片机连接问题(标签-单片机|关键词-串口)(相关搜索:51单片机|单片机|测试代码)
  • ¥15 电力市场出清matlab yalmip kkt 双层优化问题
  • ¥30 ros小车路径规划实现不了,如何解决?(操作系统-ubuntu)