Lin_zhicheng 2021-09-06 09:16 采纳率: 76.9%
浏览 23
已结题

全卷积网络结构问题[Tensorflow2.0]

img


代码如下文,
我的问题是这样,我用的tf搭建了一个基于VGG16的简单FCN模型,运行的也很成功,但是我用plot_model对网络结构进行可视化的时候,却发现,展示出来的图像是惯序的,缺少了迁跃的结构,但是我用model.summary()得到的结果却显示迁跃结构是存在的,所以我不知道是哪里出了问题


import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import os

# 载入预训练网络 VGG16 使用imagenet数据集
conv_base = tf.keras.applications.VGG16(weights='imagenet',
                                        input_shape=(224,224,3),
                                        include_top=False)



# 创建一个多输出的模型,一劳永逸
# 通过调用子model产生的多输出,来获得原来模型不同层的输出
layer_names = ['block5_pool' ,
               'block5_conv3',
               'block4_conv3',
               'block3_conv3',
                          
]

layers_output = [conv_base.get_layer(layer_name).output for layer_name in layer_names]

# 搭建多输出子模型
multi_out_model = tf.keras.models.Model(inputs = conv_base.input,
                                        outputs = layers_output
                                        )

multi_out_model.trainable = False # 冻结住预训练参数


# 先定义一个输入
inputs = tf.keras.layers.Input(shape=(224,224,3))
# 送入模型,获得输出
out, out_block5_conv3, out_block4_conv3, out_block3_conv3= multi_out_model(inputs)
print(out_block5_conv3.shape,out_block4_conv3.shape,out_block3_conv3.shape,out.shape)

# 开始上采样,一层一层来
# 先将out层7,7,512,上采样成14,14,512
x1 = tf.keras.layers.Conv2DTranspose(512,3,
                                     strides=2,
                                     padding='same',
                                     activation='relu')(out)
x1.shape # [14,14,512]
# 再加一层卷积(不改变大小)来提取x1的特征
x1 = tf.keras.layers.Conv2D(512,3,
                            strides=1,
                            padding='same',
                            activation='relu')(x1)

# 然后与out_block5_conv3[14,14,512]相加
# x2 = tf.add(x1, out_block5_conv3) # [14,14,512] 元素对应相加
# conv_base.get_layer('block5_conv3').output
x2 = x1 + out_block5_conv3

# 然后上采样成 28,28,512
x2 = tf.keras.layers.Conv2DTranspose(512,3,
                                     strides=2,
                                     padding='same',
                                     activation='relu')(x2)
x2.shape # [28,28,512]
# 再加一层卷积(不改变大小)来提取x2的特征
x2 = tf.keras.layers.Conv2D(512,3,
                            strides=1,
                            padding='same',
                            activation='relu')(x2)

x3 = tf.add(x2, out_block4_conv3) # [28,28,512]
# 上采样,成 56,56,256 和 out_block3_conv3相匹配
x3 = tf.keras.layers.Conv2DTranspose(256,3,
                                     strides=2,
                                     padding='same',
                                     activation='relu')(x3)
x3.shape # [56,56,256]
# 再加一层卷积(不改变大小)来提取x3的特征
x3 = tf.keras.layers.Conv2D(256,3,
                            strides=1,
                            padding='same',
                            activation='relu')(x3)

x4 = tf.add(x3, out_block3_conv3) # [56,56,256]

x4.shape # [56,56,256]

# 对x4进行上采样,使得输出通道为3
x5 = tf.keras.layers.Conv2DTranspose(128,3,
                                     strides=2,
                                     padding='same',
                                     activation='relu')(x4)
x5.shape # [112,112,128]
# 再加一层卷积(不改变大小)来提取特征
x5 = tf.keras.layers.Conv2D(128,3,strides=1,padding='same',activation='relu')(x5)

# 继续上采样成 224,224,3
# 因为是多分类输出所以激活函数用softmax
prediction = tf.keras.layers.Conv2DTranspose(3,3,
                                     strides=2,
                                     padding='same',
                                     activation='softmax')(x5)
prediction.shape # [224,224,3] 恢复成原来图片大小了

# model 的创建
model = tf.keras.models.Model(
    inputs=inputs,
    outputs=prediction
)

# 可视化模型结构
#导入下面的库
from tensorflow.keras.utils import plot_model  
import pydotplus  
#参数 :模型名称,结构图保存位置,是否展示shape
plot_model(model,to_file='./model.png',show_shapes=True)

  • 写回答

1条回答 默认 最新

  • 「已注销」 2021-09-06 09:25
    关注

    img

    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论

报告相同问题?

问题事件

  • 系统已结题 9月20日
  • 已采纳回答 9月12日
  • 创建了问题 9月6日

悬赏问题

  • ¥15 VB.NET2022如何生成发布成exe文件
  • ¥30 matlab appdesigner私有函数嵌套整合
  • ¥15 给我一个openharmony跑通webrtc实现视频会议的简单demo项目,sdk为12
  • ¥15 vb6.0使用jmail接收smtp邮件并另存附件到D盘
  • ¥30 vb net 使用 sendMessage 如何输入鼠标坐标
  • ¥15 关于freesurfer使用freeview可视化的问题
  • ¥100 谁能在荣耀自带系统MagicOS版本下,隐藏手机桌面图标?
  • ¥15 求SC-LIWC词典!
  • ¥20 有关esp8266连接阿里云
  • ¥15 C# 调用Bartender打印机打印