在论文所给的代码文件中,提示用test.py文件直接进行测试,但是在测试过程中设置好参数后出现了维度(貌似?)问题
__author__ = 'WEI'
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # BE QUIET!!!!
from mynetworks import load_trainable_vars
import numpy as np
import numpy.linalg as la
import tensorflow as tf
import myshrinkage
from scipy.io import savemat
from scipy.io import loadmat
from os import path
if __name__ == '__main__':
shrink='soft' #'soft''gm'
SNRrange = '10to20dB' # '10to20dB''0to10dB'
type = 'ULA' # 'UPA'
MN='256128'
#get the same Compressive Sensing matrix A
D = loadmat(type+'CSmatrix'+MN)
A = D['A']
#get the trained network
trainedfilename = type+'_'+shrink+'_'+MN+SNRrange+'.npz'
saveresultsname = type+'results'+'_'+shrink+'_'+MN+'.mat'
if (not path.exists(saveresultsname)):
#print('There is no file for saving the result before, and it has been generated!')
if shrink is 'soft':
LAMP_nmse = np.zeros(dtype=np.float64, shape=(5))
D = dict(LAMP_nmse=LAMP_nmse)
savemat(saveresultsname, D)
else:
if shrink is 'gm':
GMLAMP_nmse = np.zeros(dtype=np.float64, shape=(5))
D = dict(GMLAMP_nmse=GMLAMP_nmse)
savemat(saveresultsname, D)
T=8
untied=1
eta,theta_init=myshrinkage.get_shrinkage_function(shrink)
layer=[]
var_all=[]
M,N=A.shape
A_=tf.constant(A,dtype=tf.float64)
B=A.T/(1.01*la.norm(A,2)**2)
B_=tf.Variable(B,dtype=tf.float64,name='B_0')
var_all.append(B_)
x_ = tf.placeholder(tf.complex128, (N, None))
y_ = tf.placeholder(tf.complex128, (M, None))
xreal_=tf.real(x_)
ximag_=tf.imag(x_)
yreal_=tf.real(y_)
yimag_=tf.imag(y_)
#The first layer : v=y
Byreal_=tf.matmul(B_,yreal_)
Byimag_=tf.matmul(B_,yimag_)
By_=tf.complex(Byreal_,Byimag_)
#theta_init=theta_init*np.ones((N,1),dtype=np.float64)
theta_=tf.Variable(theta_init,dtype=tf.float64,name='theta_0')
var_all.append(theta_)
OneOverM=tf.constant(float(1)/M,dtype=tf.float64)
NOverM=tf.constant(float(N)/M,dtype=tf.complex128)
rvar_=OneOverM*tf.reduce_sum(tf.square(tf.abs(y_)),0)
xhat_,dxdr_=eta(By_,rvar_,theta_)
v_=y_
for t in range(1,T):
b_=NOverM*dxdr_
Axreal_=tf.matmul(A_,tf.real(xhat_))
Aximag_=tf.matmul(A_,tf.imag(xhat_))
Ax_=tf.complex(Axreal_,Aximag_)
v_=y_-Ax_+b_*v_
temp=tf.abs(v_)
rvar_=OneOverM*tf.reduce_sum(temp*temp,0)
theta_=tf.Variable(theta_init,dtype=tf.float64,name='theta_'+str(t))
var_all.append(theta_)
if untied:
B_=tf.Variable(B,dtype=tf.float64,name='B_'+str(t))
Bvreal_=tf.matmul(B_,tf.real(v_))
Bvimag_=tf.matmul(B_,tf.imag(v_))
Bv_=tf.complex(Bvreal_,Bvimag_)
rhat_=xhat_+Bv_
var_all.append(B_)
layer.append( ('LAMP-{0} linear T={1}'.format(shrink,t+1),rhat_ ,(B_,),tuple(var_all),(0,) ) )
else:
Bvreal_=tf.matmul(B_,tf.real(v_))
Bvimag_=tf.matmul(B_,tf.imag(v_))
Bv_=tf.complex(Bvreal_,Bvimag_)
rhat_=xhat_+Bv_
xhat_,dxdr_=eta(rhat_,rvar_,theta_)
nmse_ = tf.reduce_mean(
tf.reduce_sum(tf.square(tf.abs(xhat_ -x_)), axis=0) / tf.reduce_sum(tf.square(tf.abs(x_)), axis=0))
sess=tf.Session()
sess.run(tf.global_variables_initializer())
load_trainable_vars(sess,trainedfilename)
SNR=[]
nmse_SNR=[]
for v in tf.global_variables():
theta1=sess.run(v)
print(str(v.name)+' '+ str(theta1))
if SNRrange is '0to10dB':
SNRindex=[0,5]
ibegin=0
iend=2
else :
if SNRrange is '10to20dB':
SNRindex=[10,15,20]
ibegin=2
iend=5
for snr in SNRindex:
D = loadmat(type+'testproblem'+ MN + str(snr) + 'dB.mat')
xt=D['x']
yt=D['y']
xhat=sess.run(xhat_,feed_dict={x_:xt,y_:yt})
nmse=sess.run(nmse_,feed_dict={x_:xt,y_:yt})
nmse_dB=10*np.log10(nmse)
print(str(snr)+'dB:'+str(nmse_dB))
SNR=np.append(SNR,snr)
nmse_SNR=np.append(nmse_SNR,nmse_dB)
print(nmse_SNR)
results=loadmat(saveresultsname)
if shrink is 'gm':
GMLAMP_nmse=results['GMLAMP_nmse']
GMLAMP_nmse=GMLAMP_nmse[0]
GMLAMP_nmse[ibegin:iend]=nmse_SNR
print(GMLAMP_nmse)
D=dict(GMLAMP_nmse=GMLAMP_nmse)
savemat(saveresultsname,D)
else:
if shrink is 'soft':
#print(results)
LAMP_nmse=results['LAMP_nmse']
LAMP_nmse=LAMP_nmse[0]
LAMP_nmse[ibegin:iend]=nmse_SNR
print(LAMP_nmse)
D = dict(LAMP_nmse=LAMP_nmse)
savemat(saveresultsname, D)
在shrink设置为gm时没有报错,但是为shrink时候报错
猜测可能是soft的相关训练集出现了问题?
但是训练集都是论文文件直接给好的npz文件,我要去修改它吗?
但是根据生成的mat垫文件进行仿真后,出来shrink=gm的图又跟实验结果完全不一样;
才接触tensorflow和anaconda,如果太白痴各位大佬别骂我,,,嘤嘤嘤
想问下大佬们呢这个维度错误如何解决?