D222097 2020-05-07 00:42 采纳率: 40%
浏览 1025
已结题

tensorflow2.0如何实现参数共享

您好,我想请教一下如何在tf2.0中实现参数共享。
我正在尝试实现【Unsupervised Visual Representation Learning by Context Prediction】这篇论文中的模型结构如下图所示:

这个结构的前几层是由相同的两部分组成的,然后在fc7融合为一层,但是这个模型要求fc7层以前并行的两部分的参数是共享的。我不知道该如何在tf2.0里面实现参数共享。要在tf2.0中实现这样的结构该如何做呢?

我尝试的代码如下:

def Alex_net(x):
    conv1    = tf.keras.layers.Conv2D(96,(11,11),activation='relu',strides=(4,4))(x)
    maxpool1 = tf.keras.layers.MaxPooling2D((3,3),strides=(2,2))(conv1)
    conv2    = tf.keras.layers.Conv2D(256,(5,5),activation='relu',padding='same')(maxpool1)
    maxpool2 = tf.keras.layers.MaxPooling2D((3,3),strides=(2,2))(conv2)
    conv3    = tf.keras.layers.Conv2D(384,(3,3),activation='relu',padding='same')(maxpool2)
    conv4    = tf.keras.layers.Conv2D(384,(3,3),activation='relu',padding='same')(conv3)
    conv5    = tf.keras.layers.Conv2D(256,(3,3),activation='relu',padding='same')(conv4)
    maxpool5 = tf.keras.layers.MaxPooling2D((3,3),strides=(2,2))(conv5)
    fc6      = tf.keras.layers.Dense(4096,activation='relu')(maxpool5)
    print(fc6)
    return fc6

def Concat_net(x1,x2):
    input_1 = Alex_net(x1)
    input_2 = Alex_net(x2)   
    concat  = tf.keras.layers.concatenate([input_1,input_2])
    fc7     = tf.keras.layers.Dense(4096,activation='relu')(concat)
    fc8     = tf.keras.layers.Dense(4096,activation='relu')(fc7)
    fc9     = tf.keras.layers.Dense(8,activation='softmax')(fc8)
    C       = fc9
    return C

def final_net(width,height,depth):
    inputshape=(height,width,depth)
    inputs_1  = tf.keras.layers.Input(shape=inputshape)
    inputs_2  = tf.keras.layers.Input(shape=inputshape)
    outputs   = Concat_net(inputs_1, inputs_2)
    model     = tf.keras.Model([inputs_1,inputs_2],outputs,name='concat_NET')
    return model
F=final_net(96,96,3)
F.summary()

但是这样summary打印出来的参数是独立的,并不是共享的,模型只是重复利用了结构。summary的结果如下:

图片说明

要在tf2.0中实现这样的结构该如何做呢?

  • 写回答

3条回答 默认 最新

  • 滑动窗口协议 2020-05-08 23:49
    关注
    import tensorflow as tf
    from tensorflow import keras
    
    def Concat_net(x1,x2,model):
        input_1 = model.predict(x1)
        input_2 = model.predict(x2)   
        concat  = tf.keras.layers.concatenate([input_1,input_2])
        fc7     = tf.keras.layers.Dense(4096,activation='relu')(concat)
        fc8     = tf.keras.layers.Dense(4096,activation='relu')(fc7)
        fc9     = tf.keras.layers.Dense(8,activation='softmax')(fc8)
        C       = fc9
        return C
    
    def final_net(inputshape,model):
        inputs_1  = tf.keras.layers.Input(shape=inputshape)
        inputs_2  = tf.keras.layers.Input(shape=inputshape)
        outputs   = Concat_net(inputs_1, inputs_2, model)
        model     = tf.keras.Model([inputs_1,inputs_2],outputs,name='concat_NET')
        return model
    
    Alex_net = keras.Sequential([
        keras.layers.Conv2D(96,(11,11),activation='relu',strides=(4,4)),
        keras.layers.MaxPooling2D((3,3),strides=(2,2)),
        keras.layers.Conv2D(256,(5,5),activation='relu',padding='same'),
        keras.layers.MaxPooling2D((3,3),strides=(2,2)),
        keras.layers.Conv2D(384,(3,3),activation='relu',padding='same'),
        keras.layers.Conv2D(384,(3,3),activation='relu',padding='same'),
        keras.layers.Conv2D(256,(3,3),activation='relu',padding='same'),
        keras.layers.MaxPooling2D((3,3),strides=(2,2)),
        keras.layers.Dense(4096,activation='relu')
    ])
    inputshape=(96,96,3)
    F=final_net(inputshape,Alex_net)
    F.summary()
    
    
    评论

报告相同问题?

悬赏问题

  • ¥30 这是哪个作者做的宝宝起名网站
  • ¥60 版本过低apk如何修改可以兼容新的安卓系统
  • ¥25 由IPR导致的DRIVER_POWER_STATE_FAILURE蓝屏
  • ¥50 有数据,怎么建立模型求影响全要素生产率的因素
  • ¥50 有数据,怎么用matlab求全要素生产率
  • ¥15 TI的insta-spin例程
  • ¥15 完成下列问题完成下列问题
  • ¥15 C#算法问题, 不知道怎么处理这个数据的转换
  • ¥15 YoloV5 第三方库的版本对照问题
  • ¥15 请完成下列相关问题!