您好,我使用
model.conv1.params()
返回的是一个link,然后用
model.conv1.copyparams()
就报错缺少参数 link
于是我用:
model.conv1.copyparams(model.conv1.params() )
也行不通,
请问各位前辈后辈这个问题怎么解决?
万分感谢
您好,我使用
model.conv1.params()
返回的是一个link,然后用
model.conv1.copyparams()
就报错缺少参数 link
于是我用:
model.conv1.copyparams(model.conv1.params() )
也行不通,
请问各位前辈后辈这个问题怎么解决?
万分感谢
首先明确每一个基本的网络连接层(Network Connection)都只含有参数 W 和 b,数据类型是 Variable.
从当前模块不断定位到最初的Network Connection之后,加上".W.data"或者".b.data"就可以了。
譬如说我的网络是: VGG,每一层都是一个自定义模块"ConvBlock",每一个Block中有“self.c1=F.convolution_2d...”和其他的基本网络块组成的,那么就只需要:
vgg = VGG()
param_of_conv_1_1 = vgg.conv1.c1.W.data
# 输出是一个np.asarray数组