问题遇到的现象和发生背景
问题相关代码,请勿粘贴截图
运行结果及报错内容
我的解答思路和尝试过的方法
我想要达到的结果
```python
import os
import cv2
import numpy as np
import torch
import torchvision
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets, models
#from torchsummary import summary
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")#判断能否使用GPU加速
dx = 6.4e-3
r = 0.532e-3
z = 50
# 角谱衍射算法
def angular_diff(U0, z, dx, r):
[M, N] = np.shape(U0)
# 建立频谱坐标系
fx = (1 / (N * dx)) * (np.linspace(0, N - 1, N) - N / 2)
fy = (1 / (M * dx)) * (np.linspace(0, N - 1, N) - N / 2)
[fx, fy] = np.meshgrid(fx, fy)
trans = np.exp(1j*2*np.pi*z/r*np.sqrt(1-(r*fx)**2-(r*fy)**2))
Uf = np.fft.fftshift(np.fft.fft2(U0))*trans
U = np.fft.ifft2(np.fft.ifftshift(Uf))
I = U*np.conj(U)
I = np.abs(I)
return U, I
def torch_fftshift(real, imag):
[M, N] = real.size()
real = torch.roll(real, shifts=(M//2, N//2), dims=(0, 1))
imag = torch.roll(imag, shifts=(M//2, N//2), dims=(0, 1))
return real, imag
def torch_ifftshift(real, imag):
[M, N] = real.size()
real = torch.roll(real, shifts=(-M//2, -N//2), dims=(0, 1))
imag = torch.roll(imag, shifts=(-M//2, -N//2), dims=(0, 1))
return real, imag
def torch_diff(img, dx, r, z, size):
img_tensor_fft = torch.fft.fft(img)#, 2, True
[M, N] = size
fx = (1 / (N * dx)) * (np.linspace(0, N - 1, N) - N / 2)
fy = (1 / (M * dx)) * (np.linspace(0, N - 1, N) - N / 2)
[fx, fy] = np.meshgrid(fx, fy)
trans = np.exp(1j * 2 * np.pi * z / r * np.sqrt(1 - (r * fx) ** 2 - (r * fy) ** 2))
trans_real = trans.real
trans_real = trans_real.reshape((M, N, 1))
trans_imag = trans.imag
trans_imag = trans_imag.reshape((M, N, 1))
trans_ = np.append(trans_real, trans_imag, axis=2)
trans_tensor = torch.from_numpy(trans_)
trans_tensor = trans_tensor.float()
trans_tensor = trans_tensor.to(device)
trans_tensor_real = trans_tensor[:, :, 0]
trans_tensor_imag = trans_tensor[:, :, 1]
img_tensor_fft_real = img_tensor_fft[:, :, 0]
img_tensor_fft_imag = img_tensor_fft[:, :, 1]
img_tensor_fft_real, img_tensor_fft_imag = torch_fftshift(img_tensor_fft_real, img_tensor_fft_imag)
Uf_real = torch.mul(trans_tensor_real, img_tensor_fft_real) - torch.mul(trans_tensor_imag, img_tensor_fft_imag)
Uf_imag = torch.mul(trans_tensor_real, img_tensor_fft_imag) + torch.mul(trans_tensor_imag, img_tensor_fft_real)
Uf = torch.stack((Uf_real, Uf_imag), 2)
U = torch.fft.ifft(Uf)#, 2, True
U_real = U[:, :, 0]
# U_imag = U[:, :, 1]
# U_real, U_imag = torch_ifftshift(U_real, U_imag)
# U = torch.stack((U_real, U_imag), 2)
intensity = torch.mul(U_real, U_real)
return U, intensity
def read_data(img_path, mode="RGB"):
imgs_name = os.listdir(img_path)
imgs_dir = [os.path.join(img_path, name) for name in imgs_name]
imgs = []
for path in imgs_dir:
if mode == "L":
img = cv2.imread(path, 0)
# img = cv2.resize(img, (512, 512))
# img = img.reshape(512, 512, 1)
else:
img = cv2.imread(path)
# img = cv2.resize(img, (512, 512))
# img = img.reshape(512, 512, 3)
imgs.append(img)
return imgs
class My_Dataset(Dataset):
def __init__(self, img_path, transform=None):
self.imgs = read_data(img_path, mode="L")
# self.labels = read_data(label_path, mode="L")
self.tranform = transform
def __len__(self):
return len(self.imgs)
def __getitem__(self, idx):
img = self.imgs[idx]
# label = self.labels[idx]
if self.tranform:
img = self.tranform(img)
# label = self.tranform(label)
return img
trans = transforms.Compose([
transforms.ToTensor(),
])
train_set = My_Dataset('data/',
transform=trans)
# val_set = My_Dataset('data/test/1',
# "data/test/2", transform=trans)
batch_size = 1
data_loader = {
'train': DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0)
}
# train_img = cv2.imread('UASF_50.bmp', 0)
# train_img = trans(train_img)
def double_conv(in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True)
)
class UNet_model(nn.Module):
def __init__(self):
super().__init__()
self.dconv_down1 = double_conv(1, 64)
self.dconv_down2 = double_conv(64, 128)
self.dconv_down3 = double_conv(128, 256)
self.dconv_down4 = double_conv(256, 512)
self.maxpool = nn.MaxPool2d(kernel_size=2)
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.dconv_up3 = double_conv(256 + 512, 256)
self.dconv_up2 = double_conv(128 + 256, 128)
self.dconv_up1 = double_conv(128 + 64, 64)
self.conv_last = nn.Conv2d(64, 1, kernel_size=1)
def forward(self, x):
conv1 = self.dconv_down1(x)
x = self.maxpool(conv1)
conv2 = self.dconv_down2(x)
x = self.maxpool(conv2)
conv3 = self.dconv_down3(x)
x = self.maxpool(conv3)
x = self.dconv_down4(x)
x = self.upsample(x)
x = torch.cat([x, conv3], dim=1)
x = self.dconv_up3(x)
x = self.upsample(x)
x = torch.cat([x, conv2], dim=1)
x = self.dconv_up2(x)
x = self.upsample(x)
x = torch.cat([x, conv1], dim=1)
x = self.dconv_up1(x)
out = self.conv_last(x)
result = out[0][0]
# focus_ = out.data.cpu().numpy()
# focus_ = focus_.reshape((512, 512))
# [_, focus] = angular_diff(focus_, z, dx, r)
# focus = 255 * focus / np.max(focus)
# focus = focus.reshape((1, 1, 512, 512))
# focus = torch.from_numpy(focus)
img_real = torch.reshape(out, (512, 512))
img_imag = torch.zeros((512, 512), device=device)
img = torch.stack((img_real, img_imag), 2)
U, intensity = torch_diff(img, dx, r, z, (512, 512))
intensity = torch.reshape(intensity, (1, 1, 512, 512))
return intensity, result
model = UNet_model().to(device)
# summary(model, input_size=(1, 512, 512))
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
# train
for epoch in range(10000):
for images in data_loader['train']:
images = images.to(device)
# labels = labels.to(device)
output, result = model(images)
#output = output.to(device)
output=torch.tensor(output,dtype=float)
#output=output.clone().detach()
loss = criterion(images, output)
optimizer.zero_grad()
loss.requires_grad_(True)
loss.backward()
optimizer.step()
print('Epoch[{}], loss: {:.4f}'.format(epoch, loss.item()))
if epoch % 500 == 0:
torchvision.utils.save_image(result, str(epoch) + ".png")
model.eval()
torch.save(model.state_dict(), 'Diff.ckpt')
# model.load_state_dict(torch.load('Diff.ckpt'))
# model.eval()
#
# # test
# for images in data_loader['train']:
# images = images.to(device)
# output = model(images)
#
# img = output[0][0]
# torchvision.utils.save_image(img, "result.png")
# img = output[1][0]
# torchvision.utils.save_image(img, str(i * 2 + 1) + ".png")
```