Python,类中字典变量的疑问

各位python大神,我刚开始看python,遇到一个问题,有点疑惑,请高手给解下惑
我在一个类中定义了一个字符串变量和一个字典变量

我实例化几个实例,分别对字符串变量和字典变量进行赋值,然后存储到一个链表中,最后打印发现字典变量总是会被最后一次赋值覆盖掉。

字典变量在类中就相当于static了吗?

SysConfig 数据库中的信息读取后,保存到这个SysConfig中

 # coding=utf-8

import json

class SysConfig(object):
    syscfg = {
        "id":0,
        "name":"",
        "switcherid":0,
        "centreconsoleid":0,
        "port":"",
        "teacher":"",
        "teacher_panorama":"",
        "student":"",
        "student_panorama":"",
        "blackboard":"",
        "course":""
    }

    test="test";

    def __init__(self):
        pass;

    def getSysCfg(self):
        return self.syscfg;

    def setSyscfg(self, cfg):
        for key in cfg:
            self.syscfg[key] = cfg[key];

        self.test = str(cfg);

数据库操作类:

 import pymysql
import sys
import sysCfg

class GSDBHelper(object):

    dbConfig = {
          'host':'127.0.0.1',
          'port':3306,
          'user':'root',
          'password':'root',
          'db':'db_test',
          'charset':'utf8mb4',
          'cursorclass':pymysql.cursors.DictCursor
          }

    dbsysList = [];

    def __init__(self):
        pass;

    def setConfig(self, host, port, user, password, db):
        self.dbConfig["host"] = host;
        self.dbConfig["port"] = port;
        self.dbConfig["user"] = user;
        self.dbConfig["password"] = password;
        self.dbConfig["db"] = db;

        # print("db config...", self.dbConfig);

    def getSysConfig(self):
        try:
            self.dbsysList.clear();

            conn = pymysql.connect(**self.dbConfig);
            cur = conn.cursor();
            cur.execute("select * from dbt_sys");
            rows = cur.fetchall();
            for row in rows:
                dbsys = sysCfg.SysConfig();
                dbsys.setSyscfg(row);
                self.dbsysList.append(dbsys);

            for sysItem in self.dbsysList:
                print(sysItem, " syscfg... ", sysItem.getSysCfg());
                print(sysItem, " test 字符串  ", sysItem.test);

        except:
            info = sys.exc_info();
            print(info[0], ":", info[1]);
        finally:
            if "conn" in locals().keys():
                conn.close();

执行:

 import dbhelper

if __name__ == '__main__':
    db = dbhelper.GSDBHelper();
    db.setConfig("127.0.0.1", 3306, "root", "ab123@", "db_gsaid");
    db.getSysConfig();

运行结果:

 <sysCfg.SysConfig object at 0x0293A2D0>  syscfg...  {'id': 2, 'name': None, 'switcherid': 1, 'port': '{"sprite":99991, "cameractrl1":2001, "cameractrl2":2002, "cameractrl3":2003, "cameractrl4":2004}', 'centreconsoleid': None, 'teacher': None, 'teacher_panorama': None, 'student': None, 'student_panorama': None, 'blackboard': None, 'course': None, 'info': None}
<sysCfg.SysConfig object at 0x0293A2D0>  test 字符串   {'id': 1, 'name': '', 'switcherid': 1, 'centreconsoleid': 1, 'port': '{"sprite":9999, "cameractrl1":2001, "cameractrl2":20021, "cameractrl3":2003, "cameractrl4":2004}', 'teacher': '{"time":{"min":2000, "max":600000}, "lostormax_cutto":"student_panorama", "effect1":{"id":16, "pect":100}, "effect2":{"id":16, "pect":0}}', 'teacher_panorama': '{"time":{"min":2000, "max":100000}, "effect1":{"id":16, "pect":100}, "effect2":{"id":16, "pect":0} }', 'student': '{"time":{"min":2000, "max":100000} , "effect1":{"id":16, "pect":100}, "effect2":{"id":16, "pect":0}}', 'student_panorama': '{"time":{"min":2000, "max":100000} , "effect1":{"id":16, "pect":100}, "effect2":{"id":16, "pect":0}}', 'blackboard': '{"time":{"min":2000, "max":100000}, "effect1":{"id":16, "pect":100}, "effect2":{"id":16, "pect":0} }', 'course': '{"time":{"min":2000, "max":100000} , "effect1":{"id":16, "pect":100}, "effect2":{"id":16, "pect":0}}', 'info': None}
<sysCfg.SysConfig object at 0x0293A2F0>  syscfg...  {'id': 2, 'name': None, 'switcherid': 1, 'port': '{"sprite":99991, "cameractrl1":2001, "cameractrl2":2002, "cameractrl3":2003, "cameractrl4":2004}', 'centreconsoleid': None, 'teacher': None, 'teacher_panorama': None, 'student': None, 'student_panorama': None, 'blackboard': None, 'course': None, 'info': None}
<sysCfg.SysConfig object at 0x0293A2F0>  test 字符串   {'id': 2, 'name': None, 'switcherid': 1, 'centreconsoleid': None, 'port': '{"sprite":99991, "cameractrl1":2001, "cameractrl2":2002, "cameractrl3":2003, "cameractrl4":2004}', 'teacher': None, 'teacher_panorama': None, 'student': None, 'student_panorama': None, 'blackboard': None, 'course': None, 'info': None}

1个回答

原来这样。。。

要理解一下python跟别的面向对象语言的不同,应该这样:

 class people:

    def __init__(self, s, n):
       self. m = {}
       self.m[s] = n
    def show(self):
       self. m = {}
       print(self.m)

p1 = people('年龄', 1)
p2 = people('年龄', 2)

p1.show()
p2.show()
Csdn user default icon
上传中...
上传图片
插入图片
抄袭、复制答案,以达到刷声望分或其他目的的行为,在CSDN问答是严格禁止的,一经发现立刻封号。是时候展现真正的技术了!
其他相关推荐
python字典中键值使用变量?
``` dict={'1':["TEST1"], '2':["TEST2"], '3':["TEST3"], } i=1 x=dict["%d"] %d i print(x) ``` 我想使用一个变量去引用键值,结果发现这样操作不可行,请问下有大神知道该怎么解决么
python构建get请求如何把字典参数作为一个变量传进去
![图片说明](https://img-ask.csdn.net/upload/201912/05/1575555981_756412.png)
python字典的值为字母,怎样变成整数使其一一对应关系
多张excel表格,要提取excel表格的某一列做数据分析,表格不是固定的格式,想通过变量来提取,excel的列是以字母排序的 比如: import pandas as pd dic={"长":D,"宽":G} #长的数据在D列,宽的数据在G列 Length=dic.get("长") Width=dic.get("宽") df=pd.read_excel("d:/ex.xlsx") df1=df.iloc[:,[Length,Width]] ...... _问题是怎样把 Length=dic.get("长") Width=dic.get("宽") 变成 Length=3 Width=6 即把字典的字母值变成数字值,A-Z和0-25一一对应关系, 值为A时变成值为0,值为B时变成值为1,依此类推!谢谢大家!_
python+mysql:在插入语句中使用变量,变量中含有多个引号,和双引号,该如何使用sql语句
数据是一个字典类型: my_dict={"2017-07" : [ {"origin" : "LJ","price" : 44267,"crawl_date" : "2017-09-01"}]} item={'name':'万科','city_name':'深圳','location':'龙岗','price':str(my_dict)} 把字典转化为str后,变成了item。 插入语句: sql2='''insert into house ( name,city_name,location,price) values ('%s','%s','%s','%s')''' %(item['name'], item['city_name'], item['location'], item['price'].encode('utf-8')) 因为price中含有很多引号,{"2017-07" : [ {"origin" : "LJ","price" : 44267,"crawl_date" : "2017-09-01"}]} 所以运行的时候回提示说sql错误,要怎样才能把这些字符显示为字符而非转义字符 ?
为什么全局变量在进程中赋值后,线程中接收不到。
我定义了一个global全局变量,在一个进程中对这个变量赋值,然后想在一个线程中使用这个变量,但是发现线程中并没有收到应该被赋值的全局变量,不知道是什么原因,以下是我的部分代码 ``` global a a = {} class Control_system(QMainWindow, Ui_Control_system): socketQueue = multiprocessing.Queue() def __init__(self, parent=None): super().__init__(parent) self.setupUi(self) self.p1 = multiprocessing.Process(target=Control_system.connect, args=(self.socketQueue,)) self.p1.start() self.timer = QTimer(self) self.timer.timeout.connect(self.client) self.timer.start(1000) @staticmethod def connect(queue): ip_port = ("192.168.1.251", 8880) s = socketserver.ThreadingTCPServer(ip_port, MyServer) s.serve_forever() def client(self): print(a) # 这里的a还是空字典 class MyServer(socketserver.BaseRequestHandler): def handle(self): print("conn is :", self.request) # conn print("addr is :", self.client_address) # addr a[self.client_address] = self.request print("a:",a) # 这里的a是有值的 if __name__ == '__main__': if not QApplication.instance(): app = QApplication(sys.argv) else: app = QApplication.instance() w = Control_system() w.show() sys.exit(app.exec()) ```
C++调用python脚本(test.py这个脚本中import numpy)程序崩溃
我想在c++中调用python的一个脚本,这个脚本中我只是import了一个numpy就报错了,而如果是简单的脚本(没有import第三方库)就不会出错,我已经把: INCLUDEPATH += C:/Python27/include/ LIBS += C:/Python27/libs/python27.lib 添加进去了, ``` pyrun_simplestring("import sys"); pyrun_simplestring("import numpy"); pyrun_simplestring("sys.path.append('c:\python27\lib\site-packages\')"); pyerr_print(); pyobject * pmodule = null; pyobject * pfunc = null; pmodule =pyimport_importmodule("test_my"); //test001:python文件名 pfunc= pyobject_getattrstring(pmodule, "testdict"); //add:python文件中的函数名 pyobject *pargs = pytuple_new(1); pyobject *pdict = pydict_new(); //创建字典类型变量 pydict_setitemstring(pdict, "name", py_buildvalue("s", "wangyao")); //往字典类型变量中填充数据 pydict_setitemstring(pdict, "age", py_buildvalue("i", 25)); //往字典类型变量中填充数据 pytuple_setitem(pargs, 0, pdict); //0---序号 将字典类型变量添加到参数元组中 pyobject *preturn = null; preturn = pyeval_callobject(pfunc, pargs); //调用函数 int size = pydict_size(preturn); cout << "返回字典的大小为: " << size << endl; pyobject *pnewage = pydict_getitemstring(preturn, "age"); int newage; pyarg_parse(pnewage, "i", &newage); cout << "true age: " << newage << endl; py_finalize(); ``` 这是python的脚本: ``` #import numpy as np def HelloWorld(): print "Hello World" def add(a, b): #tmp=np.random.randint(10,88) return a+b def TestDict(dict): print dict dict["Age"] = 17 return dict class Person: def greet(self, greetStr): print greetStr #print add(5,7) #a = raw_input("Enter To Continue...") ``` 老是报错,但如果我把import numpy去掉就没问题,求大神解答,困扰好久了~~~
python子类继承父类时找不到__init__
![UML](https://img-ask.csdn.net/upload/201901/06/1546726233_455371.png) 类PatientManagement有3个方法 1. add_patient 添加一个新病人(属于类 Patient)。病人对应一个ID,第一个ID为0,第二个为1,以此类推 2. get_patient 返回病人的ID 3.get_statistics 返回女病人和男病人的数量,以及平均BMI 病人和ID以字典的方式储存在变量 patients 中,以{id: Patient}形式,最后一个使用的ID储存在变量 last_used 中。 类 Patient 有两个方法 1. to_string 打印 名 姓 出生日 月 年 性别 身高(cm)体重(kg) 2. get_bmi 计算这个病人的BMI,BMI = weight in KG / (height in M)² get_statistics调用get_bmi 且不再进行一变BMI计算 运行结果是这样 ``` class Patient(object): def __init__(self, first_name, last_name, birth_year, birth_month, birth_day, sex, body_height, body_weight): self.first_name = first_name self.last_name = last_name self.year = birth_year self.month = birth_month self.day = birth_day self.sex = sex self.height = body_height self.weight = body_weight def to_string(self): print("Name: {} {}, geboren am {}.{}.{}, geschlecht:{}, {}cm, {}kg".format(self.first_name, self.last_name, self.day, self.month, self.year, self.sex, self.height, self.weight)) def get_bmi(self): return self.weight / ((self.height/100) ** 2) a =[] b = [] patients = {} i = 0 class PatientManagement(Patient): fpCount = 0 mpCount = 0 def add_patient(self): patients[i] = Patient() i += 1 if self.sex == f: PatientManagement.fpCount += 1 if self.sex == m: PatientManagement.mpCount += 1 for k in patients.keys(): a.append(k) b.append(patients[k]) last_used = a[-1] def get_patient(self): new_dict = {v:k for k,v in patients.items()} patient_id = new_dict[Patient()] return patients_id def get_statistics(self): bmi = [] for patient in b: bmi.append(patient.get_bmi()) avg_bmi = np.mean(bmi) return "female Patients:{}, male Patients:{}, average BMI: {}".format(PatientManagement.fpCount, PatientManagement.mpCount, avg_bmi) ``` ``` p1 = Patient("Meier", "Lena", 1988, 12, 12, 'f', 164, 50) pm1 = PatientManagement(p1) pm1.add_patient() pm1.get_patient() ``` 这里报错是问题在哪里? ``` --------------------------------------------------------------------------- TypeError Traceback (most recent call last) <ipython-input-141-ebb011904724> in <module>() 1 p1 = Patient("Meier", "Lena", 1988, 12, 12, 'f', 164, 50) ----> 2 pm1 = PatientManagement(p1) 3 pm1.add_patient() 4 pm1.get_patient() TypeError: __init__() missing 7 required positional arguments: 'last_name', 'birth_year', 'birth_month', 'birth_day', 'sex', 'body_height', and 'body_weight' ```
python爬新浪新闻内容,为什么运行完stock里面为空……
#! /usr/bin/env python #coding=utf-8 from scrapy.selector import Selector from scrapy.http import Request import re,os from bs4 import BeautifulSoup from scrapy.spider import Spider import urllib2,thread #处理编码问题 import sys reload(sys) sys.setdefaultencoding('gb18030') #flag的作用是保证第一次爬取的时候不进行单个新闻页面内容的爬取 flag=1 projectpath='C:\\Users\DELL\\Desktop\\pythonproject\\mypro\\' def loop(*response): sel = Selector(response[0]) #get title title = sel.xpath('//h1/text()').extract() #get pages pages=sel.xpath('//div[@id="artibody"]//p/text()').extract() #get chanel_id & comment_id s=sel.xpath('//meta[@name="comment"]').extract() #comment_id = channel[index+3:index+15] index2=len(response[0].url) news_id=response[0].url[index2-14:index2-6] comment_id='31-1-'+news_id #评论内容都在这个list中 cmntlist=[] page=1 #含有新闻url,标题,内容,评论的文件 file2=None #该变量的作用是当某新闻下存在非手机用户评论时置为False is_all_tel=True while((page==1) or (cmntlist != [])): tel_count=0 #each page tel_user_count #提取到的评论url url="http://comment5.news.sina.com.cn/page/info?version=1&format=js&channel=cj&newsid="+str(comment_id)+"&group=0&compress=1&ie=gbk&oe=gbk&page="+str(page)+"&page_size=100" url_contain=urllib2.urlopen(url).read() b='={' after = url_contain[url_contain.index(b)+len(b)-1:] #字符串中的None对应python中的null,不然执行eval时会出错 after=after.replace('null','None') #转换为字典变量text text=eval(after) if 'cmntlist' in text['result']: cmntlist=text['result']['cmntlist'] else: cmntlist=[] if cmntlist != [] and (page==1): filename=str(comment_id)+'.txt' path=projectpath+'stock\\' +filename file2=open(path,'a+') news_content=str('') for p in pages: news_content=news_content+p+'\n' item="<url>"+response[0].url+"</url>"+'\n\n'+"<title>"+str(title[0])+"</title>\n\n"+"<content>\n"+str(news_content)+"</content>\n\n<comment>\n" file2.write(item) if cmntlist != []: content='' for status_dic in cmntlist: if status_dic['uid']!='0': is_all_tel=False #这一句视编码情况而定,在这里去掉decode和encode也行 s=status_dic['content'].decode('UTF-8').encode('GBK') #见另一篇博客“三张图” s=s.replace("'",'"') s=s.replace("\n",'') s1="u'"+s+"'" try: ss=eval(s1) except: try: s1='u"'+s+'"' ss=eval(s1) except: return content=content+status_dic['time']+'\t'+status_dic['uid']+'\t'+ss+'\n' #当属于手机用户时 else: tel_count=tel_count+1 #当一个page下不都是手机用户时,这里也可以用is_all_tel进行判断,一种是用开关的方式,一种是统计的方式 #算了不改了 if tel_count!=len(cmntlist): file2.write(content) page=page+1 #while loop end here if file2!=None: #当都是手机用户时,移除文件,否则写入"</comment>"到文件尾 if is_all_tel: file2.close() try: os.remove(file2.name) except WindowsError: pass else: file2.write("</comment>") file2.close() class DmozSpider(Spider): name = "stock" allowed_domains = ["sina.com.cn"] #在本程序中,start_urls并不重要,因为并没有解析 start_urls = [ "http://news.sina.com.cn/" ] global projectpath if os.path.exists(projectpath+'stock'): pass else: os.mkdir(projectpath+'stock') def parse(self, response): #这个scrapy.selector.Selector是个不错的处理字符串的类,python对编码很严格,它却处理得很好 #在做这个爬虫的时候,碰到很多奇奇怪怪的编码问题,主要是中文,试过很多既有的类,BeautifulSoup处理得也不是很好 sel = Selector(response) global flag if(flag==1): flag=2 page=1 while page<260: url="http://roll.finance.sina.com.cn/finance/zq1/index_" url=url+str(page)+".shtml" #伪装为浏览器 user_agent = 'Mozilla/4.0 (compatible; MSIE 5.5; Windows NT)' headers = { 'User-Agent' : user_agent } req = urllib2.Request(url, headers=headers) response = urllib2.urlopen(req) url_contain = response.read() #利用BeautifulSoup进行文档解析 soup = BeautifulSoup(url_contain) params = soup.findAll('div',{'class':'listBlk'}) if os.path.exists(projectpath+'stock\\'+'link'): pass else: os.mkdir(projectpath+'stock\\'+'link') filename='link.txt' path=projectpath+'stock\\link\\' + filename filelink=open(path,'a+') for params_item in params: persons = params_item.findAll('li') for item in persons: href=item.find('a') mil_link= href.get('href') filelink.write(str(mil_link)+'\n') #递归调用parse,传入新的爬取url yield Request(mil_link, callback=self.parse) page=page+1 #对单个新闻页面新建线程进行爬取 if flag!=1: if (response.status != 404) and (response.status != 502): thread.start_new_thread(loop,(response,))
Python多线程问题,target以及kwargs传参出错,请问应该怎么写
``` def A(a,b,c): 代码块省略 def B(a,b,c,d): 代码块省略 def thread(self,arg*): t1 = threading.Thread(target=A,args=(a,b,c)) ``` 问题一:这里我想参数target=需要多开线程的方法名,然后我随便定义一个变量作为方法名参数传到target里面,,不行,,程序报错。求正确的传参方法,难不成我要为每一个要多开的方法都要多写一个多线程方法,,仅仅改个方法名参数?这太麻烦了 问题二:args这个参数我想改成kwargs字典形式的参数应该怎么改。t1 = threading.Thread(target=A,kwargs={a=1,b=2,c=3}) 这样报错,,然后改成t1 = threading.Thread(target=A,kwargs={‘a=1,b=2,c=3’})再改成一对参数和值加一组单引号还是报错,,求正确格式。。
基于tensorflow的pix2pix代码中如何做到输入图像和输出图像分辨率不一致
问题:例如在自己制作了成对的输入(input256×256 target 200×256)后,如何让输入图像和输出图像分辨率不一致,例如成对图像中:input的分辨率是256×256, output 和target都是200×256,需要修改哪里的参数。 论文参考:《Image-to-Image Translation with Conditional Adversarial Networks》 代码参考:https://blog.csdn.net/MOU_IT/article/details/80802407?utm_source=blogxgwz0 # coding=utf-8 from __future__ import absolute_import from __future__ import division from __future__ import print_function import os os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' import tensorflow as tf import numpy as np import os import glob import random import collections import math import time # https://github.com/affinelayer/pix2pix-tensorflow train_input_dir = "D:/Project/pix2pix-tensorflow-master/facades/train/" # 训练集输入 train_output_dir = "D:/Project/pix2pix-tensorflow-master/facades/train_out/" # 训练集输出 test_input_dir = "D:/Project/pix2pix-tensorflow-master/facades/val/" # 测试集输入 test_output_dir = "D:/Project/pix2pix-tensorflow-master/facades/test_out/" # 测试集的输出 checkpoint = "D:/Project/pix2pix-tensorflow-master/facades/train_out/" # 保存结果的目录 seed = None max_steps = None # number of training steps (0 to disable) max_epochs = 200 # number of training epochs progress_freq = 50 # display progress every progress_freq steps trace_freq = 0 # trace execution every trace_freq steps display_freq = 50 # write current training images every display_freq steps save_freq = 500 # save model every save_freq steps, 0 to disable separable_conv = False # use separable convolutions in the generator aspect_ratio = 1 #aspect ratio of output images (width/height) batch_size = 1 # help="number of images in batch") which_direction = "BtoA" # choices=["AtoB", "BtoA"]) ngf = 64 # help="number of generator filters in first conv layer") ndf = 64 # help="number of discriminator filters in first conv layer") scale_size = 286 # help="scale images to this size before cropping to 256x256") flip = True # flip images horizontally no_flip = True # don't flip images horizontally lr = 0.0002 # initial learning rate for adam beta1 = 0.5 # momentum term of adam l1_weight = 100.0 # weight on L1 term for generator gradient gan_weight = 1.0 # weight on GAN term for generator gradient output_filetype = "png" # 输出图像的格式 EPS = 1e-12 # 极小数,防止梯度为损失为0 CROP_SIZE = 256 # 图片的裁剪大小 # 命名元组,用于存放加载的数据集合创建好的模型 Examples = collections.namedtuple("Examples", "paths, inputs, targets, count, steps_per_epoch") Model = collections.namedtuple("Model", "outputs, predict_real, predict_fake, discrim_loss, discrim_grads_and_vars, gen_loss_GAN, gen_loss_L1, gen_grads_and_vars, train") # 图像预处理 [0, 1] => [-1, 1] def preprocess(image): with tf.name_scope("preprocess"): return image * 2 - 1 # 图像后处理[-1, 1] => [0, 1] def deprocess(image): with tf.name_scope("deprocess"): return (image + 1) / 2 # 判别器的卷积定义,batch_input为 [ batch , 256 , 256 , 6 ] def discrim_conv(batch_input, out_channels, stride): # [ batch , 256 , 256 , 6 ] ===>[ batch , 258 , 258 , 6 ] padded_input = tf.pad(batch_input, [[0, 0], [1, 1], [1, 1], [0, 0]], mode="CONSTANT") ''' [0,0]: 第一维batch大小不扩充 [1,1]:第二维图像宽度左右各扩充一列,用0填充 [1,1]:第三维图像高度上下各扩充一列,用0填充 [0,0]:第四维图像通道不做扩充 ''' return tf.layers.conv2d(padded_input, out_channels, kernel_size=4, strides=(stride, stride), padding="valid", kernel_initializer=tf.random_normal_initializer(0, 0.02)) # 生成器的卷积定义,卷积核为4*4,步长为2,输出图像为输入的一半 def gen_conv(batch_input, out_channels): # [batch, in_height, in_width, in_channels] => [batch, out_height, out_width, out_channels] initializer = tf.random_normal_initializer(0, 0.02) if separable_conv: return tf.layers.separable_conv2d(batch_input, out_channels, kernel_size=4, strides=(2, 2), padding="same", depthwise_initializer=initializer, pointwise_initializer=initializer) else: return tf.layers.conv2d(batch_input, out_channels, kernel_size=4, strides=(2, 2), padding="same", kernel_initializer=initializer) # 生成器的反卷积定义 def gen_deconv(batch_input, out_channels): # [batch, in_height, in_width, in_channels] => [batch, out_height, out_width, out_channels] initializer = tf.random_normal_initializer(0, 0.02) if separable_conv: _b, h, w, _c = batch_input.shape resized_input = tf.image.resize_images(batch_input, [h * 2, w * 2], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR) return tf.layers.separable_conv2d(resized_input, out_channels, kernel_size=4, strides=(1, 1), padding="same", depthwise_initializer=initializer, pointwise_initializer=initializer) else: return tf.layers.conv2d_transpose(batch_input, out_channels, kernel_size=4, strides=(2, 2), padding="same", kernel_initializer=initializer) # 定义LReLu激活函数 def lrelu(x, a): with tf.name_scope("lrelu"): # adding these together creates the leak part and linear part # then cancels them out by subtracting/adding an absolute value term # leak: a*x/2 - a*abs(x)/2 # linear: x/2 + abs(x)/2 # this block looks like it has 2 inputs on the graph unless we do this x = tf.identity(x) return (0.5 * (1 + a)) * x + (0.5 * (1 - a)) * tf.abs(x) # 批量归一化图像 def batchnorm(inputs): return tf.layers.batch_normalization(inputs, axis=3, epsilon=1e-5, momentum=0.1, training=True, gamma_initializer=tf.random_normal_initializer(1.0, 0.02)) # 检查图像的维度 def check_image(image): assertion = tf.assert_equal(tf.shape(image)[-1], 3, message="image must have 3 color channels") with tf.control_dependencies([assertion]): image = tf.identity(image) if image.get_shape().ndims not in (3, 4): raise ValueError("image must be either 3 or 4 dimensions") # make the last dimension 3 so that you can unstack the colors shape = list(image.get_shape()) shape[-1] = 3 image.set_shape(shape) return image # 去除文件的后缀,获取文件名 def get_name(path): # os.path.basename(),返回path最后的文件名。若path以/或\结尾,那么就会返回空值。 # os.path.splitext(),分离文件名与扩展名;默认返回(fname,fextension)元组 name, _ = os.path.splitext(os.path.basename(path)) return name # 加载数据集,从文件读取-->解码-->归一化--->拆分为输入和目标-->像素转为[-1,1]-->转变形状 def load_examples(input_dir): if input_dir is None or not os.path.exists(input_dir): raise Exception("input_dir does not exist") # 匹配第一个参数的路径中所有的符合条件的文件,并将其以list的形式返回。 input_paths = glob.glob(os.path.join(input_dir, "*.jpg")) # 图像解码器 decode = tf.image.decode_jpeg if len(input_paths) == 0: input_paths = glob.glob(os.path.join(input_dir, "*.png")) decode = tf.image.decode_png if len(input_paths) == 0: raise Exception("input_dir contains no image files") # 如果文件名是数字,则用数字进行排序,否则用字母排序 if all(get_name(path).isdigit() for path in input_paths): input_paths = sorted(input_paths, key=lambda path: int(get_name(path))) else: input_paths = sorted(input_paths) sess = tf.Session() with tf.name_scope("load_images"): # 把我们需要的全部文件打包为一个tf内部的queue类型,之后tf开文件就从这个queue中取目录了, # 如果是训练模式时,shuffle为True path_queue = tf.train.string_input_producer(input_paths, shuffle=True) # Read的输出将是一个文件名(key)和该文件的内容(value,每次读取一个文件,分多次读取)。 reader = tf.WholeFileReader() paths, contents = reader.read(path_queue) # 对文件进行解码并且对图片作归一化处理 raw_input = decode(contents) raw_input = tf.image.convert_image_dtype(raw_input, dtype=tf.float32) # 归一化处理 # 判断两个值知否相等,如果不等抛出异常 assertion = tf.assert_equal(tf.shape(raw_input)[2], 3, message="image does not have 3 channels") ''' 对于control_dependencies这个管理器,只有当里面的操作是一个op时,才会生效,也就是先执行传入的 参数op,再执行里面的op。如果里面的操作不是定义的op,图中就不会形成一个节点,这样该管理器就失效了。 tf.identity是返回一个一模一样新的tensor的op,这会增加一个新节点到gragh中,这时control_dependencies就会生效. ''' with tf.control_dependencies([assertion]): raw_input = tf.identity(raw_input) raw_input.set_shape([None, None, 3]) # 图像值由[0,1]--->[-1, 1] width = tf.shape(raw_input)[1] # [height, width, channels] a_images = preprocess(raw_input[:, :width // 2, :]) # 256*256*3 b_images = preprocess(raw_input[:, width // 2:, :]) # 256*256*3 # 这里的which_direction为:BtoA if which_direction == "AtoB": inputs, targets = [a_images, b_images] elif which_direction == "BtoA": inputs, targets = [b_images, a_images] else: raise Exception("invalid direction") # synchronize seed for image operations so that we do the same operations to both # input and output images seed = random.randint(0, 2 ** 31 - 1) # 图像预处理,翻转、改变形状 with tf.name_scope("input_images"): input_images = transform(inputs) with tf.name_scope("target_images"): target_images = transform(targets) # 获得输入图像、目标图像的batch块 paths_batch, inputs_batch, targets_batch = tf.train.batch([paths, input_images, target_images], batch_size=batch_size) steps_per_epoch = int(math.ceil(len(input_paths) / batch_size)) return Examples( paths=paths_batch, # 输入的文件名块 inputs=inputs_batch, # 输入的图像块 targets=targets_batch, # 目标图像块 count=len(input_paths), # 数据集的大小 steps_per_epoch=steps_per_epoch, # batch的个数 ) # 图像预处理,翻转、改变形状 def transform(image): r = image if flip: r = tf.image.random_flip_left_right(r, seed=seed) # area produces a nice downscaling, but does nearest neighbor for upscaling # assume we're going to be doing downscaling here r = tf.image.resize_images(r, [scale_size, scale_size], method=tf.image.ResizeMethod.AREA) offset = tf.cast(tf.floor(tf.random_uniform([2], 0, scale_size - CROP_SIZE + 1, seed=seed)), dtype=tf.int32) if scale_size > CROP_SIZE: r = tf.image.crop_to_bounding_box(r, offset[0], offset[1], CROP_SIZE, CROP_SIZE) elif scale_size < CROP_SIZE: raise Exception("scale size cannot be less than crop size") return r # 创建生成器,这是一个编码解码器的变种,输入输出均为:256*256*3, 像素值为[-1,1] def create_generator(generator_inputs, generator_outputs_channels): layers = [] # encoder_1: [batch, 256, 256, in_channels] => [batch, 128, 128, ngf] with tf.variable_scope("encoder_1"): output = gen_conv(generator_inputs, ngf) # ngf为第一个卷积层的卷积核核数量,默认为 64 layers.append(output) layer_specs = [ ngf * 2, # encoder_2: [batch, 128, 128, ngf] => [batch, 64, 64, ngf * 2] ngf * 4, # encoder_3: [batch, 64, 64, ngf * 2] => [batch, 32, 32, ngf * 4] ngf * 8, # encoder_4: [batch, 32, 32, ngf * 4] => [batch, 16, 16, ngf * 8] ngf * 8, # encoder_5: [batch, 16, 16, ngf * 8] => [batch, 8, 8, ngf * 8] ngf * 8, # encoder_6: [batch, 8, 8, ngf * 8] => [batch, 4, 4, ngf * 8] ngf * 8, # encoder_7: [batch, 4, 4, ngf * 8] => [batch, 2, 2, ngf * 8] ngf * 8, # encoder_8: [batch, 2, 2, ngf * 8] => [batch, 1, 1, ngf * 8] ] # 卷积的编码器 for out_channels in layer_specs: with tf.variable_scope("encoder_%d" % (len(layers) + 1)): # 对最后一层使用激活函数 rectified = lrelu(layers[-1], 0.2) # [batch, in_height, in_width, in_channels] => [batch, in_height/2, in_width/2, out_channels] convolved = gen_conv(rectified, out_channels) output = batchnorm(convolved) layers.append(output) layer_specs = [ (ngf * 8, 0.5), # decoder_8: [batch, 1, 1, ngf * 8] => [batch, 2, 2, ngf * 8 * 2] (ngf * 8, 0.5), # decoder_7: [batch, 2, 2, ngf * 8 * 2] => [batch, 4, 4, ngf * 8 * 2] (ngf * 8, 0.5), # decoder_6: [batch, 4, 4, ngf * 8 * 2] => [batch, 8, 8, ngf * 8 * 2] (ngf * 8, 0.0), # decoder_5: [batch, 8, 8, ngf * 8 * 2] => [batch, 16, 16, ngf * 8 * 2] (ngf * 4, 0.0), # decoder_4: [batch, 16, 16, ngf * 8 * 2] => [batch, 32, 32, ngf * 4 * 2] (ngf * 2, 0.0), # decoder_3: [batch, 32, 32, ngf * 4 * 2] => [batch, 64, 64, ngf * 2 * 2] (ngf, 0.0), # decoder_2: [batch, 64, 64, ngf * 2 * 2] => [batch, 128, 128, ngf * 2] ] # 卷积的解码器 num_encoder_layers = len(layers) # 8 for decoder_layer, (out_channels, dropout) in enumerate(layer_specs): skip_layer = num_encoder_layers - decoder_layer - 1 with tf.variable_scope("decoder_%d" % (skip_layer + 1)): if decoder_layer == 0: # first decoder layer doesn't have skip connections # since it is directly connected to the skip_layer input = layers[-1] else: input = tf.concat([layers[-1], layers[skip_layer]], axis=3) rectified = tf.nn.relu(input) # [batch, in_height, in_width, in_channels] => [batch, in_height*2, in_width*2, out_channels] output = gen_deconv(rectified, out_channels) output = batchnorm(output) if dropout > 0.0: output = tf.nn.dropout(output, keep_prob=1 - dropout) layers.append(output) # decoder_1: [batch, 128, 128, ngf * 2] => [batch, 256, 256, generator_outputs_channels] with tf.variable_scope("decoder_1"): input = tf.concat([layers[-1], layers[0]], axis=3) rectified = tf.nn.relu(input) output = gen_deconv(rectified, generator_outputs_channels) output = tf.tanh(output) layers.append(output) return layers[-1] # 创建判别器,输入生成的图像和真实的图像:两个[batch,256,256,3],元素值值[-1,1],输出:[batch,30,30,1],元素值为概率 def create_discriminator(discrim_inputs, discrim_targets): n_layers = 3 layers = [] # 2x [batch, height, width, in_channels] => [batch, height, width, in_channels * 2] input = tf.concat([discrim_inputs, discrim_targets], axis=3) # layer_1: [batch, 256, 256, in_channels * 2] => [batch, 128, 128, ndf] with tf.variable_scope("layer_1"): convolved = discrim_conv(input, ndf, stride=2) rectified = lrelu(convolved, 0.2) layers.append(rectified) # layer_2: [batch, 128, 128, ndf] => [batch, 64, 64, ndf * 2] # layer_3: [batch, 64, 64, ndf * 2] => [batch, 32, 32, ndf * 4] # layer_4: [batch, 32, 32, ndf * 4] => [batch, 31, 31, ndf * 8] for i in range(n_layers): with tf.variable_scope("layer_%d" % (len(layers) + 1)): out_channels = ndf * min(2 ** (i + 1), 8) stride = 1 if i == n_layers - 1 else 2 # last layer here has stride 1 convolved = discrim_conv(layers[-1], out_channels, stride=stride) normalized = batchnorm(convolved) rectified = lrelu(normalized, 0.2) layers.append(rectified) # layer_5: [batch, 31, 31, ndf * 8] => [batch, 30, 30, 1] with tf.variable_scope("layer_%d" % (len(layers) + 1)): convolved = discrim_conv(rectified, out_channels=1, stride=1) output = tf.sigmoid(convolved) layers.append(output) return layers[-1] # 创建Pix2Pix模型,inputs和targets形状为:[batch_size, height, width, channels] def create_model(inputs, targets): with tf.variable_scope("generator"): out_channels = int(targets.get_shape()[-1]) outputs = create_generator(inputs, out_channels) # create two copies of discriminator, one for real pairs and one for fake pairs # they share the same underlying variables with tf.name_scope("real_discriminator"): with tf.variable_scope("discriminator"): # 2x [batch, height, width, channels] => [batch, 30, 30, 1] predict_real = create_discriminator(inputs, targets) # 条件变量图像和真实图像 with tf.name_scope("fake_discriminator"): with tf.variable_scope("discriminator", reuse=True): # 2x [batch, height, width, channels] => [batch, 30, 30, 1] predict_fake = create_discriminator(inputs, outputs) # 条件变量图像和生成的图像 # 判别器的损失,判别器希望V(G,D)尽可能大 with tf.name_scope("discriminator_loss"): # minimizing -tf.log will try to get inputs to 1 # predict_real => 1 # predict_fake => 0 discrim_loss = tf.reduce_mean(-(tf.log(predict_real + EPS) + tf.log(1 - predict_fake + EPS))) # 生成器的损失,生成器希望V(G,D)尽可能小 with tf.name_scope("generator_loss"): # predict_fake => 1 # abs(targets - outputs) => 0 gen_loss_GAN = tf.reduce_mean(-tf.log(predict_fake + EPS)) gen_loss_L1 = tf.reduce_mean(tf.abs(targets - outputs)) gen_loss = gen_loss_GAN * gan_weight + gen_loss_L1 * l1_weight # 判别器训练 with tf.name_scope("discriminator_train"): # 判别器需要优化的参数 discrim_tvars = [var for var in tf.trainable_variables() if var.name.startswith("discriminator")] # 优化器定义 discrim_optim = tf.train.AdamOptimizer(lr, beta1) # 计算损失函数对优化参数的梯度 discrim_grads_and_vars = discrim_optim.compute_gradients(discrim_loss, var_list=discrim_tvars) # 更新该梯度所对应的参数的状态,返回一个op discrim_train = discrim_optim.apply_gradients(discrim_grads_and_vars) # 生成器训练 with tf.name_scope("generator_train"): with tf.control_dependencies([discrim_train]): # 生成器需要优化的参数列表 gen_tvars = [var for var in tf.trainable_variables() if var.name.startswith("generator")] # 定义优化器 gen_optim = tf.train.AdamOptimizer(lr, beta1) # 计算需要优化的参数的梯度 gen_grads_and_vars = gen_optim.compute_gradients(gen_loss, var_list=gen_tvars) # 更新该梯度所对应的参数的状态,返回一个op gen_train = gen_optim.apply_gradients(gen_grads_and_vars) ''' 在采用随机梯度下降算法训练神经网络时,使用 tf.train.ExponentialMovingAverage 滑动平均操作的意义在于 提高模型在测试数据上的健壮性(robustness)。tensorflow 下的 tf.train.ExponentialMovingAverage 需要 提供一个衰减率(decay)。该衰减率用于控制模型更新的速度。该衰减率用于控制模型更新的速度, ExponentialMovingAverage 对每一个(待更新训练学习的)变量(variable)都会维护一个影子变量 (shadow variable)。影子变量的初始值就是这个变量的初始值, shadow_variable=decay×shadow_variable+(1−decay)×variable ''' ema = tf.train.ExponentialMovingAverage(decay=0.99) update_losses = ema.apply([discrim_loss, gen_loss_GAN, gen_loss_L1]) # global_step = tf.train.get_or_create_global_step() incr_global_step = tf.assign(global_step, global_step + 1) return Model( predict_real=predict_real, # 条件变量(输入图像)和真实图像之间的概率值,形状为;[batch,30,30,1] predict_fake=predict_fake, # 条件变量(输入图像)和生成图像之间的概率值,形状为;[batch,30,30,1] discrim_loss=ema.average(discrim_loss), # 判别器损失 discrim_grads_and_vars=discrim_grads_and_vars, # 判别器需要优化的参数和对应的梯度 gen_loss_GAN=ema.average(gen_loss_GAN), # 生成器的损失 gen_loss_L1=ema.average(gen_loss_L1), # 生成器的 L1损失 gen_grads_and_vars=gen_grads_and_vars, # 生成器需要优化的参数和对应的梯度 outputs=outputs, # 生成器生成的图片 train=tf.group(update_losses, incr_global_step, gen_train), # 打包需要run的操作op ) # 保存图像 def save_images(output_dir, fetches, step=None): image_dir = os.path.join(output_dir, "images") if not os.path.exists(image_dir): os.makedirs(image_dir) filesets = [] for i, in_path in enumerate(fetches["paths"]): name, _ = os.path.splitext(os.path.basename(in_path.decode("utf8"))) fileset = {"name": name, "step": step} for kind in ["inputs", "outputs", "targets"]: filename = name + "-" + kind + ".png" if step is not None: filename = "%08d-%s" % (step, filename) fileset[kind] = filename out_path = os.path.join(image_dir, filename) contents = fetches[kind][i] with open(out_path, "wb") as f: f.write(contents) filesets.append(fileset) return filesets # 将结果写入HTML网页 def append_index(output_dir, filesets, step=False): index_path = os.path.join(output_dir, "index.html") if os.path.exists(index_path): index = open(index_path, "a") else: index = open(index_path, "w") index.write("<html><body><table><tr>") if step: index.write("<th>step</th>") index.write("<th>name</th><th>input</th><th>output</th><th>target</th></tr>") for fileset in filesets: index.write("<tr>") if step: index.write("<td>%d</td>" % fileset["step"]) index.write("<td>%s</td>" % fileset["name"]) for kind in ["inputs", "outputs", "targets"]: index.write("<td><img src='images/%s'></td>" % fileset[kind]) index.write("</tr>") return index_path # 转变图像的尺寸、并且将[0,1]--->[0,255] def convert(image): if aspect_ratio != 1.0: # upscale to correct aspect ratio size = [CROP_SIZE, int(round(CROP_SIZE * aspect_ratio))] image = tf.image.resize_images(image, size=size, method=tf.image.ResizeMethod.BICUBIC) # 将数据的类型转换为8位无符号整型 return tf.image.convert_image_dtype(image, dtype=tf.uint8, saturate=True) # 主函数 def train(): # 设置随机数种子的值 global seed if seed is None: seed = random.randint(0, 2 ** 31 - 1) tf.set_random_seed(seed) np.random.seed(seed) random.seed(seed) # 创建目录 if not os.path.exists(train_output_dir): os.makedirs(train_output_dir) # 加载数据集,得到输入数据和目标数据并把范围变为 :[-1,1] examples = load_examples(train_input_dir) print("load successful ! examples count = %d" % examples.count) # 创建模型,inputs和targets是:[batch_size, height, width, channels] # 返回值: model = create_model(examples.inputs, examples.targets) print("create model successful!") # 图像处理[-1, 1] => [0, 1] inputs = deprocess(examples.inputs) targets = deprocess(examples.targets) outputs = deprocess(model.outputs) # 把[0,1]的像素点转为RGB值:[0,255] with tf.name_scope("convert_inputs"): converted_inputs = convert(inputs) with tf.name_scope("convert_targets"): converted_targets = convert(targets) with tf.name_scope("convert_outputs"): converted_outputs = convert(outputs) # 对图像进行编码以便于保存 with tf.name_scope("encode_images"): display_fetches = { "paths": examples.paths, # tf.map_fn接受一个函数对象和集合,用函数对集合中每个元素分别处理 "inputs": tf.map_fn(tf.image.encode_png, converted_inputs, dtype=tf.string, name="input_pngs"), "targets": tf.map_fn(tf.image.encode_png, converted_targets, dtype=tf.string, name="target_pngs"), "outputs": tf.map_fn(tf.image.encode_png, converted_outputs, dtype=tf.string, name="output_pngs"), } with tf.name_scope("parameter_count"): parameter_count = tf.reduce_sum([tf.reduce_prod(tf.shape(v)) for v in tf.trainable_variables()]) # 只保存最新一个checkpoint saver = tf.train.Saver(max_to_keep=20) init = tf.global_variables_initializer() with tf.Session() as sess: sess.run(init) print("parameter_count =", sess.run(parameter_count)) if max_epochs is not None: max_steps = examples.steps_per_epoch * max_epochs # 400X200=80000 # 因为是从文件中读取数据,所以需要启动start_queue_runners() # 这个函数将会启动输入管道的线程,填充样本到队列中,以便出队操作可以从队列中拿到样本。 coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) # 运行训练集 print("begin trainning......") print("max_steps:", max_steps) start = time.time() for step in range(max_steps): def should(freq): return freq > 0 and ((step + 1) % freq == 0 or step == max_steps - 1) print("step:", step) # 定义一个需要run的所有操作的字典 fetches = { "train": model.train } # progress_freq为 50,每50次计算一次三个损失,显示进度 if should(progress_freq): fetches["discrim_loss"] = model.discrim_loss fetches["gen_loss_GAN"] = model.gen_loss_GAN fetches["gen_loss_L1"] = model.gen_loss_L1 # display_freq为 50,每50次保存一次输入、目标、输出的图像 if should(display_freq): fetches["display"] = display_fetches # 运行各种操作, results = sess.run(fetches) # display_freq为 50,每50次保存输入、目标、输出的图像 if should(display_freq): print("saving display images") filesets = save_images(train_output_dir, results["display"], step=step) append_index(train_output_dir, filesets, step=True) # progress_freq为 50,每50次打印一次三种损失的大小,显示进度 if should(progress_freq): # global_step will have the correct step count if we resume from a checkpoint train_epoch = math.ceil(step / examples.steps_per_epoch) train_step = (step - 1) % examples.steps_per_epoch + 1 rate = (step + 1) * batch_size / (time.time() - start) remaining = (max_steps - step) * batch_size / rate print("progress epoch %d step %d image/sec %0.1f remaining %dm" % ( train_epoch, train_step, rate, remaining / 60)) print("discrim_loss", results["discrim_loss"]) print("gen_loss_GAN", results["gen_loss_GAN"]) print("gen_loss_L1", results["gen_loss_L1"]) # save_freq为500,每500次保存一次模型 if should(save_freq): print("saving model") saver.save(sess, os.path.join(train_output_dir, "model"), global_step=step) # 测试 def test(): # 设置随机数种子的值 global seed if seed is None: seed = random.randint(0, 2 ** 31 - 1) tf.set_random_seed(seed) np.random.seed(seed) random.seed(seed) # 创建目录 if not os.path.exists(test_output_dir): os.makedirs(test_output_dir) if checkpoint is None: raise Exception("checkpoint required for test mode") # disable these features in test mode scale_size = CROP_SIZE flip = False # 加载数据集,得到输入数据和目标数据 examples = load_examples(test_input_dir) print("load successful ! examples count = %d" % examples.count) # 创建模型,inputs和targets是:[batch_size, height, width, channels] model = create_model(examples.inputs, examples.targets) print("create model successful!") # 图像处理[-1, 1] => [0, 1] inputs = deprocess(examples.inputs) targets = deprocess(examples.targets) outputs = deprocess(model.outputs) # 把[0,1]的像素点转为RGB值:[0,255] with tf.name_scope("convert_inputs"): converted_inputs = convert(inputs) with tf.name_scope("convert_targets"): converted_targets = convert(targets) with tf.name_scope("convert_outputs"): converted_outputs = convert(outputs) # 对图像进行编码以便于保存 with tf.name_scope("encode_images"): display_fetches = { "paths": examples.paths, # tf.map_fn接受一个函数对象和集合,用函数对集合中每个元素分别处理 "inputs": tf.map_fn(tf.image.encode_png, converted_inputs, dtype=tf.string, name="input_pngs"), "targets": tf.map_fn(tf.image.encode_png, converted_targets, dtype=tf.string, name="target_pngs"), "outputs": tf.map_fn(tf.image.encode_png, converted_outputs, dtype=tf.string, name="output_pngs"), } sess = tf.InteractiveSession() saver = tf.train.Saver(max_to_keep=1) ckpt = tf.train.get_checkpoint_state(checkpoint) saver.restore(sess,ckpt.model_checkpoint_path) start = time.time() coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) for step in range(examples.count): results = sess.run(display_fetches) filesets = save_images(test_output_dir, results) for i, f in enumerate(filesets): print("evaluated image", f["name"]) index_path = append_index(test_output_dir, filesets) print("wrote index at", index_path) print("rate", (time.time() - start) / max_steps) if __name__ == '__main__': train() #test()
tensorflow实现BP算法遇到了问题,求大神指点!!!
import tensorflow as tf import numpy as np #from tensorflow.examples.tutorials.mnist import input_data #载入数据集 #mnist = input_data.read_data_sets("MNIST_data",one_hot=True) #每个批次的大小 #batch_size = 100 #????????????????????????????????? #计算一共有多少个批次 #n_batch = mnist.train.num_examples // batch_size #定义placeholder x_data=np.mat([[0.4984,0.5102,0.5213,0.5340], [0.5102,0.5213,0.5340,0.5407], [0.5213,0.5340,0.5407,0.5428], [0.5340,0.5407,0.5428,0.5530], [0.5407,0.5428,0.5530,0.5632], [0.5428,0.5530,0.5632,0.5739], [0.5530,0.5632,0.5739,0.5821], [0.5632,0.5739,0.5821,0.5920], [0.5739,0.5821,0.5920,0.5987], [0.5821,0.5920,0.5987,0.6043], [0.5920,0.5987,0.6043,0.6095], [0.5987,0.6043,0.6095,0.6161], [0.6043,0.6095,0.6161,0.6251], [0.6095,0.6161,0.6251,0.6318], [0.6161,0.6251,0.6318,0.6387], [0.6251,0.6318,0.6387,0.6462], [0.6318,0.6387,0.6462,0.6518], [0.6387,0.6462,0.6518,0.6589], [0.6462,0.6518,0.6589,0.6674], [0.6518,0.6589,0.6674,0.6786], [0.6589,0.6674,0.6786,0.6892], [0.6674,0.6786,0.6892,0.6988]]) y_data=np.mat([[0.5407], [0.5428], [0.5530], [0.5632], [0.5739], [0.5821], [0.5920], [0.5987], [0.6043], [0.6095], [0.6161], [0.6251], [0.6318], [0.6387], [0.6462], [0.6518], [0.6589], [0.6674], [0.6786], [0.6892], [0.6988], [0.7072]]) xs = tf.placeholder(tf.float32,[None,4]) # 样本数未知,特征数为1,占位符最后要以字典形式在运行中填入 ys = tf.placeholder(tf.float32,[None,1]) #创建一个简单的神经网络 W1 = tf.Variable(tf.truncated_normal([4,10],stddev=0.1)) b1 = tf.Variable(tf.zeros([10])+0.1) L1 = tf.nn.tanh(tf.matmul(x,W1)+b1) W2 = tf.Variable(tf.truncated_normal([10,1],stddev=0.1)) b2 = tf.Variable(tf.zeros([1])+0.1) L2 = tf.nn.softmax(tf.matmul(L1,W2)+b2) #二次代价函数 #loss = tf.reduce_mean(tf.square(y-prediction)) #loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=ys,logits=prediction)) loss = tf.reduce_mean(tf.reduce_sum(tf.square((ys-L2)),reduction_indices = [1]))#需要向相加索引号,redeuc执行跨纬度操作 #使用梯度下降法 #train_step = tf.train.GradientDescentOptimizer(0.1).mnimize(loss) train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss) #train = tf.train.GradientDescentOptimizer(0.1).minimize(loss) # 选择梯度下降法 #初始化变量 #init = tf.global_variables_initializer() init = tf.initialize_all_variables() #结果存放在一个布尔型列表中 #correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1)) #求准确率 #accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) with tf.Session() as sess: sess.run(init) for epoch in range(21): for i in range(22): #batch_xs,batch_ys = mnist.train.next_batch(batch_size) #????????????????????????? sess.run(train_step,feed_dict={xs:x_data,ys:y_data}) #test_acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels,keep_prob:1.0}) #train_acc = sess.run(accuracy,feed_dict={x:mnist.train.images,y:mnist.train.labels,keep_prob:1.0}) print (sess.run(prediction,feed_dict={xs:x_data,ys:y_data})) 提示:WARNING:tensorflow:From <ipython-input-10-578836c021a3>:89 in <module>.: initialize_all_variables (from tensorflow.python.ops.variables) is deprecated and will be removed after 2017-03-02. Instructions for updating: Use `tf.global_variables_initializer` instead. --------------------------------------------------------------------------- InvalidArgumentError Traceback (most recent call last) C:\Users\Administrator\Anaconda3\lib\site-packages\tensorflow\python\client\session.py in _do_call(self, fn, *args) 1020 try: -> 1021 return fn(*args) 1022 except errors.OpError as e: C:\Users\Administrator\Anaconda3\lib\site-packages\tensorflow\python\client\session.py in _run_fn(session, feed_dict, fetch_list, target_list, options, run_metadata) 1002 feed_dict, fetch_list, target_list, -> 1003 status, run_metadata) 1004 。。。
萌新求助代码问题,求各位老司机指点迷津,感激不尽!!
题目要求如下: 尝试编写一个用户登陆程序(这次尝试将该功能封装成函数),程序实现如图: ![图片说明](https://img-ask.csdn.net/upload/201701/08/1483886164_884313.png) 个人代码如下: ``` data = {} def denglu(): print('''\n |--- 新建用户:N/n ---| |--- 登陆账号:E/e ---| |--- 退出程序:Q/q ---|''') order = input('|--- 请输入指令代码:') if order == 'n' or order == 'N': name = input('请输入用户名:') flag1 = 1 #flag1用以判断输入的用户名是否存在,存在为1,不存在为0 while flag1: if name in data: name = input('此用户名已经被使用,请重新输入:') else: flag1 = 0 password = input('请输入密码:') data = [name, password] print('注册成功,赶紧试试登陆吧^_^') return denglu() elif order == 'e' or order == 'E': name = input('请输入用户名:') flag2 = 1 #flag2用以判断输入的用户名是否存在,不存在为1,存在为0 while flag2: if name not in data: name = input('您输入的用户名不存在,请重新输入:') else: flag2 = 0 password = input('请输入密码:') flag3 = 1 while flag3: #flag3用以判断输入的密码是否存在,不存在为1,存在为0 if password != data[name]: password = input('密码错误,请重新输入:') else: flag3 = 0 print('欢迎进入xxoo系统,请点右上角的x结束程序') else: print('欢迎使用XXX登陆程序!') denglu() ``` 运行程序,当输入指令为n(N)或者e(E),输入用户名之后,报错,提示为: UnboundLocalError: local variable 'data' referenced before assignment 但是data字典是我在函数外面定义的全局变量,所以不太理解哪里有问题,希望有经验的小伙伴指出问题所在,感激不尽!!!!!!
Java学习的正确打开方式
在博主认为,对于入门级学习java的最佳学习方法莫过于视频+博客+书籍+总结,前三者博主将淋漓尽致地挥毫于这篇博客文章中,至于总结在于个人,实际上越到后面你会发现学习的最好方式就是阅读参考官方文档其次就是国内的书籍,博客次之,这又是一个层次了,这里暂时不提后面再谈。博主将为各位入门java保驾护航,各位只管冲鸭!!!上天是公平的,只要不辜负时间,时间自然不会辜负你。 何谓学习?博主所理解的学习,它
程序员必须掌握的核心算法有哪些?
由于我之前一直强调数据结构以及算法学习的重要性,所以就有一些读者经常问我,数据结构与算法应该要学习到哪个程度呢?,说实话,这个问题我不知道要怎么回答你,主要取决于你想学习到哪些程度,不过针对这个问题,我稍微总结一下我学过的算法知识点,以及我觉得值得学习的算法。这些算法与数据结构的学习大多数是零散的,并没有一本把他们全部覆盖的书籍。下面是我觉得值得学习的一些算法以及数据结构,当然,我也会整理一些看过
大学四年自学走来,这些私藏的实用工具/学习网站我贡献出来了
大学四年,看课本是不可能一直看课本的了,对于学习,特别是自学,善于搜索网上的一些资源来辅助,还是非常有必要的,下面我就把这几年私藏的各种资源,网站贡献出来给你们。主要有:电子书搜索、实用工具、在线视频学习网站、非视频学习网站、软件下载、面试/求职必备网站。 注意:文中提到的所有资源,文末我都给你整理好了,你们只管拿去,如果觉得不错,转发、分享就是最大的支持了。 一、PDF搜索网站推荐 对于大部
linux系列之常用运维命令整理笔录
本博客记录工作中需要的linux运维命令,大学时候开始接触linux,会一些基本操作,可是都没有整理起来,加上是做开发,不做运维,有些命令忘记了,所以现在整理成博客,当然vi,文件操作等就不介绍了,慢慢积累一些其它拓展的命令,博客不定时更新 顺便拉下票,我在参加csdn博客之星竞选,欢迎投票支持,每个QQ或者微信每天都可以投5票,扫二维码即可,http://m234140.nofollow.ax.
Vue + Spring Boot 项目实战(十四):用户认证方案与完善的访问拦截
本篇文章主要讲解 token、session 等用户认证方案的区别并分析常见误区,以及如何通过前后端的配合实现完善的访问拦截,为下一步权限控制的实现打下基础。
比特币原理详解
一、什么是比特币 比特币是一种电子货币,是一种基于密码学的货币,在2008年11月1日由中本聪发表比特币白皮书,文中提出了一种去中心化的电子记账系统,我们平时的电子现金是银行来记账,因为银行的背后是国家信用。去中心化电子记账系统是参与者共同记账。比特币可以防止主权危机、信用风险。其好处不多做赘述,这一层面介绍的文章很多,本文主要从更深层的技术原理角度进行介绍。 二、问题引入  假设现有4个人
程序员接私活怎样防止做完了不给钱?
首先跟大家说明一点,我们做 IT 类的外包开发,是非标品开发,所以很有可能在开发过程中会有这样那样的需求修改,而这种需求修改很容易造成扯皮,进而影响到费用支付,甚至出现做完了项目收不到钱的情况。 那么,怎么保证自己的薪酬安全呢? 我们在开工前,一定要做好一些证据方面的准备(也就是“讨薪”的理论依据),这其中最重要的就是需求文档和验收标准。一定要让需求方提供这两个文档资料作为开发的基础。之后开发
网页实现一个简单的音乐播放器(大佬别看。(⊙﹏⊙))
今天闲着无事,就想写点东西。然后听了下歌,就打算写个播放器。 于是乎用h5 audio的加上js简单的播放器完工了。 欢迎 改进 留言。 演示地点跳到演示地点 html代码如下`&lt;!DOCTYPE html&gt; &lt;html&gt; &lt;head&gt; &lt;title&gt;music&lt;/title&gt; &lt;meta charset="utf-8"&gt
Python十大装B语法
Python 是一种代表简单思想的语言,其语法相对简单,很容易上手。不过,如果就此小视 Python 语法的精妙和深邃,那就大错特错了。本文精心筛选了最能展现 Python 语法之精妙的十个知识点,并附上详细的实例代码。如能在实战中融会贯通、灵活使用,必将使代码更为精炼、高效,同时也会极大提升代码B格,使之看上去更老练,读起来更优雅。 1. for - else 什么?不是 if 和 else 才
数据库优化 - SQL优化
前面一篇文章从实例的角度进行数据库优化,通过配置一些参数让数据库性能达到最优。但是一些“不好”的SQL也会导致数据库查询变慢,影响业务流程。本文从SQL角度进行数据库优化,提升SQL运行效率。 判断问题SQL 判断SQL是否有问题时可以通过两个表象进行判断: 系统级别表象 CPU消耗严重 IO等待严重 页面响应时间过长
2019年11月中国大陆编程语言排行榜
2019年11月2日,我统计了某招聘网站,获得有效程序员招聘数据9万条。针对招聘信息,提取编程语言关键字,并统计如下: 编程语言比例 rank pl_ percentage 1 java 33.62% 2 c/c++ 16.42% 3 c_sharp 12.82% 4 javascript 12.31% 5 python 7.93% 6 go 7.25% 7
通俗易懂地给女朋友讲:线程池的内部原理
餐厅的约会 餐盘在灯光的照耀下格外晶莹洁白,女朋友拿起红酒杯轻轻地抿了一小口,对我说:“经常听你说线程池,到底线程池到底是个什么原理?”我楞了一下,心里想女朋友今天是怎么了,怎么突然问出这么专业的问题,但做为一个专业人士在女朋友面前也不能露怯啊,想了一下便说:“我先给你讲讲我前同事老王的故事吧!” 大龄程序员老王 老王是一个已经北漂十多年的程序员,岁数大了,加班加不动了,升迁也无望,于是拿着手里
经典算法(5)杨辉三角
写在前面: 我是 扬帆向海,这个昵称来源于我的名字以及女朋友的名字。我热爱技术、热爱开源、热爱编程。技术是开源的、知识是共享的。 这博客是对自己学习的一点点总结及记录,如果您对 Java、算法 感兴趣,可以关注我的动态,我们一起学习。 用知识改变命运,让我们的家人过上更好的生活。 目录一、杨辉三角的介绍二、杨辉三角的算法思想三、代码实现1.第一种写法2.第二种写法 一、杨辉三角的介绍 百度
腾讯算法面试题:64匹马8个跑道需要多少轮才能选出最快的四匹?
昨天,有网友私信我,说去阿里面试,彻底的被打击到了。问了为什么网上大量使用ThreadLocal的源码都会加上private static?他被难住了,因为他从来都没有考虑过这个问题。无独有偶,今天笔者又发现有网友吐槽了一道腾讯的面试题,我们一起来看看。 腾讯算法面试题:64匹马8个跑道需要多少轮才能选出最快的四匹? 在互联网职场论坛,一名程序员发帖求助到。二面腾讯,其中一个算法题:64匹
面试官:你连RESTful都不知道我怎么敢要你?
面试官:了解RESTful吗? 我:听说过。 面试官:那什么是RESTful? 我:就是用起来很规范,挺好的 面试官:是RESTful挺好的,还是自我感觉挺好的 我:都挺好的。 面试官:… 把门关上。 我:… 要干嘛?先关上再说。 面试官:我说出去把门关上。 我:what ?,夺门而去 文章目录01 前言02 RESTful的来源03 RESTful6大原则1. C-S架构2. 无状态3.统一的接
为啥国人偏爱Mybatis,而老外喜欢Hibernate/JPA呢?
关于SQL和ORM的争论,永远都不会终止,我也一直在思考这个问题。昨天又跟群里的小伙伴进行了一番讨论,感触还是有一些,于是就有了今天这篇文。 声明:本文不会下关于Mybatis和JPA两个持久层框架哪个更好这样的结论。只是摆事实,讲道理,所以,请各位看官勿喷。 一、事件起因 关于Mybatis和JPA孰优孰劣的问题,争论已经很多年了。一直也没有结论,毕竟每个人的喜好和习惯是大不相同的。我也看
SQL-小白最佳入门sql查询一
一 说明 如果是初学者,建议去网上寻找安装Mysql的文章安装,以及使用navicat连接数据库,以后的示例基本是使用mysql数据库管理系统; 二 准备前提 需要建立一张学生表,列分别是id,名称,年龄,学生信息;本示例中文章篇幅原因SQL注释略; 建表语句: CREATE TABLE `student` ( `id` int(11) NOT NULL AUTO_INCREMENT, `
项目中的if else太多了,该怎么重构?
介绍 最近跟着公司的大佬开发了一款IM系统,类似QQ和微信哈,就是聊天软件。我们有一部分业务逻辑是这样的 if (msgType = "文本") { // dosomething } else if(msgType = "图片") { // doshomething } else if(msgType = "视频") { // doshomething } else { // dosho
“狗屁不通文章生成器”登顶GitHub热榜,分分钟写出万字形式主义大作
一、垃圾文字生成器介绍 最近在浏览GitHub的时候,发现了这样一个骨骼清奇的雷人项目,而且热度还特别高。 项目中文名:狗屁不通文章生成器 项目英文名:BullshitGenerator 根据作者的介绍,他是偶尔需要一些中文文字用于GUI开发时测试文本渲染,因此开发了这个废话生成器。但由于生成的废话实在是太过富于哲理,所以最近已经被小伙伴们给玩坏了。 他的文风可能是这样的: 你发现,
程序员:我终于知道post和get的区别
IT界知名的程序员曾说:对于那些月薪三万以下,自称IT工程师的码农们,其实我们从来没有把他们归为我们IT工程师的队伍。他们虽然总是以IT工程师自居,但只是他们一厢情愿罢了。 此话一出,不知激起了多少(码农)程序员的愤怒,却又无可奈何,于是码农问程序员。 码农:你知道get和post请求到底有什么区别? 程序员:你看这篇就知道了。 码农:你月薪三万了? 程序员:嗯。 码农:你是怎么做到的? 程序员:
《程序人生》系列-这个程序员只用了20行代码就拿了冠军
你知道的越多,你不知道的越多 点赞再看,养成习惯GitHub上已经开源https://github.com/JavaFamily,有一线大厂面试点脑图,欢迎Star和完善 前言 这一期不算《吊打面试官》系列的,所有没前言我直接开始。 絮叨 本来应该是没有这期的,看过我上期的小伙伴应该是知道的嘛,双十一比较忙嘛,要值班又要去帮忙拍摄年会的视频素材,还得搞个程序员一天的Vlog,还要写BU
加快推动区块链技术和产业创新发展,2019可信区块链峰会在京召开
      11月8日,由中国信息通信研究院、中国通信标准化协会、中国互联网协会、可信区块链推进计划联合主办,科技行者协办的2019可信区块链峰会将在北京悠唐皇冠假日酒店开幕。   区块链技术被认为是继蒸汽机、电力、互联网之后,下一代颠覆性的核心技术。如果说蒸汽机释放了人类的生产力,电力解决了人类基本的生活需求,互联网彻底改变了信息传递的方式,区块链作为构造信任的技术有重要的价值。   1
程序员把地府后台管理系统做出来了,还有3.0版本!12月7号最新消息:已在开发中有github地址
第一幕:缘起 听说阎王爷要做个生死簿后台管理系统,我们派去了一个程序员…… 996程序员做的梦: 第一场:团队招募 为了应对地府管理危机,阎王打算找“人”开发一套地府后台管理系统,于是就在地府总经办群中发了项目需求。 话说还是中国电信的信号好,地府都是满格,哈哈!!! 经常会有外行朋友问:看某网站做的不错,功能也简单,你帮忙做一下? 而这次,面对这样的需求,这个程序员
Android 9.0系统新特性,对刘海屏设备进行适配
其实Android 9.0系统已经是去年推出的“老”系统了,这个系统中新增了一个比较重要的特性,就是对刘海屏设备进行了支持。一直以来我也都有打算针对这个新特性好好地写一篇文章,但是为什么直到拖到了Android 10.0系统都发布了才开始写这篇文章呢?当然,一是因为我这段时间确实比较忙,今年几乎绝大部分的业余时间都放到写新书上了。但是最主要的原因并不是这个,而是因为刘海屏设备的适配存在一定的特殊性
网易云6亿用户音乐推荐算法
网易云音乐是音乐爱好者的集聚地,云音乐推荐系统致力于通过 AI 算法的落地,实现用户千人千面的个性化推荐,为用户带来不一样的听歌体验。 本次分享重点介绍 AI 算法在音乐推荐中的应用实践,以及在算法落地过程中遇到的挑战和解决方案。 将从如下两个部分展开: AI 算法在音乐推荐中的应用 音乐场景下的 AI 思考 从 2013 年 4 月正式上线至今,网易云音乐平台持续提供着:乐屏社区、UGC
【技巧总结】位运算装逼指南
位算法的效率有多快我就不说,不信你可以去用 10 亿个数据模拟一下,今天给大家讲一讲位运算的一些经典例子。不过,最重要的不是看懂了这些例子就好,而是要在以后多去运用位运算这些技巧,当然,采用位运算,也是可以装逼的,不信,你往下看。我会从最简单的讲起,一道比一道难度递增,不过居然是讲技巧,那么也不会太难,相信你分分钟看懂。 判断奇偶数 判断一个数是基于还是偶数,相信很多人都做过,一般的做法的代码如下
日均350000亿接入量,腾讯TubeMQ性能超过Kafka
整理 | 夕颜出品 | AI科技大本营(ID:rgznai100) 【导读】近日,腾讯开源动作不断,相继开源了分布式消息中间件TubeMQ,基于最主流的 OpenJDK8开发的
8年经验面试官详解 Java 面试秘诀
    作者 | 胡书敏 责编 | 刘静 出品 | CSDN(ID:CSDNnews) 本人目前在一家知名外企担任架构师,而且最近八年来,在多家外企和互联网公司担任Java技术面试官,前后累计面试了有两三百位候选人。在本文里,就将结合本人的面试经验,针对Java初学者、Java初级开发和Java开发,给出若干准备简历和准备面试的建议。   Java程序员准备和投递简历的实
面试官如何考察你的思维方式?
1.两种思维方式在求职面试中,经常会考察这种问题:北京有多少量特斯拉汽车? 某胡同口的煎饼摊一年能卖出多少个煎饼? 深圳有多少个产品经理? 一辆公交车里能装下多少个乒乓球? 一
碎片化的时代,如何学习
今天周末,和大家聊聊学习这件事情。 在如今这个社会,我们的时间被各类 APP 撕的粉碎。 刷知乎、刷微博、刷朋友圈; 看论坛、看博客、看公号; 等等形形色色的信息和知识获取方式一个都不错过。 貌似学了很多,但是却感觉没什么用。 要解决上面这些问题,首先要分清楚一点,什么是信息,什么是知识。 那什么是信息呢? 你一切听到的、看到的,都是信息,比如微博上的明星出轨、微信中的表情大战、抖音上的...
so easy! 10行代码写个"狗屁不通"文章生成器
前几天,GitHub 有个开源项目特别火,只要输入标题就可以生成一篇长长的文章。背后实现代码一定很复杂吧,里面一定有很多高深莫测的机器学习等复杂算法不过,当我看了源代码之后这程序不到50
知乎高赞:中国有什么拿得出手的开源软件产品?(整理自本人原创回答)
知乎高赞:中国有什么拿得出手的开源软件产品? 在知乎上,有个问题问“中国有什么拿得出手的开源软件产品(在 GitHub 等社区受欢迎度较好的)?” 事实上,还不少呢~ 本人于2019.7.6进行了较为全面的 回答 - Bravo Yeung,获得该问题下回答中得最高赞(236赞和1枚专业勋章),对这些受欢迎的 Github 开源项目分类整理如下: 分布式计算、云平台相关工具类 1.SkyWalk
MySQL数据库总结
文章目录一、数据库简介二、MySQL数据类型(5.5版本)三、Sql语句(1)Sql语句简介(2)数据定义语言DDLcreate,alter,drop(3)数据操纵语言DMLupdate,insert,delete(4)数据控制语言DCLgrant,revoke(5)数据查询语言DQLselect(6)分组查询与分页查询group by,limit四、完整性约束(单表)五、多表查询六、MySQL数
记一次腾讯面试:进程之间究竟有哪些通信方式?如何通信? ---- 告别死记硬背
有一次面试的时候,被问到进程之间有哪些通信方式,不过由于之前没深入思考且整理过,说的并不好。想必大家也都知道进程有哪些通信方式,可是我猜很多人都是靠着”背“来记忆的,所以今天的这篇文章,讲给大家详细着讲解他们是如何通信的,让大家尽量能够理解他们之间的区别、优缺点等,这样的话,以后面试官让你举例子,你也能够顺手拈来。 1、管道 我们来看一条 Linux 的语句 netstat -tulnp | gr...
20行Python代码爬取王者荣耀全英雄皮肤
引言 王者荣耀大家都玩过吧,没玩过的也应该听说过,作为时下最火的手机MOBA游戏,咳咳,好像跑题了。我们今天的重点是爬取王者荣耀所有英雄的所有皮肤,而且仅仅使用20行Python代码即可完成。 准备工作 爬取皮肤本身并不难,难点在于分析,我们首先得得到皮肤图片的url地址,话不多说,我们马上来到王者荣耀的官网: 我们点击英雄资料,然后随意地选择一位英雄,接着F12打开调试台,找到英雄原皮肤的图片
程序设计的5个底层逻辑,决定你能走多快
阿里妹导读:肉眼看计算机是由CPU、内存、显示器这些硬件设备组成,但大部分人从事的是软件开发工作。计算机底层原理就是连通硬件和软件的桥梁,理解计算机底层原理才能在程序设计这条路上越走越快,越走越轻松。从操作系统层面去理解高级编程语言的执行过程,会发现好多软件设计都是同一种套路,很多语言特性都依赖于底层机制,今天董鹏为你一一揭秘。 结合 CPU 理解一行 Java 代码是怎么执行的 根据冯·诺...
相关热词 c# 二进制截断字符串 c#实现窗体设计器 c#检测是否为微信 c# plc s1200 c#里氏转换原则 c# 主界面 c# do loop c#存为组套 模板 c# 停掉协程 c# rgb 读取图片
立即提问