lukas_dsc 2022-07-11 00:07
浏览 11
已结题

利用pandas将数据读取并清洗后导入keras,出现迭代信息不连续的情况

问题遇到的现象和发生背景

在使用keras搭建简单模型时,利用pandas将数据读取并清洗后,采用np.array()转化为np.array格式导入keras网络(只有简单的全连接层,采用sigmoid作为输出函数),但是在调用fit方法时,选择的batchsize为40,epochs为10,出现迭代信息不连续的情况,从40跳跃到1000,接着就是2000,不应该从40到80再到120嘛?见下图:

img

问题相关代码,请勿粘贴截图

import numpy as np
import pandas as pd
import keras as k

original_data = pd.read_csv("E:\Desktop\stroke_data.csv")
def data_process(original_data):
data = original_data.dropna(axis=0, how='any')

data = data.drop(columns='id')

data = data.reset_index(drop=True)
gender = np.zeros(data.shape[0])
for i in range(gender.shape[0]):
    if data.loc[i, 'gender'] == 'Male':
        gender[i] = 1
    else:
        gender[i] = 0
data = data.drop(columns='gender')
data.insert(0, 'gender', gender)

ever_married = np.zeros(data.shape[0])
for i in range(ever_married.shape[0]):
    if data['ever_married'][i] == 'Yes':
        ever_married[i] = 1
    else:
        ever_married[i] = 0
data = data.drop(columns='ever_married')
data.insert(4, 'ever_married', ever_married)

Residence_type = np.zeros(data.shape[0])
for i in range(Residence_type.shape[0]):
    if data['Residence_type'][i] == 'Urban':
        Residence_type[i] = 1
    else:
        Residence_type[i] = 0
data = data.drop(columns='Residence_type')
data.insert(6, 'Residence_type', Residence_type)

data['heart_disease'] = data['heart_disease'].apply(lambda x: x * 10)
data['hypertension'] = data['hypertension'].apply(lambda x: x * 10)

work_type = np.zeros(data.shape[0])
for i in range(work_type.shape[0]):
    if data['work_type'][i] == 'Self-employed':
        work_type[i] = 5
    elif data['work_type'][i] == 'Private':
        work_type[i] = 4
    elif data['work_type'][i] == 'Self-children':
        work_type[i] = 3
    elif data['work_type'][i] == 'Govt_job':
        work_type[i] = 2
    else:
        work_type[i] = 1
data = data.drop(columns='work_type')
data.insert(5, 'work_type', work_type)

rows = [x for x in data.index if data.loc[x]['smoking_status'] == 'Unknown']
data = data.drop(rows, axis=0)

data = data.reset_index(drop=True)
smoking_status = np.zeros(data.shape[0])
for i in range(smoking_status.shape[0]):
    if data['smoking_status'][i] == 'never smoked':
        smoking_status[i] = 0
    elif data['smoking_status'][i] == 'formerly smoked':
        smoking_status[i] = 10
    else:
        smoking_status[i] = 20
data = data.drop(columns='smoking_status')
data.insert(9, 'smoking_status', smoking_status)

bmi_max = data.bmi.max()
bmi_min = data.bmi.min()
data['bmi'] = data['bmi'].apply(lambda x: (x - bmi_min) / ((bmi_max - bmi_min) / 10))

age_max = data.age.max()
age_min = data.age.min()
data['age'] = data['age'].apply(lambda x: (x - age_min) / ((age_max - age_min) / 10))

glucose_max = data.avg_glucose_level.max()
glucose_min = data.avg_glucose_level.min()
data['avg_glucose_level'] = data['avg_glucose_level'].apply(
    lambda x: (x - glucose_min) / ((glucose_max - glucose_min) / 10))

return data

processed_data = data_process(original_data)
data = np.array(processed_data.iloc[:, :-1])
labels = np.expand_dims(np.array(processed_data.iloc[:, -1]), axis=-1)

train_data = data[0:3200, :]
train_label = labels[0:3200, :]

test_data = data[3200:, :]
test_label = labels[3200:, :]

model = k.Sequential()
model.add(k.layers.Dense(8, activation='relu',input_shape=(10,)))
model.add(k.layers.Dense(6, activation='relu'))
model.add(k.layers.Dense(4, activation='relu'))
model.add(k.layers.Dense(1, activation='sigmoid'))

model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['acc'])

model.fit(train_data, train_label, epochs=10, batch_size=40, verbose=1, validation_data=(test_data, test_label))
model.evaluate(test_data, test_label)

我想要达到的结果

我想知道这是什么原因造成的,以及怎么修复?

  • 写回答

1条回答 默认 最新

  • lukas_dsc 2022-07-11 00:09
    关注
    评论

报告相同问题?

问题事件

  • 系统已结题 7月19日
  • 创建了问题 7月11日

悬赏问题

  • ¥100 复现论文:matlab仿真代码编写
  • ¥15 esp32驱动GC9A01循环播放视频
  • ¥15 惠普360g9的最新bios
  • ¥30 这个功能用什么软件发合适?
  • ¥60 微信小程序,取消订单,偶尔订单没有改变状态
  • ¥15 用pytorch实现PPO算法
  • ¥15 关于调制信号的星座图?
  • ¥30 前端传参时,后端接收不到参数
  • ¥15 这是有什么问题吗,我检查许可证了但是显示有呢
  • ¥15 机器学习预测遇到的目标函数问题