jupyter训练minst模型怎么保存不了
parse_mnist_data.ipynb:
import numpy as np
import struct
from PIL import Image
import os
def make_img(fn_img = "t10k-images.idx3-ubyte",
fn_label = "t10k-labels.idx1-ubyte",
size1 = 7840016, size2 = 10008, root = "test"):
fmt_img1 = ">IIII"
offset_img1 = offset_label1= 0
fmt_label1 = ">II"
data_file_size = str(size1 - 16) + "B"
labels_file_size = str(size2 - 8) + "B"
fmt_img2 = ">" + data_file_size
offset_img2 = struct.calcsize(fmt_img1)
fmt_label2 = ">" + labels_file_size
offset_label2 = struct.calcsize(fmt_label1)
with open(fn_img, 'rb') as f:
data_buf = f.read()
with open(fn_label, 'rb') as f:
label_buf = f.read()
magic_img, numImages, numRows, numColumns = struct.unpack_from(fmt_img1, data_buf, offset_img1)
datas = struct.unpack_from(fmt_img2, data_buf, offset_img2)
datas = np.array(datas).astype(np.uint8).reshape(numImages, 1, numRows, numColumns)
magic_label, numLabels = struct.unpack_from(fmt_label1, label_buf, offset_label1)
labels = struct.unpack_from(fmt_label2, label_buf, offset_label2)
labels = np.array(labels).astype(np.int64)
if not os.path.exists(root):
os.mkdir(root)
for i in range(10):
file_name = root + os.sep + str(i)
if not os.path.exists(file_name):
os.mkdir(file_name)
for ii in range(numLabels):
img = Image.fromarray(datas[ii, 0, 0:28, 0:28])
label = labels[ii]
file_name = root + os.sep + str(label) + os.sep + str(ii).zfill(5) + '.png'
img.save(file_name)
if __name__ == "__main__":
make_img()
model_svm.py:
import numpy as np
import os
from PIL import Image
from sklearn.svm import SVC
import joblib
from sklearn.metrics import confusion_matrix, classification_report
import glob
import time
class DataLoader(object):
"""训练前的预处理"""
def get_files(self, fpath, fmt = "*.png"):
"""获取指定文件夹中指定格式的文件列表;
paras:
filepath: str, file path,
formats: str, file format,
return: list;"""
tmp = os.path.join(fpath,fmt)
fs = glob.glob(tmp)
return fs
def get_data_labels(self, fpath = "train"):
paths = glob.glob(fpath + os.sep + "*")
X = []
y = []
for fpath in paths:
fs = self.get_files(fpath)
for fn in fs:
X.append(self.img2vec(fn))
label = np.repeat(int(os.path.basename(fpath)), len(fs))
y.append(label)
labels = y[0]
for i in range(len(y) - 1):
labels = np.append(labels, y[i + 1])
return np.array(X), labels
def img2vec(self, fn):
'''将jpg等格式的图片转为向量'''
im = Image.open(fn).convert('L')
im = im.resize((28,28))
tmp = np.array(im)
vec = tmp.ravel()
return vec
def save_data(self, X_data, y_data, fn = "mnist_train_data"):
"""将数据保存到本地;"""
np.savez_compressed(fn, X = X_data, y = y_data)
def load_data(self, fn = "mnist_train_data.npz"):
"""从本地加载数据;"""
data = np.load(fn)
X_data = data["X"]
y_data = data["y"]
return X_data, y_data
class Trainer(object):
'''训练器;'''
def svc(self, x_train, y_train):
'''构建分类器'''
model = SVC(kernel = 'poly',degree = 4,probability= True)
model.fit(x_train, y_train)
return model
def save_model(self, model, output_name):
'''保存模型'''
joblib.dump(model,output_name, compress = 1)
def load_model(self, model_path):
'''加载模型'''
clf = joblib.load(model_path)
return clf
class Tester(object):
'''测试器;'''
def __init__(self, model_path):
trainer = Trainer()
self.clf = trainer.load_model(model_path)
def clf_metrics(self,X_test,y_test):
"""评估分类器效果"""
pred = self.clf.predict(X_test)
cnf_matrix = confusion_matrix(y_test, pred)
score = self.clf.score(X_test, y_test)
clf_repo = classification_report(y_test, pred)
return cnf_matrix, score, clf_repo
def predict(self, fn):
'''样本预测;'''
loader = DataLoader()
tmp = loader.img2vec(fn)
X_test = tmp.reshape(1, -1)
ans = self.clf.predict(X_test)
return ans
def run_train():
t0 = time.time()
loader = DataLoader()
trainer = Trainer()
X, y = loader.get_data_labels()
t1 = time.time()
print(t1 - t0)
clf = trainer.svc(X, y)
print(time.time() - t1)
trainer.save_model(clf, "mnist_svm.m")
X_test, y_test = loader.get_data_labels("test")
tester = Tester("mnist_svm.m")
mt, score, repo = tester.clf_metrics(X_test, y_test)
return clf, X, y
GUI界面脚本(digit_gui.py):
import wx
from collections import namedtuple
from PIL import Image
import os
from model_svm import Tester
origin_path = os.getcwd()
wildcard ="png (*.png)|*.png|" \
"jpg(*.jpg) |*.jpg|"\
"jpeg(*.jpeg) |*.jpeg|"\
"tiff(*.tif) |*.tiff|"\
"All files (*.*)|*.*"
class MainWindow(wx.Frame):
def __init__(self,parent,title):
wx.Frame.__init__(self,parent,title=title,size=(600,-1))
static_font = wx.Font(12, wx.SWISS, wx.NORMAL, wx.NORMAL)
Size = namedtuple("Size",['x','y'])
s = Size(100,50)
model_path = os.path.join(origin_path,'mnist_svm.m')
self.fileName = None
self.model = Tester(model_path)
b_labels = [u'open',u'run']
TipString = [u'选择图片', u'识别数字']
funcs = [self.choose_file,self.run]
'''create input area'''
self.in1 = wx.TextCtrl(self,-1,size = (2*s.x,3*s.y))
self.out1 = wx.TextCtrl(self,-1,size = (s.x,3*s.y))
'''create button'''
self.sizer0 = wx.FlexGridSizer(cols=4, hgap=4, vgap=2)
self.sizer0.Add(self.in1)
buttons = []
for i,label in enumerate(b_labels):
b = wx.Button(self, id = i,label = label,size = (1.5*s.x,s.y))
buttons.append(b)
self.sizer0.Add(b)
self.sizer0.Add(self.out1)
'''set the color and size of labels and buttons'''
for i,button in enumerate(buttons):
button.SetForegroundColour('red')
button.SetFont(static_font)
# button.SetToolTipString(TipString[i]) #wx2.8
button.SetToolTip(TipString[i]) #wx4.0
button.Bind(wx.EVT_BUTTON,funcs[i])
'''layout'''
self.SetSizer(self.sizer0)
self.SetAutoLayout(1)
self.sizer0.Fit(self)
self.CreateStatusBar()
self.Show(True)
def run(self,evt):
if self.fileName is None:
self.raise_msg(u'请选择一幅图片')
return None
else:
ans = self.model.predict(self.fileName)
self.out1.Clear()
self.out1.write(str(ans))
def choose_file(self,evt):
'''choose img'''
dlg = wx.FileDialog(
self, message="Choose a file",
defaultDir=os.getcwd(),
defaultFile="",
wildcard=wildcard,
# style=wx.OPEN | wx.MULTIPLE | wx.CHANGE_DIR #wx2.8
style = wx.FD_OPEN | wx.FD_MULTIPLE | #wx4.0
wx.FD_CHANGE_DIR | wx.FD_FILE_MUST_EXIST |
wx.FD_PREVIEW
)
if dlg.ShowModal() == wx.ID_OK:
paths = dlg.GetPaths()
dlg.Destroy()
self.in1.Clear()
self.in1.write(paths[0])
self.fileName = paths[0]
im = Image.open(self.fileName)
im.show()
else:
return None
def raise_msg(self,msg):
'''warning message'''
info = wx.AboutDialogInfo()
info.Name = "Warning Message"
info.Copyright = msg
wx.AboutBox(info)
if __name__ == '__main__':
app = wx.App(False)
frame = MainWindow(None,'Digit Recognize')
app.MainLoop()
我是按着http://www.demodashi.com/demo/16460.html这里面来的
运行结果及https://img-mid.csdnimg.cn/release/static/image/mid/ask/014544551046191.png "#left")
报错内容