liudinglldd 2022-05-11 17:37 采纳率: 22.2%
浏览 34
已结题

seqGAN代码中target_lstm计算self.pretrain_loss是什么含义?

问题遇到的现象和发生背景

尝试运行seqGAN源代码,
论文:https://arxiv.org/abs/1609.05473
代码:https://github.com/LantaoYu/SeqGAN
在查看代码时发现,target_lstm.py中计算self.pretrain_loss有两个量参与,一个是由generator生成的eval_file通过likelihood_data_loader得到的batch样本,传递给self.x,另一个参与的变量是通过oracl model计算得到的self.g_predictions,用自己的数据文件来替换掉oracl model来产生训练数据,那么在计算这个self.pretrain_loss时,如何不使用oracl model,而是用自己的数据文件来参与呢?

问题相关代码,请勿粘贴截图

计算self.pretrain_loss的代码如下

self.pretrain_loss = -tf.reduce_sum(
            tf.one_hot(tf.cast(tf.reshape(self.x, [-1]), tf.int32), self.num_emb, 1.0, 0.0) *
            tf.math.log(tf.reshape(self.g_predictions, [-1, self.num_emb]))
        ) / (self.sequence_length * self.batch_size)

计算self.g_predictions的代码如下

        def _pretrain_recurrence(i, x_t, h_tm1, g_predictions): 
            h_t = self.g_recurrent_unit(x_t, h_tm1)           # 两个参数构成tuple传到函数中,(64,32)(2,64,32)
            o_t = self.g_output_unit(h_t)                            # o_t, shape=(64, 5000)
            g_predictions = g_predictions.write(i, o_t)       # shape=(64,5000)
            x_tp1 = ta_emb_x.read(i)                                # shape=(64,32)
            return i + 1, x_tp1, h_t, g_predictions

        it, _, _, self.g_predictions = control_flow_ops.while_loop(
            cond=lambda i, _1, _2, _3: i < self.sequence_length,
            body=_pretrain_recurrence, 
            loop_vars=(
                tf.constant(0, dtype=tf.int32),
                tf.nn.embedding_lookup(self.g_embeddings, self.start_token),  # (5000,32) (64,) = (64,32)
                self.h0,
                g_predictions) 
        )

        self.g_predictions = tf.transpose(self.g_predictions.stack(), perm=[1, 0, 2])

代码中的self.x是传入的shape=(64,20)的一个batch的sample,由generator.py生成eval_file,(9984,20)的数据量,再通过likelihood_data_loader对eval_file中的sample进行处理和shape变换得到的,处理后的总数据为(156,64,20),这里只传入一个batch,为(64,20)。

我想要达到的结果

self.g_predictions通过构造的oracl model的网络模型计算得到,因为想使用自己的数据文件,而不通过这个模型运算输出数据,那么再计算这个self.pretrain_loss时应该怎么使用自己的数据文件?
为了方便处理,将自己的数据文件格式直接改为(n,20)的数据量,n暂不给定。

  • 写回答

0条回答 默认 最新

    报告相同问题?

    问题事件

    • 系统已结题 5月19日
    • 创建了问题 5月11日

    悬赏问题

    • ¥15 机电一体化系统设计说明书
    • ¥20 sgy数据提取地震波速,有人能回答吗小馋
    • ¥20 c#实现打开word的功能,并且需要安装成windows服务,word打不开怎么办
    • ¥15 python用ARIMA时间预测模型预测数据出错,急!
    • ¥30 为什么后端传给前端vue的河流json数据不在地图中显示出来
    • ¥50 关于弹性波动方程求解的问题:
    • ¥100 C# 做 10个串口数据采集;采集数据填到 EXE表中;1路与PLC MODBUS通讯 ;要求速度尽量快点; 有能做完整程序的单聊,重酬
    • ¥15 P0口接8个LED,编写程序,并绘制proteus仿真电路原理图
    • ¥15 java,maven
    • ¥15 单独编译安卓13车载evs