m0_56499137 2021-05-07 17:59 采纳率: 0%
浏览 18

inport_data 和 model是自己写的py文件吗?他们是什么作用?

import input_data
import model

# 变量声明
N_CLASSES = 4  # 四种花类型
IMG_W = 64  # resize图像,太大的话训练时间久
IMG_H = 64
BATCH_SIZE = 20
CAPACITY = 200
MAX_STEP = 2000  # 一般大于10K
learning_rate = 0.0001  # 一般小于0.0001

# 获取批次batch
train_dir = 'D:/桌面/CDA/7、机器学习/input_data'  # 训练样本的读入路径
logs_train_dir = 'D:/桌面/CDA/7、机器学习/save'  # logs存储路径

# train, train_label = input_data.get_files(train_dir)
train, train_label, val, val_label = input_data.get_files(train_dir, 0.3)
# 训练数据及标签
train_batch, train_label_batch = input_data.get_batch(train, train_label, IMG_W, IMG_H, BATCH_SIZE, CAPACITY)
# 测试数据及标签
val_batch, val_label_batch = input_data.get_batch(val, val_label, IMG_W, IMG_H, BATCH_SIZE, CAPACITY)

# 训练操作定义
train_logits = model.inference(train_batch, BATCH_SIZE, N_CLASSES)
train_loss = model.losses(train_logits, train_label_batch)
train_op = model.trainning(train_loss, learning_rate)
train_acc = model.evaluation(train_logits, train_label_batch)

# 测试操作定义
test_logits = model.inference(val_batch, BATCH_SIZE, N_CLASSES)
test_loss = model.losses(test_logits, val_label_batch)
test_acc = model.evaluation(test_logits, val_label_batch)

# 这个是log汇总记录
summary_op = tf.summary.merge_all()

# 产生一个会话
sess = tf.Session()
# 产生一个writer来写log文件
train_writer = tf.summary.FileWriter(logs_train_dir, sess.graph)
# val_writer = tf.summary.FileWriter(logs_test_dir, sess.graph)
# 产生一个saver来存储训练好的模型
saver = tf.train.Saver()
# 所有节点初始化
sess.run(tf.global_variables_initializer())
# 队列监控
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)

# 进行batch的训练
try:
    # 执行MAX_STEP步的训练,一步一个batch
    for step in np.arange(MAX_STEP):
        if coord.should_stop():
            break
        _, tra_loss, tra_acc = sess.run([train_op, train_loss, train_acc])

        # 每隔50步打印一次当前的loss以及acc,同时记录log,写入writer
        if step % 10 == 0:
            print('Step %d, train loss = %.2f, train accuracy = %.2f%%' % (step, tra_loss, tra_acc * 100.0))
            summary_str = sess.run(summary_op)
            train_writer.add_summary(summary_str, step)
        # 每隔100步,保存一次训练好的模型
        if (step + 1) == MAX_STEP:
            checkpoint_path = os.path.join(logs_train_dir, 'model.ckpt')
            saver.save(sess, checkpoint_path, global_step=step)

except tf.errors.OutOfRangeError:
    print('Done training -- epoch limit reached')

finally:
    coord.request_stop()

  • 写回答

1条回答 默认 最新

  • 睿妈陪娃 2024-07-11 12:22
    关注

    不是

    评论

报告相同问题?

悬赏问题

  • ¥15 咨询一下有关于王者荣耀赢藏战绩
  • ¥100 求购一套带接口实现实习自动签到打卡
  • ¥50 MacOS 使用虚拟机安装k8s
  • ¥500 亚马逊 COOKIE我如何才能实现 登录一个亚马逊账户 下发新 COOKIE ..我使用下发新COOKIE 导入ADS 指纹浏览器登录,我把账户密码 修改过后,原来下发新COOKIE 不会失效的方式
  • ¥20 玩游戏gpu和cpu利用率特别低,玩游戏卡顿
  • ¥25 oracle中的正则匹配
  • ¥15 关于#vscode#的问题:把软件卸载不会再出现蓝屏
  • ¥15 vimplus出现的错误
  • ¥15 usb无线网卡转typec口
  • ¥30 怎么使用AVL fire ESE软件自带的优化模式来优化设计Soot和NOx?