是多多多多多啊 2022-03-07 17:53 采纳率: 0%
浏览 40

pytorch框架下写的一个物理增强physenNet 网络 运行出来的图片不对。急着毕业

问题遇到的现象和发生背景
问题相关代码,请勿粘贴截图
运行结果及报错内容
我的解答思路和尝试过的方法
我想要达到的结果

```python

import os
import cv2
import numpy as np
import torch
import torchvision
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets, models
#from torchsummary import summary


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")#判断能否使用GPU加速
dx = 6.4e-3
r = 0.532e-3
z = 50


# 角谱衍射算法
def angular_diff(U0, z, dx, r):
    [M, N] = np.shape(U0)

    # 建立频谱坐标系
    fx = (1 / (N * dx)) * (np.linspace(0, N - 1, N) - N / 2)
    fy = (1 / (M * dx)) * (np.linspace(0, N - 1, N) - N / 2)
    [fx, fy] = np.meshgrid(fx, fy)

    trans = np.exp(1j*2*np.pi*z/r*np.sqrt(1-(r*fx)**2-(r*fy)**2))
    Uf = np.fft.fftshift(np.fft.fft2(U0))*trans
    U = np.fft.ifft2(np.fft.ifftshift(Uf))

    I = U*np.conj(U)
    I = np.abs(I)
    return U, I


def torch_fftshift(real, imag):
    [M, N] = real.size()
    real = torch.roll(real, shifts=(M//2, N//2), dims=(0, 1))
    imag = torch.roll(imag, shifts=(M//2, N//2), dims=(0, 1))
    return real, imag


def torch_ifftshift(real, imag):
    [M, N] = real.size()
    real = torch.roll(real, shifts=(-M//2, -N//2), dims=(0, 1))
    imag = torch.roll(imag, shifts=(-M//2, -N//2), dims=(0, 1))
    return real, imag


def torch_diff(img, dx, r, z, size):
    img_tensor_fft = torch.fft.fft(img)#, 2, True
    [M, N] = size

    fx = (1 / (N * dx)) * (np.linspace(0, N - 1, N) - N / 2)
    fy = (1 / (M * dx)) * (np.linspace(0, N - 1, N) - N / 2)
    [fx, fy] = np.meshgrid(fx, fy)
    trans = np.exp(1j * 2 * np.pi * z / r * np.sqrt(1 - (r * fx) ** 2 - (r * fy) ** 2))
    trans_real = trans.real
    trans_real = trans_real.reshape((M, N, 1))
    trans_imag = trans.imag
    trans_imag = trans_imag.reshape((M, N, 1))
    trans_ = np.append(trans_real, trans_imag, axis=2)
    trans_tensor = torch.from_numpy(trans_)
    trans_tensor = trans_tensor.float()
    trans_tensor = trans_tensor.to(device)

    trans_tensor_real = trans_tensor[:, :, 0]
    trans_tensor_imag = trans_tensor[:, :, 1]

    img_tensor_fft_real = img_tensor_fft[:, :, 0]
    img_tensor_fft_imag = img_tensor_fft[:, :, 1]
    img_tensor_fft_real, img_tensor_fft_imag = torch_fftshift(img_tensor_fft_real, img_tensor_fft_imag)

    Uf_real = torch.mul(trans_tensor_real, img_tensor_fft_real) - torch.mul(trans_tensor_imag, img_tensor_fft_imag)
    Uf_imag = torch.mul(trans_tensor_real, img_tensor_fft_imag) + torch.mul(trans_tensor_imag, img_tensor_fft_real)
    Uf = torch.stack((Uf_real, Uf_imag), 2)

    U = torch.fft.ifft(Uf)#, 2, True
    U_real = U[:, :, 0]
    # U_imag = U[:, :, 1]
    # U_real, U_imag = torch_ifftshift(U_real, U_imag)
    # U = torch.stack((U_real, U_imag), 2)
    intensity = torch.mul(U_real, U_real)

    return U, intensity


def read_data(img_path, mode="RGB"):
    imgs_name = os.listdir(img_path)
    imgs_dir = [os.path.join(img_path, name) for name in imgs_name]
    imgs = []
    for path in imgs_dir:
        if mode == "L":
            img = cv2.imread(path, 0)
            # img = cv2.resize(img, (512, 512))
            # img = img.reshape(512, 512, 1)
        else:
            img = cv2.imread(path)
            # img = cv2.resize(img, (512, 512))
            # img = img.reshape(512, 512, 3)

        imgs.append(img)
    return imgs


class My_Dataset(Dataset):
    def __init__(self, img_path, transform=None):
        self.imgs = read_data(img_path, mode="L")
        # self.labels = read_data(label_path, mode="L")
        self.tranform = transform

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        img = self.imgs[idx]
        # label = self.labels[idx]
        if self.tranform:
            img = self.tranform(img)
            # label = self.tranform(label)
        return img


trans = transforms.Compose([
    transforms.ToTensor(),
])

train_set = My_Dataset('data/',
                       transform=trans)
# val_set = My_Dataset('data/test/1',
#                      "data/test/2", transform=trans)

batch_size = 1
data_loader = {
    'train': DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0)
}


# train_img = cv2.imread('UASF_50.bmp', 0)
# train_img = trans(train_img)


def double_conv(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
        nn.ReLU(inplace=True)
    )


class UNet_model(nn.Module):
    def __init__(self):
        super().__init__()

        self.dconv_down1 = double_conv(1, 64)
        self.dconv_down2 = double_conv(64, 128)
        self.dconv_down3 = double_conv(128, 256)
        self.dconv_down4 = double_conv(256, 512)

        self.maxpool = nn.MaxPool2d(kernel_size=2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        self.dconv_up3 = double_conv(256 + 512, 256)
        self.dconv_up2 = double_conv(128 + 256, 128)
        self.dconv_up1 = double_conv(128 + 64, 64)

        self.conv_last = nn.Conv2d(64, 1, kernel_size=1)

    def forward(self, x):
        conv1 = self.dconv_down1(x)
        x = self.maxpool(conv1)

        conv2 = self.dconv_down2(x)
        x = self.maxpool(conv2)

        conv3 = self.dconv_down3(x)
        x = self.maxpool(conv3)

        x = self.dconv_down4(x)

        x = self.upsample(x)
        x = torch.cat([x, conv3], dim=1)

        x = self.dconv_up3(x)
        x = self.upsample(x)
        x = torch.cat([x, conv2], dim=1)

        x = self.dconv_up2(x)
        x = self.upsample(x)
        x = torch.cat([x, conv1], dim=1)

        x = self.dconv_up1(x)
        out = self.conv_last(x)

        result = out[0][0]

        # focus_ = out.data.cpu().numpy()
        # focus_ = focus_.reshape((512, 512))
        # [_, focus] = angular_diff(focus_, z, dx, r)
        # focus = 255 * focus / np.max(focus)
        # focus = focus.reshape((1, 1, 512, 512))
        # focus = torch.from_numpy(focus)

        img_real = torch.reshape(out, (512, 512))
        img_imag = torch.zeros((512, 512), device=device)
        img = torch.stack((img_real, img_imag), 2)
        U, intensity = torch_diff(img, dx, r, z, (512, 512))
        intensity = torch.reshape(intensity, (1, 1, 512, 512))

        return intensity, result


model = UNet_model().to(device)
# summary(model, input_size=(1, 512, 512))

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

# train
for epoch in range(10000):
    for images in data_loader['train']:

        images = images.to(device)
        # labels = labels.to(device)
        output, result = model(images)
        #output = output.to(device)

        output=torch.tensor(output,dtype=float)
        #output=output.clone().detach()
        loss = criterion(images, output)

        optimizer.zero_grad()
        loss.requires_grad_(True)
        loss.backward()
        optimizer.step()
        print('Epoch[{}], loss: {:.4f}'.format(epoch, loss.item()))

        if epoch % 500 == 0:
            torchvision.utils.save_image(result, str(epoch) + ".png")

model.eval()
torch.save(model.state_dict(), 'Diff.ckpt')


# model.load_state_dict(torch.load('Diff.ckpt'))
# model.eval()
#
# # test
# for images in data_loader['train']:
#     images = images.to(device)
#     output = model(images)
#
#     img = output[0][0]
#     torchvision.utils.save_image(img, "result.png")

#     img = output[1][0]
#     torchvision.utils.save_image(img, str(i * 2 + 1) + ".png")


```

  • 写回答

1条回答 默认 最新

  • 丨封尘绝念斩丨 2022-03-07 18:12
    关注

    报的错误是啥

    评论

报告相同问题?

问题事件

  • 创建了问题 3月7日

悬赏问题

  • ¥15 如何让企业微信机器人实现消息汇总整合
  • ¥50 关于#ui#的问题:做yolov8的ui界面出现的问题
  • ¥15 如何用Python爬取各高校教师公开的教育和工作经历
  • ¥15 TLE9879QXA40 电机驱动
  • ¥20 对于工程问题的非线性数学模型进行线性化
  • ¥15 Mirare PLUS 进行密钥认证?(详解)
  • ¥15 物体双站RCS和其组成阵列后的双站RCS关系验证
  • ¥20 想用ollama做一个自己的AI数据库
  • ¥15 关于qualoth编辑及缝合服装领子的问题解决方案探寻
  • ¥15 请问怎么才能复现这样的图呀