周其乐之父 2021-12-22 15:05 采纳率: 75%
浏览 178
已结题

jupyter训练minst模型怎么保存不了

jupyter训练minst模型怎么保存不了

img

parse_mnist_data.ipynb:

import numpy as np 
import struct 
from PIL import Image 
import os 

def make_img(fn_img = "t10k-images.idx3-ubyte", 
             fn_label = "t10k-labels.idx1-ubyte", 
             size1 = 7840016, size2 = 10008, root = "test"):
    fmt_img1 = ">IIII"
    offset_img1 = offset_label1= 0
    fmt_label1 = ">II"
    
    data_file_size = str(size1 - 16) + "B"
    labels_file_size = str(size2 - 8) + "B"
    
    fmt_img2 = ">" + data_file_size
    offset_img2 = struct.calcsize(fmt_img1)
    
    fmt_label2 = ">" + labels_file_size
    offset_label2 = struct.calcsize(fmt_label1)
    
    with open(fn_img, 'rb') as f:
        data_buf = f.read()
        
    with open(fn_label, 'rb') as f:
        label_buf = f.read()
        
    magic_img, numImages, numRows, numColumns = struct.unpack_from(fmt_img1, data_buf, offset_img1) 
    datas = struct.unpack_from(fmt_img2, data_buf, offset_img2)
    datas = np.array(datas).astype(np.uint8).reshape(numImages, 1, numRows, numColumns)
    
    magic_label, numLabels = struct.unpack_from(fmt_label1, label_buf, offset_label1)
    labels = struct.unpack_from(fmt_label2, label_buf, offset_label2)
    labels = np.array(labels).astype(np.int64)
    
    if not os.path.exists(root): 
        os.mkdir(root) 
    
    for i in range(10): 
        file_name = root + os.sep + str(i) 
        if not os.path.exists(file_name): 
            os.mkdir(file_name) 
        
    for ii in range(numLabels): 
        img = Image.fromarray(datas[ii, 0, 0:28, 0:28]) 
        label = labels[ii] 
        
        file_name = root + os.sep + str(label) + os.sep + str(ii).zfill(5) + '.png' 
        img.save(file_name)
        
if __name__ == "__main__":
    make_img()


model_svm.py:

import numpy as np
import os
from PIL import Image
from sklearn.svm import SVC
import joblib
from sklearn.metrics import confusion_matrix, classification_report
import glob
import time

class DataLoader(object):
    """训练前的预处理"""
    def get_files(self, fpath, fmt = "*.png"):
        """获取指定文件夹中指定格式的文件列表;
        paras:
            filepath: str, file path,
            formats: str, file format,
        return: list;"""
        tmp = os.path.join(fpath,fmt)
        fs = glob.glob(tmp)
        return fs
    
    def get_data_labels(self, fpath = "train"):
        paths = glob.glob(fpath + os.sep + "*")
        X = []
        y = []
        for fpath in paths:
            fs = self.get_files(fpath)
            for fn in fs:
                X.append(self.img2vec(fn))
            label = np.repeat(int(os.path.basename(fpath)), len(fs))
            y.append(label)
        labels = y[0]
        for i in range(len(y) - 1):
            labels = np.append(labels, y[i + 1])
        return np.array(X), labels
    
    def img2vec(self, fn):
        '''将jpg等格式的图片转为向量'''
        im = Image.open(fn).convert('L')
        im = im.resize((28,28))
        tmp = np.array(im)
        vec = tmp.ravel()
        return vec 
    
    def save_data(self, X_data, y_data, fn = "mnist_train_data"):
        """将数据保存到本地;"""
        np.savez_compressed(fn, X = X_data, y = y_data)
        
    def load_data(self, fn = "mnist_train_data.npz"):
        """从本地加载数据;"""
        data = np.load(fn)
        X_data = data["X"]
        y_data = data["y"]
        return X_data, y_data

class Trainer(object):
    '''训练器;'''
    def svc(self, x_train, y_train):
        '''构建分类器'''
        model = SVC(kernel = 'poly',degree = 4,probability= True)
        model.fit(x_train, y_train)
        return model
        
    def save_model(self, model, output_name):
        '''保存模型'''
        joblib.dump(model,output_name, compress = 1)

    def load_model(self, model_path):
        '''加载模型'''
        clf = joblib.load(model_path)
        return clf

class Tester(object):
    '''测试器;'''
    def __init__(self, model_path):
        trainer = Trainer()      
        self.clf = trainer.load_model(model_path)
        
    def clf_metrics(self,X_test,y_test):
        """评估分类器效果"""
        pred = self.clf.predict(X_test)
        cnf_matrix = confusion_matrix(y_test, pred)
        score = self.clf.score(X_test, y_test)
        clf_repo = classification_report(y_test, pred)
        return cnf_matrix, score, clf_repo
    
    def predict(self, fn):
        '''样本预测;'''
        loader = DataLoader()
        tmp = loader.img2vec(fn)
        X_test = tmp.reshape(1, -1)
        ans = self.clf.predict(X_test)
        return ans

def run_train():
    t0 = time.time()
    loader = DataLoader()
    trainer = Trainer()
    
    X, y = loader.get_data_labels()
    t1 = time.time()
    print(t1 - t0)
    clf = trainer.svc(X, y)
    print(time.time() - t1)
    
    trainer.save_model(clf, "mnist_svm.m")
    
    X_test, y_test = loader.get_data_labels("test")
    
    tester = Tester("mnist_svm.m")
    mt, score, repo = tester.clf_metrics(X_test, y_test)
    return clf, X, y

GUI界面脚本(digit_gui.py):






import wx
from collections import namedtuple
from PIL import Image
import os
from model_svm import Tester


origin_path = os.getcwd()
wildcard ="png (*.png)|*.png|" \
           "jpg(*.jpg) |*.jpg|"\
           "jpeg(*.jpeg) |*.jpeg|"\
           "tiff(*.tif) |*.tiff|"\
           "All files (*.*)|*.*"

class MainWindow(wx.Frame):
    def __init__(self,parent,title):
        wx.Frame.__init__(self,parent,title=title,size=(600,-1))
        static_font = wx.Font(12, wx.SWISS, wx.NORMAL, wx.NORMAL)
        
        Size = namedtuple("Size",['x','y'])
        s = Size(100,50)
        model_path = os.path.join(origin_path,'mnist_svm.m')

        self.fileName = None
        self.model = Tester(model_path)
        
        b_labels = [u'open',u'run']

        TipString = [u'选择图片', u'识别数字']
        
        funcs = [self.choose_file,self.run]
        
        '''create input area'''
        self.in1 = wx.TextCtrl(self,-1,size = (2*s.x,3*s.y))
        self.out1 = wx.TextCtrl(self,-1,size = (s.x,3*s.y))

        '''create button'''
        self.sizer0 = wx.FlexGridSizer(cols=4, hgap=4, vgap=2) 
        self.sizer0.Add(self.in1)
        
        buttons = []
        for i,label in enumerate(b_labels):
            b = wx.Button(self, id = i,label = label,size = (1.5*s.x,s.y))
            buttons.append(b)
            self.sizer0.Add(b)      

        self.sizer0.Add(self.out1)

        '''set the color and size of labels and buttons'''  
        for i,button in enumerate(buttons):
            button.SetForegroundColour('red')
            button.SetFont(static_font)
#            button.SetToolTipString(TipString[i]) #wx2.8
            button.SetToolTip(TipString[i])   #wx4.0
            button.Bind(wx.EVT_BUTTON,funcs[i])

        '''layout'''
        self.SetSizer(self.sizer0)
        self.SetAutoLayout(1)
        self.sizer0.Fit(self)
        
        self.CreateStatusBar()
        self.Show(True)
    
    def run(self,evt):
        if self.fileName is None:
            self.raise_msg(u'请选择一幅图片')
            return None
        else:
            ans = self.model.predict(self.fileName)
            self.out1.Clear()
            self.out1.write(str(ans))
        
    def choose_file(self,evt):
        '''choose img'''
        dlg = wx.FileDialog(
            self, message="Choose a file",
            defaultDir=os.getcwd(), 
            defaultFile="",
            wildcard=wildcard,
#            style=wx.OPEN | wx.MULTIPLE | wx.CHANGE_DIR #wx2.8
            style = wx.FD_OPEN | wx.FD_MULTIPLE |     #wx4.0
                    wx.FD_CHANGE_DIR | wx.FD_FILE_MUST_EXIST |
                    wx.FD_PREVIEW
            )
        if dlg.ShowModal() == wx.ID_OK:
            paths = dlg.GetPaths()
            dlg.Destroy()
            self.in1.Clear()
            self.in1.write(paths[0])
            self.fileName = paths[0]
            im = Image.open(self.fileName)
            im.show()
        else:
            return None
    
    def raise_msg(self,msg):
        '''warning message'''
        info = wx.AboutDialogInfo()
        info.Name = "Warning Message"
        info.Copyright = msg
        wx.AboutBox(info)
        
if __name__ == '__main__':
    app = wx.App(False)
    frame = MainWindow(None,'Digit Recognize')
    app.MainLoop()

我是按着http://www.demodashi.com/demo/16460.html这里面来的

运行结果及https://img-mid.csdnimg.cn/release/static/image/mid/ask/014544551046191.png "#left")

报错内容

img

我想要达到的结果:保存训练minst模型
  • 写回答

3条回答 默认 最新

  • 江小皮不皮 2021-12-22 16:26
    关注

    看了。挺简单的。

    img

    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论
查看更多回答(2条)

报告相同问题?

问题事件

  • 系统已结题 12月30日
  • 已采纳回答 12月22日
  • 赞助了问题酬金 12月22日
  • 创建了问题 12月22日

悬赏问题

  • ¥30 VMware 云桌面水印如何添加
  • ¥15 用ns3仿真出5G核心网网元
  • ¥15 matlab答疑 关于海上风电的爬坡事件检测
  • ¥88 python部署量化回测异常问题
  • ¥30 酬劳2w元求合作写文章
  • ¥15 在现有系统基础上增加功能
  • ¥15 远程桌面文档内容复制粘贴,格式会变化
  • ¥15 这种微信登录授权 谁可以做啊
  • ¥15 请问我该如何添加自己的数据去运行蚁群算法代码
  • ¥20 用HslCommunication 连接欧姆龙 plc有时会连接失败。报异常为“未知错误”