zxcKAMEHAME 2023-05-24 03:29 采纳率: 25%
浏览 13

Unet训练,影像传入路径来自多个路径 修改

我在训练unet时报错,原因是mygene传入的参数是folders列表,但是我的影像是来自六个不同的资料夹(五张单通道影像 一张三通道影像),该怎么修改
unet model部分已修改成八通道 def unet(pretrained_weights = None,input_size = (256,256,8)):
参考的是zhixuhao/unet

main.py

from model import *
from data import *

#os.environ["CUDA_VISIBLE_DEVICES"] = "0"

def show_train_history(train_history, train,name):
    plt.plot(train_history.history[train])
    plt.title(name)
    plt.ylabel('train')
    plt.xlabel('Epoch')
    plt.legend(['train'], loc='center right')
    #plt.show()
    plt.savefig(name+".png")
    plt.close()


acc_name ="Accuracy"
loss_name="Loss"


data_gen_args = dict(rotation_range=0.2,
                    width_shift_range=0.05,
                    height_shift_range=0.05,
                    shear_range=0.05,
                    zoom_range=0.05,
                    horizontal_flip=True,
                    fill_mode='nearest')
# 訓練資料路徑列表
folders = [
    'data/syu/train1day/auto_correlation1_reverse',
    'data/syu/train1day/energy1',
    'data/syu/train1day/entropy1_reverse',
    'data/syu/train1day/homogeneity1',
    'data/syu/train1day/temprature',
    'data/syu/train1day/clahe1'
]
myGene = trainGenerator(8, folders,'image','label',data_gen_args,save_to_dir = None)

model = unet(input_size=(256, 256, 8))
model_checkpoint = ModelCheckpoint('weight/unet_8channels.hdf5', monitor='loss',verbose=1, save_best_only=True)
train_history=model.fit_generator(myGene,steps_per_epoch=200,epochs=20,callbacks=[model_checkpoint])

acc_name ="Accuracy"
loss_name="Loss"
show_train_history(train_history, '1accuracy',acc_name)
show_train_history(train_history, 'loss',loss_name)

data.py 报错的部分在def trainGenerator,

from __future__ import print_function
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import numpy as np 
import os
import glob
import skimage.io as io
import skimage.transform as trans
import cv2
from matplotlib import pyplot as plt
from skimage import img_as_ubyte
import tensorflow as tf

Sky = [128,128,128]
Building = [128,0,0]
Pole = [192,192,128]
Road = [128,64,128]
Pavement = [60,40,222]
Tree = [128,128,0]
SignSymbol = [192,128,128]
Fence = [64,64,128]
Car = [64,0,128]
Pedestrian = [64,64,0]
Bicyclist = [0,128,192]
Unlabelled = [0,0,0]

COLOR_DICT = np.array([Sky, Building, Pole, Road, Pavement,
                          Tree, SignSymbol, Fence, Car, Pedestrian, Bicyclist, Unlabelled])



def normalized_the_gray_value(img):
    gray_max=img.max()
    gray_min=img.min()
    
    normalized_img= ( (img-gray_min) / (gray_max-gray_min) )*255
    
    return normalized_img
    

#  change the img to  YUV and normalized the Y value
def RGB_Y_normalized(img,gamma_value):
    #print(img_out)
    #print(gamma_value)
    YUV = cv2.cvtColor(img, cv2.COLOR_BGR2YUV)
    
    Y, U, V = cv2.split(YUV)
    #print(y)
    y_max=Y.max()
    y_min=Y.min()
    #print(y_min)
    #print(y_max)
    temp_y=  ( ( ( Y-y_min)  / ( y_max - y_min ) )  **gamma_value )*255 
    
    # Y normalized and return
    #YCrCb_array = np.zeros((img.shape[0], img.shape[1], 3), "uint8")
    #YUV_array = np.zeros((img.shape[0], img.shape[1], 3))
    YUV_array = np.zeros((img.shape[0], img.shape[1], 3), "uint8")
    YUV_array[:, :, 0], YUV_array[:, :, 1], YUV_array[:, :, 2] = temp_y, YUV[:,:,1], YUV[:,:,2]
    
    #print(YCrCb_array[:, :, 0])
    
    
    # return the YUV channel  to RGB channel image array
    #print(np.shape(YUV_array))
    final_rgb=cv2.cvtColor(YUV_array, cv2.COLOR_YUV2BGR)
    
    
    
    
    return final_rgb
    
    

#  change the img to  YCbCr and normalized the Y value
def YCbCr_normalized(img,gamma_value):
    #print(img_out)
    #print(gamma_value)
    YCrCb = cv2.cvtColor(img, cv2.COLOR_BGR2YCR_CB)
    
    Y, Cr, Cb = cv2.split(YCrCb)
    #print(y)
    y_max=Y.max()
    y_min=Y.min()
    #print(y_min)
    #print(y_max)
    temp_y=  ( ( ( Y-y_min)  / ( y_max - y_min ) )  **gamma_value )*255 
    
    # Y normalized and return
    #YCrCb_array = np.zeros((img.shape[0], img.shape[1], 3), "uint8")
    YCrCb_array = np.zeros((img.shape[0], img.shape[1], 3))
    YCrCb_array[:, :, 0], YCrCb_array[:, :, 1], YCrCb_array[:, :, 2] = temp_y, YCrCb[:,:,1], YCrCb[:,:,2]
    
    #print(YCrCb_array[:, :, 0])
    return YCrCb_array


def adjustData(img,mask,flag_multi_class,num_class):
    if(flag_multi_class):
        img = img / 255
        mask = mask[:,:,:,0] if(len(mask.shape) == 4) else mask[:,:,0]
        new_mask = np.zeros(mask.shape + (num_class,))
        for i in range(num_class):
            #for one pixel in the image, find the class in mask and convert it into one-hot vector
            #index = np.where(mask == i)
            #index_mask = (index[0],index[1],index[2],np.zeros(len(index[0]),dtype = np.int64) + i) if (len(mask.shape) == 4) else (index[0],index[1],np.zeros(len(index[0]),dtype = np.int64) + i)
            #new_mask[index_mask] = 1
            new_mask[mask == i,i] = 1
        new_mask = np.reshape(new_mask,(new_mask.shape[0],new_mask.shape[1]*new_mask.shape[2],new_mask.shape[3])) if flag_multi_class else np.reshape(new_mask,(new_mask.shape[0]*new_mask.shape[1],new_mask.shape[2]))
        mask = new_mask
    elif(np.max(img) > 1):
        img = img / 255
        mask = mask /255
        mask[mask > 0.5] = 1
        mask[mask <= 0.5] = 0
    return (img,mask)



def trainGenerator(batch_size,train_path,image_folder,mask_folder,aug_dict,image_color_mode = "rgba",
                    mask_color_mode = "grayscale",image_save_prefix  = "image",mask_save_prefix  = "mask",
                    flag_multi_class = False,num_class = 2,save_to_dir = "data_train/test/",target_size = (256,256),seed = 1):
    '''
    can generate image and mask at the same time
    use the same seed for image_datagen and mask_datagen to ensure the transformation for image and mask is the same
    if you want to visualize the results of generator, set save_to_dir = "your path"
    '''
    image_datagen = ImageDataGenerator(**aug_dict)
    mask_datagen = ImageDataGenerator(**aug_dict)
    image_generator = image_datagen.flow_from_directory(
        train_path,
        classes = [image_folder],
        class_mode = None,
        color_mode = image_color_mode,
        target_size = target_size,
        batch_size = batch_size,
        save_to_dir = save_to_dir,
        save_prefix  = image_save_prefix,
        seed = seed)
    mask_generator = mask_datagen.flow_from_directory(
        train_path,
        classes = [mask_folder],
        class_mode = None,
        color_mode = mask_color_mode,
        target_size = target_size,
        batch_size = batch_size,
        save_to_dir = save_to_dir,
        save_prefix  = mask_save_prefix,
        seed = seed)
     
    train_generator = zip(image_generator, mask_generator)
    for (img,mask) in train_generator:
        img,mask = adjustData(img,mask,flag_multi_class,num_class)
        yield (img,mask)
    save_to_dir = "C:/Labbb" #橘色字部分 可視化結果



def testGenerator(test_path,num_image = 30,target_size = (256,256),flag_multi_class = False,as_gray = True):
    for i in range(num_image):
        img = io.imread(os.path.join(test_path,"%d.png"%i),as_gray = False)
        img = img / 255
        img = trans.resize(img,target_size)
        #img = np.reshape(img,img.shape+(1,)) if (not flag_multi_class) else img
        img = np.reshape(img,(1,)+img.shape)
        yield img


def geneTrainNpy(image_path,mask_path,gamma_value,flag_multi_class = False,num_class = 2,image_prefix = "image",mask_prefix = "mask",image_as_gray = True,mask_as_gray = True):
    
    # image_name_arr = glob.glob(os.path.join(image_path,"%s*.png"%image_prefix))
    image_name_arr = glob.glob(os.path.join(image_path,"*.png"))

    # mask name array ----  get the mask path
    mask_name_arr = glob.glob(os.path.join(mask_path,"*.png"))
    
    
    target_size = (256,256)
    image_arr = []
    mask_arr = []
    for index,item in enumerate(image_name_arr):
        #img = io.imread(item,as_gray = image_as_gray)
        img = cv2.imdecode(np.fromfile(item, dtype=np.uint8), 1)

        #plt.subplot(1, 2, 1)
        #plt.title('img')
        #plt.imshow(img)
        
        # Rerurn the Y normalized rgb channel array
        img =RGB_Y_normalized(img,gamma_value)
        
        #print("take RGB to YCbCr " ,end=" ")
        # resize the image
        img = cv2.resize(img, (256, 256), interpolation=cv2.INTER_LINEAR)
        #img = trans.resize(img,target_size)
        
        
        
        #img = normalized_the_gray_value(img)
        
        #plt.subplot(1, 2, 2)
        #plt.title('img')
        #plt.imshow(img)
        
        #plt.imshow(img,cmap='gray')
        
        
        img = np.reshape(img,img.shape + (1,)) if image_as_gray else img
        #mask = io.imread(item.replace(image_path,mask_path).replace(image_prefix,mask_prefix),as_gray = mask_as_gray)
        
        mask = cv2.imdecode(np.fromfile(mask_name_arr[index], dtype=np.uint8), 0)
        #mask = io.imread(item,as_gray = mask_as_gray)
        
        
        
        
        
        # resize the mask
        mask = cv2.resize(mask, (256, 256), interpolation=cv2.INTER_LINEAR)
        #mask = trans.resize(mask,target_size)
        
        #plt.subplot(1, 2, 2)
        #plt.title('mask')
        #plt.imshow(mask,cmap='gray')
        
        mask = np.reshape(mask,mask.shape + (1,)) if mask_as_gray else mask
        
       
        #plt.show()
        
        #print(index ,item , mask_name_arr[index],img.shape)
        
        img,mask = adjustData(img,mask,flag_multi_class,num_class)
        image_arr.append(img)
        mask_arr.append(mask)
    image_arr = np.array(image_arr)
    mask_arr = np.array(mask_arr)
    return image_arr,mask_arr


'''def labelVisualize(num_class,color_dict,img):
    img = img[:,:,0] if len(img.shape) == 3 else img
    img_out = np.zeros(img.shape + (3,))
    for i in range(num_class):
        img_out[img == i,:] = color_dict[i]
    return img_out / 255'''



'''def saveResult(save_path,npyfile,flag_multi_class = False,num_class = 2):
    for i,item in enumerate(npyfile):
        img = labelVisualize(num_class,COLOR_DICT,item) if flag_multi_class else item[:,:,0]
        io.imsave(os.path.join(save_path,"%d_predict.jpg"%i),img)'''

def predict_RGBcolor_img(test_path,save_path,model,test_pic_num):
    target_size = (256,256)
    files_name=os.listdir(test_path)
    for i in range(len(files_name)):
        img = io.imread(os.path.join(test_path,"%s"%files_name[i]))
        #print(files_name)
       # io.imsave(os.path.join(save_path,"%s"%files_name[i]),img_as_ubyte(img))
        img = img / 255
        img = trans.resize(img,target_size)
        img = np.reshape(img,(1,)+img.shape)
        results = model.predict(img)
        a=results[0][:,:,0]
        '''for i in range(len(a)):
            for j in range(len(a[i])):
                if a[i][j]>0.5:
                    a[i][j]=255
                else:
                    a[i][j]=0'''
        a[a<=0.5]=0
        a[a>0.5]=255
        a=a.astype(np.uint8)
        #print(a.shape)
        #a[a<0.5]=0
        #a[a>=0.5]=255
        #plt.imshow(results[0][:,:,0],cmap="gray")
        #plt.show()
        io.imsave(os.path.join(save_path,"%s"%files_name[i]),img_as_ubyte(a))

  • 写回答

1条回答 默认 最新

      报告相同问题?

      相关推荐 更多相似问题

      问题事件

      • 创建了问题 5月24日

      悬赏问题

      • ¥100 支付宝sdk原生链接转h5
      • ¥20 VS2019如何添加.mdf文件失败
      • ¥15 SeaTunnel多Transform配置问题
      • ¥15 消除字符串,求最短字符串长度
      • ¥20 有人做基于集员滤波的异常值处理相关的内容吗?(语言-matlab)
      • ¥30 matlab编程,用chatGPT帮助,但给出的code总是报错。
      • ¥15 离线安装VS2017出现报错
      • ¥50 opengl2怎么将梯形的纹理映射在矩形上面不变形
      • ¥15 起终点不同的tsp旅行商问题
      • ¥15 博途V16变频器CU320-2pn版本为2.34的gsd文件