SilyaSophie 2024-04-01 22:15 采纳率: 46.2%
浏览 7

关于#Python#生成对抗网络代码的问题,如何解决?(相关搜索:python代码|数据集|训练集)

代码如下:

from __future__ import print_function, division

from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Conv1D,GRU,Dropout,InputLayer
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam

import matplotlib.pyplot as plt

import sys

import numpy as np
import pandas as pd
import import2023


class GAN:
    def __init__(self):
        self.img_rows = 28
        self.img_cols = 28
        self.channels = 1
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        self.latent_dim = 100

        optimizer = Adam(0.0002, 0.5)

        # Build and compile the discriminator
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='mse',
                                   optimizer=optimizer,
                                   metrics=['accuracy'])

        # Build the generator
        self.generator = self.build_generator()

        # The generator takes noise as input and generates imgs
        z = Input(shape=(self.latent_dim,))
        img = self.generator(z)

        # For the combined model we will only train the generator
        self.discriminator.trainable = False

        # The discriminator takes generated images as input and determines validity
        validity = self.discriminator(img)

        # The combined model  (stacked generator and discriminator)
        # Trains the generator to fool the discriminator
        self.combined = Model(z, validity)
        self.combined.compile(loss='mse', optimizer=optimizer)

    def build_generator(self):

        model = Sequential()
        model.add(GRU(256, input_shape=self.img_shape, activation='relu'))
        #model.add(Dense(256, input_dim=self.latent_dim))
        model.add(Dense(512, activation='relu'))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(512))
        model.add(Dense(512, activation='relu'))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(1024))
        model.add(Dense(1024, activation='relu'))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(np.prod(self.img_shape), activation='tanh'))
        model.add(Reshape(self.img_shape))

        model.summary()

        noise = Input(shape=(self.latent_dim,1,1))
        img = model(noise)

        return Model(noise,img)

    def build_discriminator(self):

        model = Sequential()

        #model.add(Flatten(input_shape=self.img_shape))
        model.add(InputLayer(input_shape=(32, 64)))
        #model.add(Dense(256,input_shape=self.img_shape,activation='relu'))
        #Conv2D(filters, kernel_size, data_format='NHWC')
        model.add(Conv1D(1024, kernel_size=3, strides=2, padding='same', data_format='channels_first',activation='relu'))
        model.add(Dense(512))
        model.add(Dense(512, activation='relu'))
        model.add(Dense(256))
        model.add(Dense(64, activation='relu'))
        model.add(Dense(1, activation='sigmoid'))
        model.summary()
        img = Input(shape=self.img_shape)
        validity = model(img)
        return Model(img,validity)

    def train(self, epochs, batch_size=128, sample_interval=50):

        # Load the dataset
        # csv文件路径
        csv_path_train = 'E:/dataset/CICIoT2023/benign.csv'
        # 读取数据
        X_train= pd.read_csv(csv_path_train)
        # Rescale -1 to 1
        X_train = X_train / 127.5 - 1.
        #X_train = np.expand_dims(X_train, axis=3)
        X_train = np.reshape(X_train, (-1, 100, 46))
        print(X_train.shape)

        # Adversarial ground truths
        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))
        print(valid.shape)
        print(fake.shape)


        for epoch in range(epochs):

            # ---------------------
            #  Train Discriminator
            # ---------------------

            # Select a random batch of images
            idx = np.random.randint(0, X_train.shape[0], batch_size)
            imgs = X_train[idx]
            print(imgs.shape)
            noise = np.random.normal(0, 1, (batch_size, self.latent_dim,1))
            print(noise.shape)
            # Generate a batch of new images
            gen_imgs = self.generator.predict(noise)

            # Train the discriminator
            d_loss_real = self.discriminator.train_on_batch(imgs, valid)
            d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # ---------------------
            #  Train Generator
            # ---------------------

            noise = np.random.normal(0, 1, (batch_size, self.latent_dim,1))

            # Train the generator (to have the discriminator label samples as valid)
            g_loss = self.combined.train_on_batch(noise, valid)

            # Plot the progress
            print("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100 * d_loss[1], g_loss))

            # # If at save interval => save generated image samples
            # if epoch % sample_interval == 0:
            #     self.sample_images(epoch)

if __name__ == '__main__':
    gan = GAN()
    gan.train(epochs=30000, batch_size=32, sample_interval=200)

运行结果及报错:

img

img

img

img

img

img

  • 写回答

3条回答 默认 最新

  • 码农阿豪@新空间代码工作室 Java领域优质创作者 2024-04-01 22:30
    关注
    让阿豪来帮你解答,本回答参考chatgpt3.5编写提供,如果还有疑问可以评论或留言
    该代码实现的是一个基于深度学习的生成对抗网络(GAN)模型,用于生成28x28像素的手写数字图像。
    1. 问题:这段代码所实现的功能是什么? 回答: 该段代码实现的是一个基于深度学习的生成对抗网络(GAN)模型,用于生成28x28像素的手写数字图像。
    2. 问题:这段代码中所使用的算法或技术是什么? 回答: 该段代码中使用了生成对抗网络(GAN)算法,包括生成器和判别器两个部分。
    3. 问题:这段代码中的每个函数或方法都是用来做什么的? 回答:
    4. __init__(self):初始化生成对抗网络模型。
    5. build_generator(self):构建生成器模型,生成28x28的图像。
    6. build_discriminator(self):构建判别器模型,判别输入的图像是真实的还是虚假的。
    7. GAN类中的其他函数和方法包括训练模型所需用到的定义和编译模型、导入所需的库、以及实现模型训练、评估等功能的代码。
    8. 问题:这段代码在实现功能时存在哪些缺陷或不足?如何改进? 回答: 该段代码在实现功能时没有明显的缺陷或不足。但是,它只实现了生成28x28像素的手写数字图像,不能应用于其他类型的图像生成。可以改进的方向包括增加模型的泛化能力,使其能够生成其他类型的图像,以及提高模型的训练速度和准确性。可以探索使用更先进的神经网络和优化算法,如增加卷积层、使用批量归一化、使用更优秀的优化算法等。
    评论

报告相同问题?

问题事件

  • 创建了问题 4月1日

悬赏问题

  • ¥15 部分客户订单定位有误的问题
  • ¥15 如何在maya程序中利用python编写领子和褶裥的模型的方法
  • ¥15 Linux权限管理相关操作(求解答)
  • ¥15 Bug traq 数据包 大概什么价
  • ¥15 在anaconda上pytorch和paddle paddle下载报错
  • ¥25 自动填写QQ腾讯文档收集表
  • ¥15 DbVisualizer Pro 12.0.7 sql commander光标错位 显示位置与实际不符
  • ¥15 android 打包报错
  • ¥15 关于stm32的问题
  • ¥15 ncode振动疲劳分析中,noisefloor如何影响PSD函数?