vgg16 fine-tune keras 20C

哪位可以分享一下keras vgg16 fine-tune程序代码 程序能够正常运行,qq:1246365615

1个回答

这是jupyter notebook文件:

 {
  "cells": [
    {
      "source": [
        "import numpy as np\n",
        "import pandas as pd\n",
        "from sklearn.model_selection import train_test_split\n",
        "from keras.applications.vgg16 import VGG16\n",
        "from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau\n",
        "from keras.layers import Input, Dense, Dropout, Flatten\n",
        "from keras.layers import Conv2D, MaxPooling2D\n",
        "from keras.models import Sequential, Model\n",
        "from keras.optimizers import Adam\n",
        "from keras.preprocessing.image import ImageDataGenerator\n",
        "np.random.seed(7)"
      ],
      "outputs": [],
      "metadata": {
        "_cell_guid": "ad5a3ddc-02c5-4699-9e83-c58a09b9af25",
        "_uuid": "fbcb0242449f7516052a04145a2119656907ea87"
      },
      "cell_type": "code",
      "execution_count": 1
    },
    {
      "source": [
        "def make_df(path, mode):\n",
        "    \"\"\"\n",
        "    params\n",
        "    --------\n",
        "    path(str): path to json\n",
        "    mode(str): \"train\" or \"test\"\n",
        "\n",
        "    outputs\n",
        "    --------\n",
        "    X(np.array): list of images shape=(None, 75, 75, 3)\n",
        "    Y(np.array): list of labels shape=(None,)\n",
        "    df(pd.DataFrame): data frame from json\n",
        "    \"\"\"\n",
        "    df = pd.read_json(path)\n",
        "    df.inc_angle = df.inc_angle.replace('na', 0)\n",
        "    X = _get_scaled_imgs(df)\n",
        "    if mode == \"test\":\n",
        "        return X, df\n",
        "\n",
        "    Y = np.array(df['is_iceberg'])\n",
        "\n",
        "    idx_tr = np.where(df.inc_angle > 0)\n",
        "\n",
        "    X = X[idx_tr[0]]\n",
        "    Y = Y[idx_tr[0], ...]\n",
        "\n",
        "    return X, Y\n",
        "\n",
        "\n",
        "def _get_scaled_imgs(df):\n",
        "    imgs = []\n",
        "\n",
        "    for i, row in df.iterrows():\n",
        "        band_1 = np.array(row['band_1']).reshape(75, 75)\n",
        "        band_2 = np.array(row['band_2']).reshape(75, 75)\n",
        "        band_3 = band_1 + band_2\n",
        "\n",
        "        a = (band_1 - band_1.mean()) / (band_1.max() - band_1.min())\n",
        "        b = (band_2 - band_2.mean()) / (band_2.max() - band_2.min())\n",
        "        c = (band_3 - band_3.mean()) / (band_3.max() - band_3.min())\n",
        "\n",
        "        imgs.append(np.dstack((a, b, c)))\n",
        "\n",
        "    return np.array(imgs)"
      ],
      "outputs": [],
      "metadata": {
        "collapsed": true
      },
      "cell_type": "code",
      "execution_count": 2
    },
    {
      "source": [
        "def SmallCNN():\n",
        "    model = Sequential()\n",
        "\n",
        "    model.add(Conv2D(64, kernel_size=(3, 3), activation='relu',\n",
        "                     input_shape=(75, 75, 3)))\n",
        "    model.add(MaxPooling2D(pool_size=(3, 3), strides=(2, 2)))\n",
        "    model.add(Dropout(0.2))\n",
        "\n",
        "    model.add(Conv2D(128, kernel_size=(3, 3), activation='relu'))\n",
        "    model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))\n",
        "    model.add(Dropout(0.2))\n",
        "\n",
        "    model.add(Conv2D(128, kernel_size=(3, 3), activation='relu'))\n",
        "    model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))\n",
        "    model.add(Dropout(0.3))\n",
        "\n",
        "    model.add(Conv2D(64, kernel_size=(3, 3), activation='relu'))\n",
        "    model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))\n",
        "    model.add(Dropout(0.3))\n",
        "\n",
        "    model.add(Flatten())\n",
        "    model.add(Dense(512, activation='relu'))\n",
        "    model.add(Dropout(0.2))\n",
        "\n",
        "    model.add(Dense(256, activation='relu'))\n",
        "    model.add(Dropout(0.2))\n",
        "\n",
        "    model.add(Dense(1, activation=\"sigmoid\"))\n",
        "\n",
        "    return model"
      ],
      "outputs": [],
      "metadata": {
        "collapsed": true
      },
      "cell_type": "code",
      "execution_count": 3
    },
    {
      "source": [
        "def Vgg16():\n",
        "    input_tensor = Input(shape=(75, 75, 3))\n",
        "    vgg16 = VGG16(include_top=False, weights='imagenet',\n",
        "                  input_tensor=input_tensor)\n",
        "\n",
        "    top_model = Sequential()\n",
        "    top_model.add(Flatten(input_shape=vgg16.output_shape[1:]))\n",
        "    top_model.add(Dense(512, activation='relu'))\n",
        "    top_model.add(Dropout(0.5))\n",
        "    top_model.add(Dense(256, activation='relu'))\n",
        "    top_model.add(Dropout(0.5))\n",
        "    top_model.add(Dense(1, activation='sigmoid'))\n",
        "\n",
        "    model = Model(input=vgg16.input, output=top_model(vgg16.output))\n",
        "    for layer in model.layers[:13]:\n",
        "        layer.trainable = False\n",
        "\n",
        "    return model"
      ],
      "outputs": [],
      "metadata": {
        "collapsed": true
      },
      "cell_type": "code",
      "execution_count": 4
    },
    {
      "source": [
        "if __name__ == \"__main__\":\n",
        "    x, y = make_df(\"../input/train.json\", \"train\")\n",
        "    xtr, xval, ytr, yval = train_test_split(x, y, test_size=0.25,\n",
        "                                            random_state=7)\n",
        "    model = SmallCNN()\n",
        "    #model = Vgg16()\n",
        "    optimizer = Adam(lr=0.001, decay=0.0)\n",
        "    model.compile(loss='binary_crossentropy', optimizer=optimizer,\n",
        "                  metrics=['accuracy'])\n",
        "\n",
        "    earlyStopping = EarlyStopping(monitor='val_loss', patience=20, verbose=0,\n",
        "                                  mode='min')\n",
        "    ckpt = ModelCheckpoint('.model.hdf5', save_best_only=True,\n",
        "                           monitor='val_loss', mode='min')\n",
        "    reduce_lr_loss = ReduceLROnPlateau(monitor='val_loss', factor=0.1,\n",
        "                                       patience=7, verbose=1, epsilon=1e-4,\n",
        "                                       mode='min')\n",
        "\n",
        "    gen = ImageDataGenerator(horizontal_flip=True,\n",
        "                             vertical_flip=True,\n",
        "                             width_shift_range=0,\n",
        "                             height_shift_range=0,\n",
        "                             channel_shift_range=0,\n",
        "                             zoom_range=0.2,\n",
        "                             rotation_range=10)\n",
        "    gen.fit(xtr)\n",
        "    model.fit_generator(gen.flow(xtr, ytr, batch_size=32),\n",
        "                        steps_per_epoch=len(xtr), epochs=1,\n",
        "                        callbacks=[earlyStopping, ckpt, reduce_lr_loss],\n",
        "                        validation_data=(xval, yval))\n",
        "\n",
        "    model.load_weights(filepath='.model.hdf5')\n",
        "    score = model.evaluate(xtr, ytr, verbose=1)\n",
        "    print('Train score:', score[0], 'Train accuracy:', score[1])\n",
        "\n",
        "    xtest, df_test = make_df(\"../input/test.json\", \"test\")\n",
        "    pred_test = model.predict(xtest)\n",
        "    pred_test = pred_test.reshape((pred_test.shape[0]))\n",
        "    submission = pd.DataFrame({'id': df_test[\"id\"], 'is_iceberg': pred_test})\n",
        "    submission.to_csv('submission.csv', index=False)"
      ],
      "outputs": [],
      "metadata": {},
      "cell_type": "code",
      "execution_count": 5
    }
  ],
  "nbformat": 4,
  "nbformat_minor": 1,
  "metadata": {
    "language_info": {
      "name": "python",
      "version": "3.6.3",
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "pygments_lexer": "ipython3",
      "nbconvert_exporter": "python",
      "file_extension": ".py",
      "mimetype": "text/x-python"
    },
    "kernelspec": {
      "name": "python3",
      "language": "python",
      "display_name": "Python 3"
    }
  }
}

来自:https://www.kaggle.com/takuok/keras-smallcnn-and-vgg16-fine-tuning/code

另外的资料:http://marubon-ds.blogspot.com/2017/09/vgg16-fine-tuning-model.html (需要科学上网)
https://flyyufelix.github.io/2016/10/03/fine-tuning-in-keras-part1.html

Csdn user default icon
上传中...
上传图片
插入图片
抄袭、复制答案,以达到刷声望分或其他目的的行为,在CSDN问答是严格禁止的,一经发现立刻封号。是时候展现真正的技术了!
其他相关推荐
keras yolov3 tiny_yolo_body网络结构改为vgg16结构
-
VGG16和ResNet50的mAP问题
-
为什么用vgg16网络训练我自己的数据集,loss一直在1.7左右震荡,用训练好的模型进行预测出来的都是一个值?
-
如何将pytorch的VGG16改为CNN+ELM?
-
请大神指点,VGG-16训练时权重不更新,怎么回事??
-
vgg16计算反向传播时,无法读取全连接层
-
导入npy预训练文件出现No gradients provided for any variable
-
使用反卷积tf.nn.conv2d_transpose函数,算出来为什么都是(?,?,?,2)的形式?
-
已有原图像和mask 怎么去制作数据集呢
-
Keras报错 ‘ValueError: 'pool5' is not in list’
-
tensorflow训练过程权重不更新,loss不下降,输出保持不变,只有bias在非常缓慢地变化?
-
迁移学习中进行医学影像分析,训练神经网络后accuracy保持不变。。。
-
训练网络时损失值一直震荡
-
度量学习中三元组损失不收敛(loss无法下降到margin以下,样本的降维输出聚在一起)
-
利用caffe训练VGG网络出现错误
-
linux下使用pytorch框架出现cuda run out of memory问题
-
深度学习VGG模型加载硬件条件
-
TensorFlow SSD训练自己的数据 checkpoint问题
-
关于深度学习图片数据集的建立
-
程序员实用工具网站
目录 1、搜索引擎 2、PPT 3、图片操作 4、文件共享 5、应届生招聘 6、程序员面试题库 7、办公、开发软件 8、高清图片、视频素材网站 9、项目开源 10、在线工具宝典大全 程序员开发需要具备良好的信息检索能力,为了备忘(收藏夹真是满了),将开发过程中常用的网站进行整理。 1、搜索引擎 1.1、秘迹搜索 一款无敌有良心、无敌安全的搜索引擎,不会收集私人信息,保...
我花了一夜用数据结构给女朋友写个H5走迷宫游戏
起因 又到深夜了,我按照以往在csdn和公众号写着数据结构!这占用了我大量的时间!我的超越妹妹严重缺乏陪伴而 怨气满满! 而女朋友时常埋怨,认为数据结构这么抽象难懂的东西没啥作用,常会问道:天天写这玩意,有啥作用。而我答道:能干事情多了,比如写个迷宫小游戏啥的! 当我码完字准备睡觉时:写不好别睡觉! 分析 如果用数据结构与算法造出东西来呢? ...
别再翻了,面试二叉树看这 11 个就够了~
写在前边 数据结构与算法: 不知道你有没有这种困惑,虽然刷了很多算法题,当我去面试的时候,面试官让你手写一个算法,可能你对此算法很熟悉,知道实现思路,但是总是不知道该在什么地方写,而且很多边界条件想不全面,一紧张,代码写的乱七八糟。如果遇到没有做过的算法题,思路也不知道从何寻找。面试吃了亏之后,我就慢慢的做出总结,开始分类的把数据结构所有的题型和解题思路每周刷题做出的系统性总结写在了 Github...
让程序员崩溃的瞬间(非程序员勿入)
今天给大家带来点快乐,程序员才能看懂。 来源:https://zhuanlan.zhihu.com/p/47066521 1. 公司实习生找 Bug 2.在调试时,将断点设置在错误的位置 3.当我有一个很棒的调试想法时 4.偶然间看到自己多年前写的代码 5.当我第一次启动我的单元测试时 ...
接私活必备的 10 个开源项目!
点击蓝色“GitHubDaily”关注我加个“星标”,每天下午 18:35,带你逛 GitHub!作者 | SevDot来源 | http://1t.click/VE8W...
GitHub开源的10个超棒后台管理面板
目录 1、AdminLTE 2、vue-Element-Admin 3、tabler 4、Gentelella 5、ng2-admin 6、ant-design-pro 7、blur-admin 8、iview-admin 9、material-dashboard 10、layui 项目开发中后台管理平台必不可少,但是从零搭建一套多样化后台管理并不容易,目前有许多开源、免费、...
100 个网络基础知识普及,看完成半个网络高手
欢迎添加华为云小助手微信(微信号:HWCloud002或HWCloud003),输入关键字“加群”,加入华为云线上技术讨论群;输入关键字“最新活动”,获取华为云最新特惠促销。华为云诸多技术大咖、特惠活动等你来撩! 1)什么是链接? 链接是指两个设备之间的连接。它包括用于一个设备能够与另一个设备通信的电缆类型和协议。 2)OSI 参考模型的层次是什么? 有 7 个 OSI 层:物理...
VS CODE远程开发入门
在我们办公室,通常配置两台电脑,一台 Windows 主机,主要用于办公、即时通讯,一台 Linux 主机,用于开发。一般开发人员习惯用 Windows 系统下的工具,比如 Source Insight ,但代码需要在 Linux 下编译。这样就需要 Windows 和 Linux 之间协作,通常的做法是在 Linux 下安装 samba 服务,通过 Windows 共享访问。今天看到一篇文章,...
中国最顶级的一批程序员,从首富到首负!
过去的20年是程序员快意恩仇的江湖时代通过代码,实现梦想和财富有人痴迷于技术,做出一夜成名的产品有人将技术变现,创办企业成功上市这些早一代的程序员们创造的奇迹引发了一浪高...
为什么面向对象糟透了?
又是周末,编程语言“三巨头”Java, Lisp 和C语言在Hello World咖啡馆聚会。服务员送来咖啡的同时还带来了一张今天的报纸, 三人寒暄了几句, C语言翻开了...
分享靠写代码赚钱的一些门路
作者 mezod,译者 josephchang10如今,通过自己的代码去赚钱变得越来越简单,不过对很多人来说依然还是很难,因为他们不知道有哪些门路。今天给大家分享一个精彩...
对计算机专业来说学历真的重要吗?
我本科学校是渣渣二本,研究生学校是985,现在毕业五年,校招笔试、面试,社招面试参加了两年了,就我个人的经历来说下这个问题。 这篇文章很长,但绝对是精华,相信我,读完以后,你会知道学历不好的解决方案,记得帮我点赞哦。 先说结论,无论赞不赞同,它本质就是这样:对于技术类工作而言,学历五年以内非常重要,但有办法弥补。五年以后,不重要。 目录: 张雪峰讲述的事实 我看到的事实 为什么会这样 ...
世界上最好的学习法:费曼学习法
你是否曾幻想读一遍书就记住所有的内容?是否想学习完一项技能就马上达到巅峰水平?除非你是天才,不然这是不可能的。对于大多数的普通人来说,可以通过笨办法(死记硬背)来达到学习的目的,但效率低下。当然,也可以通过优秀的学习法来进行学习,比如今天讲的“费曼学习法”,可以将你的学习效率极大的提高。 费曼学习法是由加拿大物理学家费曼所发明的一种高效的学习方法,费曼本身是一个天才,13岁自学微积分,24岁加入曼...
学Linux到底学什么
来源:公众号【编程珠玑】 作者:守望先生 网站:https://www.yanbinghu.com/2019/09/25/14472.html 前言 ​我们常常听到很多人说要学学Linux或者被人告知说应该学学Linux,那么学Linux到底要学什么? 为什么要学Linux 在回答学什么之前,我们先看看为什么要学。首先我们需要认识到的是,很多服务器使用的是Linux系统,而作为服务器应...
深入理解C语言指针
一、指针的概念 要知道指针的概念,要先了解变量在内存中如何存储的。在存储时,内存被分为一块一块的。每一块都有一个特有的编号。而这个编号可以暂时理解为指针,就像酒店的门牌号一样。 1.1、变量和地址 先写一段简单的代码: void main(){ int x = 10, int y = 20; } 这段代码非常简单,就是两个变量的声明,分别赋值了 10、20。我们把内存当做一个酒店,而每个房间就...
C语言实现推箱子游戏
很早就想过做点小游戏了,但是一直没有机会动手。今天闲来无事,动起手来。过程还是蛮顺利的,代码也不是非常难。今天给大家分享一下~ 一、介绍 开发语言:C语言 开发工具:Dev-C++ 5.11 日期:2019年9月28日 作者:ZackSock 也不说太多多余的话了,先看一下效果图: 游戏中的人物、箱子、墙壁、球都是字符构成的。通过wasd键移动,规则的话就是推箱子的规则,也就不多说了。 二、代...
面试官:兄弟,说说基本类型和包装类型的区别吧
Java 的每个基本类型都对应了一个包装类型,比如说 int 的包装类型为 Integer,double 的包装类型为 Double。基本类型和包装类型的区别主要有以下 4 点。
8000字干货:那些很厉害的人是怎么构建知识体系的
本文约8000字,正常阅读需要15~20分钟。读完本文可以获得如下收益: 分辨知识和知识体系的差别 理解如何用八大问发现知识的连接点; 掌握致用类知识体系的构建方法; 能够应用甜蜜区模型找到特定领域来构建知识体系。 1. 知识体系?有必要吗? 小张准备通过跑步锻炼身体,可因为之前听说过小腿变粗、膝盖受伤、猝死等等与跑步有关的意外状况,有点担心自己会掉进各种坑里,就在微信上问朋友圈一直晒跑步...
Android完整知识体系路线(菜鸟-资深-大牛必进之路)
前言 移动研发火热不停,越来越多人开始学习Android 开发。但很多人感觉入门容易成长很难,对未来比较迷茫,不知道自己技能该怎么提升,到达下一阶段需要补充哪些内容。市面上也多是谈论知识图谱,缺少体系和成长节奏感,特此编写一份 Android 研发进阶之路,希望能对大家有所帮助。 由于篇幅过长,有些问题的答案并未放在文章当中,不过我都整理成了一个文档归纳好了,请阅读到文末领取~ Ja...
网易云音乐你喜欢吗?你自己也可以做一个
【公众号回复 “1024”,免费领取程序员赚钱实操经验】今天我章鱼猫给大家带来的这个开源项目,估计很多喜欢听音乐的朋友都会喜欢。就目前来讲,很多人对这款音乐 App 都抱...
C语言这么厉害,它自身又是用什么语言写的?
这是来自我的星球的一个提问:“C语言本身用什么语言写的?”换个角度来问,其实是:C语言在运行之前,得编译才行,那C语言的编译器从哪里来? 用什么语言来写的?如果是用C语...
相关热词 c#部门请假管理系统 c#服务器socket c# 默认的访问修饰符 c#拖动文件 c# 截取指定窗口屏幕 c# html对象传后台 c# 判断域名还是ip c#遮罩层 c# 取字符串中的数字 c# 网站高并发测试