Python,类中字典变量的疑问

各位python大神,我刚开始看python,遇到一个问题,有点疑惑,请高手给解下惑
我在一个类中定义了一个字符串变量和一个字典变量

我实例化几个实例,分别对字符串变量和字典变量进行赋值,然后存储到一个链表中,最后打印发现字典变量总是会被最后一次赋值覆盖掉。

字典变量在类中就相当于static了吗?

SysConfig 数据库中的信息读取后,保存到这个SysConfig中

 # coding=utf-8

import json

class SysConfig(object):
    syscfg = {
        "id":0,
        "name":"",
        "switcherid":0,
        "centreconsoleid":0,
        "port":"",
        "teacher":"",
        "teacher_panorama":"",
        "student":"",
        "student_panorama":"",
        "blackboard":"",
        "course":""
    }

    test="test";

    def __init__(self):
        pass;

    def getSysCfg(self):
        return self.syscfg;

    def setSyscfg(self, cfg):
        for key in cfg:
            self.syscfg[key] = cfg[key];

        self.test = str(cfg);

数据库操作类:

 import pymysql
import sys
import sysCfg

class GSDBHelper(object):

    dbConfig = {
          'host':'127.0.0.1',
          'port':3306,
          'user':'root',
          'password':'root',
          'db':'db_test',
          'charset':'utf8mb4',
          'cursorclass':pymysql.cursors.DictCursor
          }

    dbsysList = [];

    def __init__(self):
        pass;

    def setConfig(self, host, port, user, password, db):
        self.dbConfig["host"] = host;
        self.dbConfig["port"] = port;
        self.dbConfig["user"] = user;
        self.dbConfig["password"] = password;
        self.dbConfig["db"] = db;

        # print("db config...", self.dbConfig);

    def getSysConfig(self):
        try:
            self.dbsysList.clear();

            conn = pymysql.connect(**self.dbConfig);
            cur = conn.cursor();
            cur.execute("select * from dbt_sys");
            rows = cur.fetchall();
            for row in rows:
                dbsys = sysCfg.SysConfig();
                dbsys.setSyscfg(row);
                self.dbsysList.append(dbsys);

            for sysItem in self.dbsysList:
                print(sysItem, " syscfg... ", sysItem.getSysCfg());
                print(sysItem, " test 字符串  ", sysItem.test);

        except:
            info = sys.exc_info();
            print(info[0], ":", info[1]);
        finally:
            if "conn" in locals().keys():
                conn.close();

执行:

 import dbhelper

if __name__ == '__main__':
    db = dbhelper.GSDBHelper();
    db.setConfig("127.0.0.1", 3306, "root", "ab123@", "db_gsaid");
    db.getSysConfig();

运行结果:

 <sysCfg.SysConfig object at 0x0293A2D0>  syscfg...  {'id': 2, 'name': None, 'switcherid': 1, 'port': '{"sprite":99991, "cameractrl1":2001, "cameractrl2":2002, "cameractrl3":2003, "cameractrl4":2004}', 'centreconsoleid': None, 'teacher': None, 'teacher_panorama': None, 'student': None, 'student_panorama': None, 'blackboard': None, 'course': None, 'info': None}
<sysCfg.SysConfig object at 0x0293A2D0>  test 字符串   {'id': 1, 'name': '', 'switcherid': 1, 'centreconsoleid': 1, 'port': '{"sprite":9999, "cameractrl1":2001, "cameractrl2":20021, "cameractrl3":2003, "cameractrl4":2004}', 'teacher': '{"time":{"min":2000, "max":600000}, "lostormax_cutto":"student_panorama", "effect1":{"id":16, "pect":100}, "effect2":{"id":16, "pect":0}}', 'teacher_panorama': '{"time":{"min":2000, "max":100000}, "effect1":{"id":16, "pect":100}, "effect2":{"id":16, "pect":0} }', 'student': '{"time":{"min":2000, "max":100000} , "effect1":{"id":16, "pect":100}, "effect2":{"id":16, "pect":0}}', 'student_panorama': '{"time":{"min":2000, "max":100000} , "effect1":{"id":16, "pect":100}, "effect2":{"id":16, "pect":0}}', 'blackboard': '{"time":{"min":2000, "max":100000}, "effect1":{"id":16, "pect":100}, "effect2":{"id":16, "pect":0} }', 'course': '{"time":{"min":2000, "max":100000} , "effect1":{"id":16, "pect":100}, "effect2":{"id":16, "pect":0}}', 'info': None}
<sysCfg.SysConfig object at 0x0293A2F0>  syscfg...  {'id': 2, 'name': None, 'switcherid': 1, 'port': '{"sprite":99991, "cameractrl1":2001, "cameractrl2":2002, "cameractrl3":2003, "cameractrl4":2004}', 'centreconsoleid': None, 'teacher': None, 'teacher_panorama': None, 'student': None, 'student_panorama': None, 'blackboard': None, 'course': None, 'info': None}
<sysCfg.SysConfig object at 0x0293A2F0>  test 字符串   {'id': 2, 'name': None, 'switcherid': 1, 'centreconsoleid': None, 'port': '{"sprite":99991, "cameractrl1":2001, "cameractrl2":2002, "cameractrl3":2003, "cameractrl4":2004}', 'teacher': None, 'teacher_panorama': None, 'student': None, 'student_panorama': None, 'blackboard': None, 'course': None, 'info': None}

1个回答

原来这样。。。

要理解一下python跟别的面向对象语言的不同,应该这样:

 class people:

    def __init__(self, s, n):
       self. m = {}
       self.m[s] = n
    def show(self):
       self. m = {}
       print(self.m)

p1 = people('年龄', 1)
p2 = people('年龄', 2)

p1.show()
p2.show()
Csdn user default icon
上传中...
上传图片
插入图片
抄袭、复制答案,以达到刷声望分或其他目的的行为,在CSDN问答是严格禁止的,一经发现立刻封号。是时候展现真正的技术了!
其他相关推荐
python字典中键值使用变量?
``` dict={'1':["TEST1"], '2':["TEST2"], '3':["TEST3"], } i=1 x=dict["%d"] %d i print(x) ``` 我想使用一个变量去引用键值,结果发现这样操作不可行,请问下有大神知道该怎么解决么
python构建get请求如何把字典参数作为一个变量传进去
![图片说明](https://img-ask.csdn.net/upload/201912/05/1575555981_756412.png)
python字典的值为字母,怎样变成整数使其一一对应关系
多张excel表格,要提取excel表格的某一列做数据分析,表格不是固定的格式,想通过变量来提取,excel的列是以字母排序的 比如: import pandas as pd dic={"长":D,"宽":G} #长的数据在D列,宽的数据在G列 Length=dic.get("长") Width=dic.get("宽") df=pd.read_excel("d:/ex.xlsx") df1=df.iloc[:,[Length,Width]] ...... _问题是怎样把 Length=dic.get("长") Width=dic.get("宽") 变成 Length=3 Width=6 即把字典的字母值变成数字值,A-Z和0-25一一对应关系, 值为A时变成值为0,值为B时变成值为1,依此类推!谢谢大家!_
python+mysql:在插入语句中使用变量,变量中含有多个引号,和双引号,该如何使用sql语句
数据是一个字典类型: my_dict={"2017-07" : [ {"origin" : "LJ","price" : 44267,"crawl_date" : "2017-09-01"}]} item={'name':'万科','city_name':'深圳','location':'龙岗','price':str(my_dict)} 把字典转化为str后,变成了item。 插入语句: sql2='''insert into house ( name,city_name,location,price) values ('%s','%s','%s','%s')''' %(item['name'], item['city_name'], item['location'], item['price'].encode('utf-8')) 因为price中含有很多引号,{"2017-07" : [ {"origin" : "LJ","price" : 44267,"crawl_date" : "2017-09-01"}]} 所以运行的时候回提示说sql错误,要怎样才能把这些字符显示为字符而非转义字符 ?
为什么全局变量在进程中赋值后,线程中接收不到。
我定义了一个global全局变量,在一个进程中对这个变量赋值,然后想在一个线程中使用这个变量,但是发现线程中并没有收到应该被赋值的全局变量,不知道是什么原因,以下是我的部分代码 ``` global a a = {} class Control_system(QMainWindow, Ui_Control_system): socketQueue = multiprocessing.Queue() def __init__(self, parent=None): super().__init__(parent) self.setupUi(self) self.p1 = multiprocessing.Process(target=Control_system.connect, args=(self.socketQueue,)) self.p1.start() self.timer = QTimer(self) self.timer.timeout.connect(self.client) self.timer.start(1000) @staticmethod def connect(queue): ip_port = ("192.168.1.251", 8880) s = socketserver.ThreadingTCPServer(ip_port, MyServer) s.serve_forever() def client(self): print(a) # 这里的a还是空字典 class MyServer(socketserver.BaseRequestHandler): def handle(self): print("conn is :", self.request) # conn print("addr is :", self.client_address) # addr a[self.client_address] = self.request print("a:",a) # 这里的a是有值的 if __name__ == '__main__': if not QApplication.instance(): app = QApplication(sys.argv) else: app = QApplication.instance() w = Control_system() w.show() sys.exit(app.exec()) ```
C++调用python脚本(test.py这个脚本中import numpy)程序崩溃
我想在c++中调用python的一个脚本,这个脚本中我只是import了一个numpy就报错了,而如果是简单的脚本(没有import第三方库)就不会出错,我已经把: INCLUDEPATH += C:/Python27/include/ LIBS += C:/Python27/libs/python27.lib 添加进去了, ``` pyrun_simplestring("import sys"); pyrun_simplestring("import numpy"); pyrun_simplestring("sys.path.append('c:\python27\lib\site-packages\')"); pyerr_print(); pyobject * pmodule = null; pyobject * pfunc = null; pmodule =pyimport_importmodule("test_my"); //test001:python文件名 pfunc= pyobject_getattrstring(pmodule, "testdict"); //add:python文件中的函数名 pyobject *pargs = pytuple_new(1); pyobject *pdict = pydict_new(); //创建字典类型变量 pydict_setitemstring(pdict, "name", py_buildvalue("s", "wangyao")); //往字典类型变量中填充数据 pydict_setitemstring(pdict, "age", py_buildvalue("i", 25)); //往字典类型变量中填充数据 pytuple_setitem(pargs, 0, pdict); //0---序号 将字典类型变量添加到参数元组中 pyobject *preturn = null; preturn = pyeval_callobject(pfunc, pargs); //调用函数 int size = pydict_size(preturn); cout << "返回字典的大小为: " << size << endl; pyobject *pnewage = pydict_getitemstring(preturn, "age"); int newage; pyarg_parse(pnewage, "i", &newage); cout << "true age: " << newage << endl; py_finalize(); ``` 这是python的脚本: ``` #import numpy as np def HelloWorld(): print "Hello World" def add(a, b): #tmp=np.random.randint(10,88) return a+b def TestDict(dict): print dict dict["Age"] = 17 return dict class Person: def greet(self, greetStr): print greetStr #print add(5,7) #a = raw_input("Enter To Continue...") ``` 老是报错,但如果我把import numpy去掉就没问题,求大神解答,困扰好久了~~~
python子类继承父类时找不到__init__
![UML](https://img-ask.csdn.net/upload/201901/06/1546726233_455371.png) 类PatientManagement有3个方法 1. add_patient 添加一个新病人(属于类 Patient)。病人对应一个ID,第一个ID为0,第二个为1,以此类推 2. get_patient 返回病人的ID 3.get_statistics 返回女病人和男病人的数量,以及平均BMI 病人和ID以字典的方式储存在变量 patients 中,以{id: Patient}形式,最后一个使用的ID储存在变量 last_used 中。 类 Patient 有两个方法 1. to_string 打印 名 姓 出生日 月 年 性别 身高(cm)体重(kg) 2. get_bmi 计算这个病人的BMI,BMI = weight in KG / (height in M)² get_statistics调用get_bmi 且不再进行一变BMI计算 运行结果是这样 ``` class Patient(object): def __init__(self, first_name, last_name, birth_year, birth_month, birth_day, sex, body_height, body_weight): self.first_name = first_name self.last_name = last_name self.year = birth_year self.month = birth_month self.day = birth_day self.sex = sex self.height = body_height self.weight = body_weight def to_string(self): print("Name: {} {}, geboren am {}.{}.{}, geschlecht:{}, {}cm, {}kg".format(self.first_name, self.last_name, self.day, self.month, self.year, self.sex, self.height, self.weight)) def get_bmi(self): return self.weight / ((self.height/100) ** 2) a =[] b = [] patients = {} i = 0 class PatientManagement(Patient): fpCount = 0 mpCount = 0 def add_patient(self): patients[i] = Patient() i += 1 if self.sex == f: PatientManagement.fpCount += 1 if self.sex == m: PatientManagement.mpCount += 1 for k in patients.keys(): a.append(k) b.append(patients[k]) last_used = a[-1] def get_patient(self): new_dict = {v:k for k,v in patients.items()} patient_id = new_dict[Patient()] return patients_id def get_statistics(self): bmi = [] for patient in b: bmi.append(patient.get_bmi()) avg_bmi = np.mean(bmi) return "female Patients:{}, male Patients:{}, average BMI: {}".format(PatientManagement.fpCount, PatientManagement.mpCount, avg_bmi) ``` ``` p1 = Patient("Meier", "Lena", 1988, 12, 12, 'f', 164, 50) pm1 = PatientManagement(p1) pm1.add_patient() pm1.get_patient() ``` 这里报错是问题在哪里? ``` --------------------------------------------------------------------------- TypeError Traceback (most recent call last) <ipython-input-141-ebb011904724> in <module>() 1 p1 = Patient("Meier", "Lena", 1988, 12, 12, 'f', 164, 50) ----> 2 pm1 = PatientManagement(p1) 3 pm1.add_patient() 4 pm1.get_patient() TypeError: __init__() missing 7 required positional arguments: 'last_name', 'birth_year', 'birth_month', 'birth_day', 'sex', 'body_height', and 'body_weight' ```
python爬新浪新闻内容,为什么运行完stock里面为空……
#! /usr/bin/env python #coding=utf-8 from scrapy.selector import Selector from scrapy.http import Request import re,os from bs4 import BeautifulSoup from scrapy.spider import Spider import urllib2,thread #处理编码问题 import sys reload(sys) sys.setdefaultencoding('gb18030') #flag的作用是保证第一次爬取的时候不进行单个新闻页面内容的爬取 flag=1 projectpath='C:\\Users\DELL\\Desktop\\pythonproject\\mypro\\' def loop(*response): sel = Selector(response[0]) #get title title = sel.xpath('//h1/text()').extract() #get pages pages=sel.xpath('//div[@id="artibody"]//p/text()').extract() #get chanel_id & comment_id s=sel.xpath('//meta[@name="comment"]').extract() #comment_id = channel[index+3:index+15] index2=len(response[0].url) news_id=response[0].url[index2-14:index2-6] comment_id='31-1-'+news_id #评论内容都在这个list中 cmntlist=[] page=1 #含有新闻url,标题,内容,评论的文件 file2=None #该变量的作用是当某新闻下存在非手机用户评论时置为False is_all_tel=True while((page==1) or (cmntlist != [])): tel_count=0 #each page tel_user_count #提取到的评论url url="http://comment5.news.sina.com.cn/page/info?version=1&format=js&channel=cj&newsid="+str(comment_id)+"&group=0&compress=1&ie=gbk&oe=gbk&page="+str(page)+"&page_size=100" url_contain=urllib2.urlopen(url).read() b='={' after = url_contain[url_contain.index(b)+len(b)-1:] #字符串中的None对应python中的null,不然执行eval时会出错 after=after.replace('null','None') #转换为字典变量text text=eval(after) if 'cmntlist' in text['result']: cmntlist=text['result']['cmntlist'] else: cmntlist=[] if cmntlist != [] and (page==1): filename=str(comment_id)+'.txt' path=projectpath+'stock\\' +filename file2=open(path,'a+') news_content=str('') for p in pages: news_content=news_content+p+'\n' item="<url>"+response[0].url+"</url>"+'\n\n'+"<title>"+str(title[0])+"</title>\n\n"+"<content>\n"+str(news_content)+"</content>\n\n<comment>\n" file2.write(item) if cmntlist != []: content='' for status_dic in cmntlist: if status_dic['uid']!='0': is_all_tel=False #这一句视编码情况而定,在这里去掉decode和encode也行 s=status_dic['content'].decode('UTF-8').encode('GBK') #见另一篇博客“三张图” s=s.replace("'",'"') s=s.replace("\n",'') s1="u'"+s+"'" try: ss=eval(s1) except: try: s1='u"'+s+'"' ss=eval(s1) except: return content=content+status_dic['time']+'\t'+status_dic['uid']+'\t'+ss+'\n' #当属于手机用户时 else: tel_count=tel_count+1 #当一个page下不都是手机用户时,这里也可以用is_all_tel进行判断,一种是用开关的方式,一种是统计的方式 #算了不改了 if tel_count!=len(cmntlist): file2.write(content) page=page+1 #while loop end here if file2!=None: #当都是手机用户时,移除文件,否则写入"</comment>"到文件尾 if is_all_tel: file2.close() try: os.remove(file2.name) except WindowsError: pass else: file2.write("</comment>") file2.close() class DmozSpider(Spider): name = "stock" allowed_domains = ["sina.com.cn"] #在本程序中,start_urls并不重要,因为并没有解析 start_urls = [ "http://news.sina.com.cn/" ] global projectpath if os.path.exists(projectpath+'stock'): pass else: os.mkdir(projectpath+'stock') def parse(self, response): #这个scrapy.selector.Selector是个不错的处理字符串的类,python对编码很严格,它却处理得很好 #在做这个爬虫的时候,碰到很多奇奇怪怪的编码问题,主要是中文,试过很多既有的类,BeautifulSoup处理得也不是很好 sel = Selector(response) global flag if(flag==1): flag=2 page=1 while page<260: url="http://roll.finance.sina.com.cn/finance/zq1/index_" url=url+str(page)+".shtml" #伪装为浏览器 user_agent = 'Mozilla/4.0 (compatible; MSIE 5.5; Windows NT)' headers = { 'User-Agent' : user_agent } req = urllib2.Request(url, headers=headers) response = urllib2.urlopen(req) url_contain = response.read() #利用BeautifulSoup进行文档解析 soup = BeautifulSoup(url_contain) params = soup.findAll('div',{'class':'listBlk'}) if os.path.exists(projectpath+'stock\\'+'link'): pass else: os.mkdir(projectpath+'stock\\'+'link') filename='link.txt' path=projectpath+'stock\\link\\' + filename filelink=open(path,'a+') for params_item in params: persons = params_item.findAll('li') for item in persons: href=item.find('a') mil_link= href.get('href') filelink.write(str(mil_link)+'\n') #递归调用parse,传入新的爬取url yield Request(mil_link, callback=self.parse) page=page+1 #对单个新闻页面新建线程进行爬取 if flag!=1: if (response.status != 404) and (response.status != 502): thread.start_new_thread(loop,(response,))
Python多线程问题,target以及kwargs传参出错,请问应该怎么写
``` def A(a,b,c): 代码块省略 def B(a,b,c,d): 代码块省略 def thread(self,arg*): t1 = threading.Thread(target=A,args=(a,b,c)) ``` 问题一:这里我想参数target=需要多开线程的方法名,然后我随便定义一个变量作为方法名参数传到target里面,,不行,,程序报错。求正确的传参方法,难不成我要为每一个要多开的方法都要多写一个多线程方法,,仅仅改个方法名参数?这太麻烦了 问题二:args这个参数我想改成kwargs字典形式的参数应该怎么改。t1 = threading.Thread(target=A,kwargs={a=1,b=2,c=3}) 这样报错,,然后改成t1 = threading.Thread(target=A,kwargs={‘a=1,b=2,c=3’})再改成一对参数和值加一组单引号还是报错,,求正确格式。。
基于tensorflow的pix2pix代码中如何做到输入图像和输出图像分辨率不一致
问题:例如在自己制作了成对的输入(input256×256 target 200×256)后,如何让输入图像和输出图像分辨率不一致,例如成对图像中:input的分辨率是256×256, output 和target都是200×256,需要修改哪里的参数。 论文参考:《Image-to-Image Translation with Conditional Adversarial Networks》 代码参考:https://blog.csdn.net/MOU_IT/article/details/80802407?utm_source=blogxgwz0 # coding=utf-8 from __future__ import absolute_import from __future__ import division from __future__ import print_function import os os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' import tensorflow as tf import numpy as np import os import glob import random import collections import math import time # https://github.com/affinelayer/pix2pix-tensorflow train_input_dir = "D:/Project/pix2pix-tensorflow-master/facades/train/" # 训练集输入 train_output_dir = "D:/Project/pix2pix-tensorflow-master/facades/train_out/" # 训练集输出 test_input_dir = "D:/Project/pix2pix-tensorflow-master/facades/val/" # 测试集输入 test_output_dir = "D:/Project/pix2pix-tensorflow-master/facades/test_out/" # 测试集的输出 checkpoint = "D:/Project/pix2pix-tensorflow-master/facades/train_out/" # 保存结果的目录 seed = None max_steps = None # number of training steps (0 to disable) max_epochs = 200 # number of training epochs progress_freq = 50 # display progress every progress_freq steps trace_freq = 0 # trace execution every trace_freq steps display_freq = 50 # write current training images every display_freq steps save_freq = 500 # save model every save_freq steps, 0 to disable separable_conv = False # use separable convolutions in the generator aspect_ratio = 1 #aspect ratio of output images (width/height) batch_size = 1 # help="number of images in batch") which_direction = "BtoA" # choices=["AtoB", "BtoA"]) ngf = 64 # help="number of generator filters in first conv layer") ndf = 64 # help="number of discriminator filters in first conv layer") scale_size = 286 # help="scale images to this size before cropping to 256x256") flip = True # flip images horizontally no_flip = True # don't flip images horizontally lr = 0.0002 # initial learning rate for adam beta1 = 0.5 # momentum term of adam l1_weight = 100.0 # weight on L1 term for generator gradient gan_weight = 1.0 # weight on GAN term for generator gradient output_filetype = "png" # 输出图像的格式 EPS = 1e-12 # 极小数,防止梯度为损失为0 CROP_SIZE = 256 # 图片的裁剪大小 # 命名元组,用于存放加载的数据集合创建好的模型 Examples = collections.namedtuple("Examples", "paths, inputs, targets, count, steps_per_epoch") Model = collections.namedtuple("Model", "outputs, predict_real, predict_fake, discrim_loss, discrim_grads_and_vars, gen_loss_GAN, gen_loss_L1, gen_grads_and_vars, train") # 图像预处理 [0, 1] => [-1, 1] def preprocess(image): with tf.name_scope("preprocess"): return image * 2 - 1 # 图像后处理[-1, 1] => [0, 1] def deprocess(image): with tf.name_scope("deprocess"): return (image + 1) / 2 # 判别器的卷积定义,batch_input为 [ batch , 256 , 256 , 6 ] def discrim_conv(batch_input, out_channels, stride): # [ batch , 256 , 256 , 6 ] ===>[ batch , 258 , 258 , 6 ] padded_input = tf.pad(batch_input, [[0, 0], [1, 1], [1, 1], [0, 0]], mode="CONSTANT") ''' [0,0]: 第一维batch大小不扩充 [1,1]:第二维图像宽度左右各扩充一列,用0填充 [1,1]:第三维图像高度上下各扩充一列,用0填充 [0,0]:第四维图像通道不做扩充 ''' return tf.layers.conv2d(padded_input, out_channels, kernel_size=4, strides=(stride, stride), padding="valid", kernel_initializer=tf.random_normal_initializer(0, 0.02)) # 生成器的卷积定义,卷积核为4*4,步长为2,输出图像为输入的一半 def gen_conv(batch_input, out_channels): # [batch, in_height, in_width, in_channels] => [batch, out_height, out_width, out_channels] initializer = tf.random_normal_initializer(0, 0.02) if separable_conv: return tf.layers.separable_conv2d(batch_input, out_channels, kernel_size=4, strides=(2, 2), padding="same", depthwise_initializer=initializer, pointwise_initializer=initializer) else: return tf.layers.conv2d(batch_input, out_channels, kernel_size=4, strides=(2, 2), padding="same", kernel_initializer=initializer) # 生成器的反卷积定义 def gen_deconv(batch_input, out_channels): # [batch, in_height, in_width, in_channels] => [batch, out_height, out_width, out_channels] initializer = tf.random_normal_initializer(0, 0.02) if separable_conv: _b, h, w, _c = batch_input.shape resized_input = tf.image.resize_images(batch_input, [h * 2, w * 2], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR) return tf.layers.separable_conv2d(resized_input, out_channels, kernel_size=4, strides=(1, 1), padding="same", depthwise_initializer=initializer, pointwise_initializer=initializer) else: return tf.layers.conv2d_transpose(batch_input, out_channels, kernel_size=4, strides=(2, 2), padding="same", kernel_initializer=initializer) # 定义LReLu激活函数 def lrelu(x, a): with tf.name_scope("lrelu"): # adding these together creates the leak part and linear part # then cancels them out by subtracting/adding an absolute value term # leak: a*x/2 - a*abs(x)/2 # linear: x/2 + abs(x)/2 # this block looks like it has 2 inputs on the graph unless we do this x = tf.identity(x) return (0.5 * (1 + a)) * x + (0.5 * (1 - a)) * tf.abs(x) # 批量归一化图像 def batchnorm(inputs): return tf.layers.batch_normalization(inputs, axis=3, epsilon=1e-5, momentum=0.1, training=True, gamma_initializer=tf.random_normal_initializer(1.0, 0.02)) # 检查图像的维度 def check_image(image): assertion = tf.assert_equal(tf.shape(image)[-1], 3, message="image must have 3 color channels") with tf.control_dependencies([assertion]): image = tf.identity(image) if image.get_shape().ndims not in (3, 4): raise ValueError("image must be either 3 or 4 dimensions") # make the last dimension 3 so that you can unstack the colors shape = list(image.get_shape()) shape[-1] = 3 image.set_shape(shape) return image # 去除文件的后缀,获取文件名 def get_name(path): # os.path.basename(),返回path最后的文件名。若path以/或\结尾,那么就会返回空值。 # os.path.splitext(),分离文件名与扩展名;默认返回(fname,fextension)元组 name, _ = os.path.splitext(os.path.basename(path)) return name # 加载数据集,从文件读取-->解码-->归一化--->拆分为输入和目标-->像素转为[-1,1]-->转变形状 def load_examples(input_dir): if input_dir is None or not os.path.exists(input_dir): raise Exception("input_dir does not exist") # 匹配第一个参数的路径中所有的符合条件的文件,并将其以list的形式返回。 input_paths = glob.glob(os.path.join(input_dir, "*.jpg")) # 图像解码器 decode = tf.image.decode_jpeg if len(input_paths) == 0: input_paths = glob.glob(os.path.join(input_dir, "*.png")) decode = tf.image.decode_png if len(input_paths) == 0: raise Exception("input_dir contains no image files") # 如果文件名是数字,则用数字进行排序,否则用字母排序 if all(get_name(path).isdigit() for path in input_paths): input_paths = sorted(input_paths, key=lambda path: int(get_name(path))) else: input_paths = sorted(input_paths) sess = tf.Session() with tf.name_scope("load_images"): # 把我们需要的全部文件打包为一个tf内部的queue类型,之后tf开文件就从这个queue中取目录了, # 如果是训练模式时,shuffle为True path_queue = tf.train.string_input_producer(input_paths, shuffle=True) # Read的输出将是一个文件名(key)和该文件的内容(value,每次读取一个文件,分多次读取)。 reader = tf.WholeFileReader() paths, contents = reader.read(path_queue) # 对文件进行解码并且对图片作归一化处理 raw_input = decode(contents) raw_input = tf.image.convert_image_dtype(raw_input, dtype=tf.float32) # 归一化处理 # 判断两个值知否相等,如果不等抛出异常 assertion = tf.assert_equal(tf.shape(raw_input)[2], 3, message="image does not have 3 channels") ''' 对于control_dependencies这个管理器,只有当里面的操作是一个op时,才会生效,也就是先执行传入的 参数op,再执行里面的op。如果里面的操作不是定义的op,图中就不会形成一个节点,这样该管理器就失效了。 tf.identity是返回一个一模一样新的tensor的op,这会增加一个新节点到gragh中,这时control_dependencies就会生效. ''' with tf.control_dependencies([assertion]): raw_input = tf.identity(raw_input) raw_input.set_shape([None, None, 3]) # 图像值由[0,1]--->[-1, 1] width = tf.shape(raw_input)[1] # [height, width, channels] a_images = preprocess(raw_input[:, :width // 2, :]) # 256*256*3 b_images = preprocess(raw_input[:, width // 2:, :]) # 256*256*3 # 这里的which_direction为:BtoA if which_direction == "AtoB": inputs, targets = [a_images, b_images] elif which_direction == "BtoA": inputs, targets = [b_images, a_images] else: raise Exception("invalid direction") # synchronize seed for image operations so that we do the same operations to both # input and output images seed = random.randint(0, 2 ** 31 - 1) # 图像预处理,翻转、改变形状 with tf.name_scope("input_images"): input_images = transform(inputs) with tf.name_scope("target_images"): target_images = transform(targets) # 获得输入图像、目标图像的batch块 paths_batch, inputs_batch, targets_batch = tf.train.batch([paths, input_images, target_images], batch_size=batch_size) steps_per_epoch = int(math.ceil(len(input_paths) / batch_size)) return Examples( paths=paths_batch, # 输入的文件名块 inputs=inputs_batch, # 输入的图像块 targets=targets_batch, # 目标图像块 count=len(input_paths), # 数据集的大小 steps_per_epoch=steps_per_epoch, # batch的个数 ) # 图像预处理,翻转、改变形状 def transform(image): r = image if flip: r = tf.image.random_flip_left_right(r, seed=seed) # area produces a nice downscaling, but does nearest neighbor for upscaling # assume we're going to be doing downscaling here r = tf.image.resize_images(r, [scale_size, scale_size], method=tf.image.ResizeMethod.AREA) offset = tf.cast(tf.floor(tf.random_uniform([2], 0, scale_size - CROP_SIZE + 1, seed=seed)), dtype=tf.int32) if scale_size > CROP_SIZE: r = tf.image.crop_to_bounding_box(r, offset[0], offset[1], CROP_SIZE, CROP_SIZE) elif scale_size < CROP_SIZE: raise Exception("scale size cannot be less than crop size") return r # 创建生成器,这是一个编码解码器的变种,输入输出均为:256*256*3, 像素值为[-1,1] def create_generator(generator_inputs, generator_outputs_channels): layers = [] # encoder_1: [batch, 256, 256, in_channels] => [batch, 128, 128, ngf] with tf.variable_scope("encoder_1"): output = gen_conv(generator_inputs, ngf) # ngf为第一个卷积层的卷积核核数量,默认为 64 layers.append(output) layer_specs = [ ngf * 2, # encoder_2: [batch, 128, 128, ngf] => [batch, 64, 64, ngf * 2] ngf * 4, # encoder_3: [batch, 64, 64, ngf * 2] => [batch, 32, 32, ngf * 4] ngf * 8, # encoder_4: [batch, 32, 32, ngf * 4] => [batch, 16, 16, ngf * 8] ngf * 8, # encoder_5: [batch, 16, 16, ngf * 8] => [batch, 8, 8, ngf * 8] ngf * 8, # encoder_6: [batch, 8, 8, ngf * 8] => [batch, 4, 4, ngf * 8] ngf * 8, # encoder_7: [batch, 4, 4, ngf * 8] => [batch, 2, 2, ngf * 8] ngf * 8, # encoder_8: [batch, 2, 2, ngf * 8] => [batch, 1, 1, ngf * 8] ] # 卷积的编码器 for out_channels in layer_specs: with tf.variable_scope("encoder_%d" % (len(layers) + 1)): # 对最后一层使用激活函数 rectified = lrelu(layers[-1], 0.2) # [batch, in_height, in_width, in_channels] => [batch, in_height/2, in_width/2, out_channels] convolved = gen_conv(rectified, out_channels) output = batchnorm(convolved) layers.append(output) layer_specs = [ (ngf * 8, 0.5), # decoder_8: [batch, 1, 1, ngf * 8] => [batch, 2, 2, ngf * 8 * 2] (ngf * 8, 0.5), # decoder_7: [batch, 2, 2, ngf * 8 * 2] => [batch, 4, 4, ngf * 8 * 2] (ngf * 8, 0.5), # decoder_6: [batch, 4, 4, ngf * 8 * 2] => [batch, 8, 8, ngf * 8 * 2] (ngf * 8, 0.0), # decoder_5: [batch, 8, 8, ngf * 8 * 2] => [batch, 16, 16, ngf * 8 * 2] (ngf * 4, 0.0), # decoder_4: [batch, 16, 16, ngf * 8 * 2] => [batch, 32, 32, ngf * 4 * 2] (ngf * 2, 0.0), # decoder_3: [batch, 32, 32, ngf * 4 * 2] => [batch, 64, 64, ngf * 2 * 2] (ngf, 0.0), # decoder_2: [batch, 64, 64, ngf * 2 * 2] => [batch, 128, 128, ngf * 2] ] # 卷积的解码器 num_encoder_layers = len(layers) # 8 for decoder_layer, (out_channels, dropout) in enumerate(layer_specs): skip_layer = num_encoder_layers - decoder_layer - 1 with tf.variable_scope("decoder_%d" % (skip_layer + 1)): if decoder_layer == 0: # first decoder layer doesn't have skip connections # since it is directly connected to the skip_layer input = layers[-1] else: input = tf.concat([layers[-1], layers[skip_layer]], axis=3) rectified = tf.nn.relu(input) # [batch, in_height, in_width, in_channels] => [batch, in_height*2, in_width*2, out_channels] output = gen_deconv(rectified, out_channels) output = batchnorm(output) if dropout > 0.0: output = tf.nn.dropout(output, keep_prob=1 - dropout) layers.append(output) # decoder_1: [batch, 128, 128, ngf * 2] => [batch, 256, 256, generator_outputs_channels] with tf.variable_scope("decoder_1"): input = tf.concat([layers[-1], layers[0]], axis=3) rectified = tf.nn.relu(input) output = gen_deconv(rectified, generator_outputs_channels) output = tf.tanh(output) layers.append(output) return layers[-1] # 创建判别器,输入生成的图像和真实的图像:两个[batch,256,256,3],元素值值[-1,1],输出:[batch,30,30,1],元素值为概率 def create_discriminator(discrim_inputs, discrim_targets): n_layers = 3 layers = [] # 2x [batch, height, width, in_channels] => [batch, height, width, in_channels * 2] input = tf.concat([discrim_inputs, discrim_targets], axis=3) # layer_1: [batch, 256, 256, in_channels * 2] => [batch, 128, 128, ndf] with tf.variable_scope("layer_1"): convolved = discrim_conv(input, ndf, stride=2) rectified = lrelu(convolved, 0.2) layers.append(rectified) # layer_2: [batch, 128, 128, ndf] => [batch, 64, 64, ndf * 2] # layer_3: [batch, 64, 64, ndf * 2] => [batch, 32, 32, ndf * 4] # layer_4: [batch, 32, 32, ndf * 4] => [batch, 31, 31, ndf * 8] for i in range(n_layers): with tf.variable_scope("layer_%d" % (len(layers) + 1)): out_channels = ndf * min(2 ** (i + 1), 8) stride = 1 if i == n_layers - 1 else 2 # last layer here has stride 1 convolved = discrim_conv(layers[-1], out_channels, stride=stride) normalized = batchnorm(convolved) rectified = lrelu(normalized, 0.2) layers.append(rectified) # layer_5: [batch, 31, 31, ndf * 8] => [batch, 30, 30, 1] with tf.variable_scope("layer_%d" % (len(layers) + 1)): convolved = discrim_conv(rectified, out_channels=1, stride=1) output = tf.sigmoid(convolved) layers.append(output) return layers[-1] # 创建Pix2Pix模型,inputs和targets形状为:[batch_size, height, width, channels] def create_model(inputs, targets): with tf.variable_scope("generator"): out_channels = int(targets.get_shape()[-1]) outputs = create_generator(inputs, out_channels) # create two copies of discriminator, one for real pairs and one for fake pairs # they share the same underlying variables with tf.name_scope("real_discriminator"): with tf.variable_scope("discriminator"): # 2x [batch, height, width, channels] => [batch, 30, 30, 1] predict_real = create_discriminator(inputs, targets) # 条件变量图像和真实图像 with tf.name_scope("fake_discriminator"): with tf.variable_scope("discriminator", reuse=True): # 2x [batch, height, width, channels] => [batch, 30, 30, 1] predict_fake = create_discriminator(inputs, outputs) # 条件变量图像和生成的图像 # 判别器的损失,判别器希望V(G,D)尽可能大 with tf.name_scope("discriminator_loss"): # minimizing -tf.log will try to get inputs to 1 # predict_real => 1 # predict_fake => 0 discrim_loss = tf.reduce_mean(-(tf.log(predict_real + EPS) + tf.log(1 - predict_fake + EPS))) # 生成器的损失,生成器希望V(G,D)尽可能小 with tf.name_scope("generator_loss"): # predict_fake => 1 # abs(targets - outputs) => 0 gen_loss_GAN = tf.reduce_mean(-tf.log(predict_fake + EPS)) gen_loss_L1 = tf.reduce_mean(tf.abs(targets - outputs)) gen_loss = gen_loss_GAN * gan_weight + gen_loss_L1 * l1_weight # 判别器训练 with tf.name_scope("discriminator_train"): # 判别器需要优化的参数 discrim_tvars = [var for var in tf.trainable_variables() if var.name.startswith("discriminator")] # 优化器定义 discrim_optim = tf.train.AdamOptimizer(lr, beta1) # 计算损失函数对优化参数的梯度 discrim_grads_and_vars = discrim_optim.compute_gradients(discrim_loss, var_list=discrim_tvars) # 更新该梯度所对应的参数的状态,返回一个op discrim_train = discrim_optim.apply_gradients(discrim_grads_and_vars) # 生成器训练 with tf.name_scope("generator_train"): with tf.control_dependencies([discrim_train]): # 生成器需要优化的参数列表 gen_tvars = [var for var in tf.trainable_variables() if var.name.startswith("generator")] # 定义优化器 gen_optim = tf.train.AdamOptimizer(lr, beta1) # 计算需要优化的参数的梯度 gen_grads_and_vars = gen_optim.compute_gradients(gen_loss, var_list=gen_tvars) # 更新该梯度所对应的参数的状态,返回一个op gen_train = gen_optim.apply_gradients(gen_grads_and_vars) ''' 在采用随机梯度下降算法训练神经网络时,使用 tf.train.ExponentialMovingAverage 滑动平均操作的意义在于 提高模型在测试数据上的健壮性(robustness)。tensorflow 下的 tf.train.ExponentialMovingAverage 需要 提供一个衰减率(decay)。该衰减率用于控制模型更新的速度。该衰减率用于控制模型更新的速度, ExponentialMovingAverage 对每一个(待更新训练学习的)变量(variable)都会维护一个影子变量 (shadow variable)。影子变量的初始值就是这个变量的初始值, shadow_variable=decay×shadow_variable+(1−decay)×variable ''' ema = tf.train.ExponentialMovingAverage(decay=0.99) update_losses = ema.apply([discrim_loss, gen_loss_GAN, gen_loss_L1]) # global_step = tf.train.get_or_create_global_step() incr_global_step = tf.assign(global_step, global_step + 1) return Model( predict_real=predict_real, # 条件变量(输入图像)和真实图像之间的概率值,形状为;[batch,30,30,1] predict_fake=predict_fake, # 条件变量(输入图像)和生成图像之间的概率值,形状为;[batch,30,30,1] discrim_loss=ema.average(discrim_loss), # 判别器损失 discrim_grads_and_vars=discrim_grads_and_vars, # 判别器需要优化的参数和对应的梯度 gen_loss_GAN=ema.average(gen_loss_GAN), # 生成器的损失 gen_loss_L1=ema.average(gen_loss_L1), # 生成器的 L1损失 gen_grads_and_vars=gen_grads_and_vars, # 生成器需要优化的参数和对应的梯度 outputs=outputs, # 生成器生成的图片 train=tf.group(update_losses, incr_global_step, gen_train), # 打包需要run的操作op ) # 保存图像 def save_images(output_dir, fetches, step=None): image_dir = os.path.join(output_dir, "images") if not os.path.exists(image_dir): os.makedirs(image_dir) filesets = [] for i, in_path in enumerate(fetches["paths"]): name, _ = os.path.splitext(os.path.basename(in_path.decode("utf8"))) fileset = {"name": name, "step": step} for kind in ["inputs", "outputs", "targets"]: filename = name + "-" + kind + ".png" if step is not None: filename = "%08d-%s" % (step, filename) fileset[kind] = filename out_path = os.path.join(image_dir, filename) contents = fetches[kind][i] with open(out_path, "wb") as f: f.write(contents) filesets.append(fileset) return filesets # 将结果写入HTML网页 def append_index(output_dir, filesets, step=False): index_path = os.path.join(output_dir, "index.html") if os.path.exists(index_path): index = open(index_path, "a") else: index = open(index_path, "w") index.write("<html><body><table><tr>") if step: index.write("<th>step</th>") index.write("<th>name</th><th>input</th><th>output</th><th>target</th></tr>") for fileset in filesets: index.write("<tr>") if step: index.write("<td>%d</td>" % fileset["step"]) index.write("<td>%s</td>" % fileset["name"]) for kind in ["inputs", "outputs", "targets"]: index.write("<td><img src='images/%s'></td>" % fileset[kind]) index.write("</tr>") return index_path # 转变图像的尺寸、并且将[0,1]--->[0,255] def convert(image): if aspect_ratio != 1.0: # upscale to correct aspect ratio size = [CROP_SIZE, int(round(CROP_SIZE * aspect_ratio))] image = tf.image.resize_images(image, size=size, method=tf.image.ResizeMethod.BICUBIC) # 将数据的类型转换为8位无符号整型 return tf.image.convert_image_dtype(image, dtype=tf.uint8, saturate=True) # 主函数 def train(): # 设置随机数种子的值 global seed if seed is None: seed = random.randint(0, 2 ** 31 - 1) tf.set_random_seed(seed) np.random.seed(seed) random.seed(seed) # 创建目录 if not os.path.exists(train_output_dir): os.makedirs(train_output_dir) # 加载数据集,得到输入数据和目标数据并把范围变为 :[-1,1] examples = load_examples(train_input_dir) print("load successful ! examples count = %d" % examples.count) # 创建模型,inputs和targets是:[batch_size, height, width, channels] # 返回值: model = create_model(examples.inputs, examples.targets) print("create model successful!") # 图像处理[-1, 1] => [0, 1] inputs = deprocess(examples.inputs) targets = deprocess(examples.targets) outputs = deprocess(model.outputs) # 把[0,1]的像素点转为RGB值:[0,255] with tf.name_scope("convert_inputs"): converted_inputs = convert(inputs) with tf.name_scope("convert_targets"): converted_targets = convert(targets) with tf.name_scope("convert_outputs"): converted_outputs = convert(outputs) # 对图像进行编码以便于保存 with tf.name_scope("encode_images"): display_fetches = { "paths": examples.paths, # tf.map_fn接受一个函数对象和集合,用函数对集合中每个元素分别处理 "inputs": tf.map_fn(tf.image.encode_png, converted_inputs, dtype=tf.string, name="input_pngs"), "targets": tf.map_fn(tf.image.encode_png, converted_targets, dtype=tf.string, name="target_pngs"), "outputs": tf.map_fn(tf.image.encode_png, converted_outputs, dtype=tf.string, name="output_pngs"), } with tf.name_scope("parameter_count"): parameter_count = tf.reduce_sum([tf.reduce_prod(tf.shape(v)) for v in tf.trainable_variables()]) # 只保存最新一个checkpoint saver = tf.train.Saver(max_to_keep=20) init = tf.global_variables_initializer() with tf.Session() as sess: sess.run(init) print("parameter_count =", sess.run(parameter_count)) if max_epochs is not None: max_steps = examples.steps_per_epoch * max_epochs # 400X200=80000 # 因为是从文件中读取数据,所以需要启动start_queue_runners() # 这个函数将会启动输入管道的线程,填充样本到队列中,以便出队操作可以从队列中拿到样本。 coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) # 运行训练集 print("begin trainning......") print("max_steps:", max_steps) start = time.time() for step in range(max_steps): def should(freq): return freq > 0 and ((step + 1) % freq == 0 or step == max_steps - 1) print("step:", step) # 定义一个需要run的所有操作的字典 fetches = { "train": model.train } # progress_freq为 50,每50次计算一次三个损失,显示进度 if should(progress_freq): fetches["discrim_loss"] = model.discrim_loss fetches["gen_loss_GAN"] = model.gen_loss_GAN fetches["gen_loss_L1"] = model.gen_loss_L1 # display_freq为 50,每50次保存一次输入、目标、输出的图像 if should(display_freq): fetches["display"] = display_fetches # 运行各种操作, results = sess.run(fetches) # display_freq为 50,每50次保存输入、目标、输出的图像 if should(display_freq): print("saving display images") filesets = save_images(train_output_dir, results["display"], step=step) append_index(train_output_dir, filesets, step=True) # progress_freq为 50,每50次打印一次三种损失的大小,显示进度 if should(progress_freq): # global_step will have the correct step count if we resume from a checkpoint train_epoch = math.ceil(step / examples.steps_per_epoch) train_step = (step - 1) % examples.steps_per_epoch + 1 rate = (step + 1) * batch_size / (time.time() - start) remaining = (max_steps - step) * batch_size / rate print("progress epoch %d step %d image/sec %0.1f remaining %dm" % ( train_epoch, train_step, rate, remaining / 60)) print("discrim_loss", results["discrim_loss"]) print("gen_loss_GAN", results["gen_loss_GAN"]) print("gen_loss_L1", results["gen_loss_L1"]) # save_freq为500,每500次保存一次模型 if should(save_freq): print("saving model") saver.save(sess, os.path.join(train_output_dir, "model"), global_step=step) # 测试 def test(): # 设置随机数种子的值 global seed if seed is None: seed = random.randint(0, 2 ** 31 - 1) tf.set_random_seed(seed) np.random.seed(seed) random.seed(seed) # 创建目录 if not os.path.exists(test_output_dir): os.makedirs(test_output_dir) if checkpoint is None: raise Exception("checkpoint required for test mode") # disable these features in test mode scale_size = CROP_SIZE flip = False # 加载数据集,得到输入数据和目标数据 examples = load_examples(test_input_dir) print("load successful ! examples count = %d" % examples.count) # 创建模型,inputs和targets是:[batch_size, height, width, channels] model = create_model(examples.inputs, examples.targets) print("create model successful!") # 图像处理[-1, 1] => [0, 1] inputs = deprocess(examples.inputs) targets = deprocess(examples.targets) outputs = deprocess(model.outputs) # 把[0,1]的像素点转为RGB值:[0,255] with tf.name_scope("convert_inputs"): converted_inputs = convert(inputs) with tf.name_scope("convert_targets"): converted_targets = convert(targets) with tf.name_scope("convert_outputs"): converted_outputs = convert(outputs) # 对图像进行编码以便于保存 with tf.name_scope("encode_images"): display_fetches = { "paths": examples.paths, # tf.map_fn接受一个函数对象和集合,用函数对集合中每个元素分别处理 "inputs": tf.map_fn(tf.image.encode_png, converted_inputs, dtype=tf.string, name="input_pngs"), "targets": tf.map_fn(tf.image.encode_png, converted_targets, dtype=tf.string, name="target_pngs"), "outputs": tf.map_fn(tf.image.encode_png, converted_outputs, dtype=tf.string, name="output_pngs"), } sess = tf.InteractiveSession() saver = tf.train.Saver(max_to_keep=1) ckpt = tf.train.get_checkpoint_state(checkpoint) saver.restore(sess,ckpt.model_checkpoint_path) start = time.time() coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) for step in range(examples.count): results = sess.run(display_fetches) filesets = save_images(test_output_dir, results) for i, f in enumerate(filesets): print("evaluated image", f["name"]) index_path = append_index(test_output_dir, filesets) print("wrote index at", index_path) print("rate", (time.time() - start) / max_steps) if __name__ == '__main__': train() #test()
tensorflow实现BP算法遇到了问题,求大神指点!!!
import tensorflow as tf import numpy as np #from tensorflow.examples.tutorials.mnist import input_data #载入数据集 #mnist = input_data.read_data_sets("MNIST_data",one_hot=True) #每个批次的大小 #batch_size = 100 #????????????????????????????????? #计算一共有多少个批次 #n_batch = mnist.train.num_examples // batch_size #定义placeholder x_data=np.mat([[0.4984,0.5102,0.5213,0.5340], [0.5102,0.5213,0.5340,0.5407], [0.5213,0.5340,0.5407,0.5428], [0.5340,0.5407,0.5428,0.5530], [0.5407,0.5428,0.5530,0.5632], [0.5428,0.5530,0.5632,0.5739], [0.5530,0.5632,0.5739,0.5821], [0.5632,0.5739,0.5821,0.5920], [0.5739,0.5821,0.5920,0.5987], [0.5821,0.5920,0.5987,0.6043], [0.5920,0.5987,0.6043,0.6095], [0.5987,0.6043,0.6095,0.6161], [0.6043,0.6095,0.6161,0.6251], [0.6095,0.6161,0.6251,0.6318], [0.6161,0.6251,0.6318,0.6387], [0.6251,0.6318,0.6387,0.6462], [0.6318,0.6387,0.6462,0.6518], [0.6387,0.6462,0.6518,0.6589], [0.6462,0.6518,0.6589,0.6674], [0.6518,0.6589,0.6674,0.6786], [0.6589,0.6674,0.6786,0.6892], [0.6674,0.6786,0.6892,0.6988]]) y_data=np.mat([[0.5407], [0.5428], [0.5530], [0.5632], [0.5739], [0.5821], [0.5920], [0.5987], [0.6043], [0.6095], [0.6161], [0.6251], [0.6318], [0.6387], [0.6462], [0.6518], [0.6589], [0.6674], [0.6786], [0.6892], [0.6988], [0.7072]]) xs = tf.placeholder(tf.float32,[None,4]) # 样本数未知,特征数为1,占位符最后要以字典形式在运行中填入 ys = tf.placeholder(tf.float32,[None,1]) #创建一个简单的神经网络 W1 = tf.Variable(tf.truncated_normal([4,10],stddev=0.1)) b1 = tf.Variable(tf.zeros([10])+0.1) L1 = tf.nn.tanh(tf.matmul(x,W1)+b1) W2 = tf.Variable(tf.truncated_normal([10,1],stddev=0.1)) b2 = tf.Variable(tf.zeros([1])+0.1) L2 = tf.nn.softmax(tf.matmul(L1,W2)+b2) #二次代价函数 #loss = tf.reduce_mean(tf.square(y-prediction)) #loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=ys,logits=prediction)) loss = tf.reduce_mean(tf.reduce_sum(tf.square((ys-L2)),reduction_indices = [1]))#需要向相加索引号,redeuc执行跨纬度操作 #使用梯度下降法 #train_step = tf.train.GradientDescentOptimizer(0.1).mnimize(loss) train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss) #train = tf.train.GradientDescentOptimizer(0.1).minimize(loss) # 选择梯度下降法 #初始化变量 #init = tf.global_variables_initializer() init = tf.initialize_all_variables() #结果存放在一个布尔型列表中 #correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1)) #求准确率 #accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) with tf.Session() as sess: sess.run(init) for epoch in range(21): for i in range(22): #batch_xs,batch_ys = mnist.train.next_batch(batch_size) #????????????????????????? sess.run(train_step,feed_dict={xs:x_data,ys:y_data}) #test_acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels,keep_prob:1.0}) #train_acc = sess.run(accuracy,feed_dict={x:mnist.train.images,y:mnist.train.labels,keep_prob:1.0}) print (sess.run(prediction,feed_dict={xs:x_data,ys:y_data})) 提示:WARNING:tensorflow:From <ipython-input-10-578836c021a3>:89 in <module>.: initialize_all_variables (from tensorflow.python.ops.variables) is deprecated and will be removed after 2017-03-02. Instructions for updating: Use `tf.global_variables_initializer` instead. --------------------------------------------------------------------------- InvalidArgumentError Traceback (most recent call last) C:\Users\Administrator\Anaconda3\lib\site-packages\tensorflow\python\client\session.py in _do_call(self, fn, *args) 1020 try: -> 1021 return fn(*args) 1022 except errors.OpError as e: C:\Users\Administrator\Anaconda3\lib\site-packages\tensorflow\python\client\session.py in _run_fn(session, feed_dict, fetch_list, target_list, options, run_metadata) 1002 feed_dict, fetch_list, target_list, -> 1003 status, run_metadata) 1004 。。。
萌新求助代码问题,求各位老司机指点迷津,感激不尽!!
题目要求如下: 尝试编写一个用户登陆程序(这次尝试将该功能封装成函数),程序实现如图: ![图片说明](https://img-ask.csdn.net/upload/201701/08/1483886164_884313.png) 个人代码如下: ``` data = {} def denglu(): print('''\n |--- 新建用户:N/n ---| |--- 登陆账号:E/e ---| |--- 退出程序:Q/q ---|''') order = input('|--- 请输入指令代码:') if order == 'n' or order == 'N': name = input('请输入用户名:') flag1 = 1 #flag1用以判断输入的用户名是否存在,存在为1,不存在为0 while flag1: if name in data: name = input('此用户名已经被使用,请重新输入:') else: flag1 = 0 password = input('请输入密码:') data = [name, password] print('注册成功,赶紧试试登陆吧^_^') return denglu() elif order == 'e' or order == 'E': name = input('请输入用户名:') flag2 = 1 #flag2用以判断输入的用户名是否存在,不存在为1,存在为0 while flag2: if name not in data: name = input('您输入的用户名不存在,请重新输入:') else: flag2 = 0 password = input('请输入密码:') flag3 = 1 while flag3: #flag3用以判断输入的密码是否存在,不存在为1,存在为0 if password != data[name]: password = input('密码错误,请重新输入:') else: flag3 = 0 print('欢迎进入xxoo系统,请点右上角的x结束程序') else: print('欢迎使用XXX登陆程序!') denglu() ``` 运行程序,当输入指令为n(N)或者e(E),输入用户名之后,报错,提示为: UnboundLocalError: local variable 'data' referenced before assignment 但是data字典是我在函数外面定义的全局变量,所以不太理解哪里有问题,希望有经验的小伙伴指出问题所在,感激不尽!!!!!!
终于明白阿里百度这样的大公司,为什么面试经常拿ThreadLocal考验求职者了
点击上面↑「爱开发」关注我们每晚10点,捕获技术思考和创业资源洞察什么是ThreadLocalThreadLocal是一个本地线程副本变量工具类,各个线程都拥有一份线程私有的数
程序员必须掌握的核心算法有哪些?
由于我之前一直强调数据结构以及算法学习的重要性,所以就有一些读者经常问我,数据结构与算法应该要学习到哪个程度呢?,说实话,这个问题我不知道要怎么回答你,主要取决于你想学习到哪些程度,不过针对这个问题,我稍微总结一下我学过的算法知识点,以及我觉得值得学习的算法。这些算法与数据结构的学习大多数是零散的,并没有一本把他们全部覆盖的书籍。下面是我觉得值得学习的一些算法以及数据结构,当然,我也会整理一些看过
Linux(服务器编程):15---两种高效的事件处理模式(reactor模式、proactor模式)
前言 同步I/O模型通常用于实现Reactor模式 异步I/O模型则用于实现Proactor模式 最后我们会使用同步I/O方式模拟出Proactor模式 一、Reactor模式 Reactor模式特点 它要求主线程(I/O处理单元)只负责监听文件描述符上是否有事件发生,有的话就立即将时间通知工作线程(逻辑单元)。除此之外,主线程不做任何其他实质性的工作 读写数据,接受新的连接,以及处...
阿里面试官问我:如何设计秒杀系统?我的回答让他比起大拇指
你知道的越多,你不知道的越多 点赞再看,养成习惯 GitHub上已经开源 https://github.com/JavaFamily 有一线大厂面试点脑图和个人联系方式,欢迎Star和指教 前言 Redis在互联网技术存储方面使用如此广泛,几乎所有的后端技术面试官都要在Redis的使用和原理方面对小伙伴们进行360°的刁难。 作为一个在互联网公司面一次拿一次Offer的面霸,打败了...
五年程序员记流水账式的自白。
不知觉已中码龄已突破五年,一路走来从起初铁憨憨到现在的十九线程序员,一路成长,虽然不能成为高工,但是也能挡下一面,从15年很火的android开始入坑,走过java、.Net、QT,目前仍处于android和.net交替开发中。 毕业到现在一共就职过两家公司,目前是第二家,公司算是半个创业公司,所以基本上都会身兼多职。比如不光要写代码,还要写软著、软著评测、线上线下客户对接需求收集...
C语言魔塔游戏
很早就很想写这个,今天终于写完了。 游戏截图: 编译环境: VS2017 游戏需要一些图片,如果有想要的或者对游戏有什么看法的可以加我的QQ 2985486630 讨论,如果暂时没有回应,可以在博客下方留言,到时候我会看到。 下面我来介绍一下游戏的主要功能和实现方式 首先是玩家的定义,使用结构体,这个名字是可以自己改变的 struct gamerole { char n
一文详尽系列之模型评估指标
点击上方“Datawhale”,选择“星标”公众号第一时间获取价值内容在机器学习领域通常会根据实际的业务场景拟定相应的不同的业务指标,针对不同机器学习问题如回归、分类、排...
究竟你适不适合买Mac?
我清晰的记得,刚买的macbook pro回到家,开机后第一件事情,就是上了淘宝网,花了500元钱,找了一个上门维修电脑的师傅,上门给我装了一个windows系统。。。。。。 表砍我。。。 当时买mac的初衷,只是想要个固态硬盘的笔记本,用来运行一些复杂的扑克软件。而看了当时所有的SSD笔记本后,最终决定,还是买个好(xiong)看(da)的。 已经有好几个朋友问我mba怎么样了,所以今天尽量客观
程序员一般通过什么途径接私活?
二哥,你好,我想知道一般程序猿都如何接私活,我也想接,能告诉我一些方法吗? 上面是一个读者“烦不烦”问我的一个问题。其实不止是“烦不烦”,还有很多读者问过我类似这样的问题。 我接的私活不算多,挣到的钱也没有多少,加起来不到 20W。说实话,这个数目说出来我是有点心虚的,毕竟太少了,大家轻喷。但我想,恰好配得上“一般程序员”这个称号啊。毕竟苍蝇再小也是肉,我也算是有经验的人了。 唾弃接私活、做外...
压测学习总结(1)——高并发性能指标:QPS、TPS、RT、吞吐量详解
一、QPS,每秒查询 QPS:Queries Per Second意思是“每秒查询率”,是一台服务器每秒能够相应的查询次数,是对一个特定的查询服务器在规定时间内所处理流量多少的衡量标准。互联网中,作为域名系统服务器的机器的性能经常用每秒查询率来衡量。 二、TPS,每秒事务 TPS:是TransactionsPerSecond的缩写,也就是事务数/秒。它是软件测试结果的测量单位。一个事务是指一...
Python爬虫爬取淘宝,京东商品信息
小编是一个理科生,不善长说一些废话。简单介绍下原理然后直接上代码。 使用的工具(Python+pycharm2019.3+selenium+xpath+chromedriver)其中要使用pycharm也可以私聊我selenium是一个框架可以通过pip下载 pip install selenium -i https://pypi.tuna.tsinghua.edu.cn/simple/ 
阿里程序员写了一个新手都写不出的低级bug,被骂惨了。
这种新手都不会范的错,居然被一个工作好几年的小伙子写出来,差点被当场开除了。
Java工作4年来应聘要16K最后没要,细节如下。。。
前奏: 今天2B哥和大家分享一位前几天面试的一位应聘者,工作4年26岁,统招本科。 以下就是他的简历和面试情况。 基本情况: 专业技能: 1、&nbsp;熟悉Sping了解SpringMVC、SpringBoot、Mybatis等框架、了解SpringCloud微服务 2、&nbsp;熟悉常用项目管理工具:SVN、GIT、MAVEN、Jenkins 3、&nbsp;熟悉Nginx、tomca
2020年,冯唐49岁:我给20、30岁IT职场年轻人的建议
点击“技术领导力”关注∆  每天早上8:30推送 作者| Mr.K   编辑| Emma 来源| 技术领导力(ID:jishulingdaoli) 前天的推文《冯唐:职场人35岁以后,方法论比经验重要》,收到了不少读者的反馈,觉得挺受启发。其实,冯唐写了不少关于职场方面的文章,都挺不错的。可惜大家只记住了“春风十里不如你”、“如何避免成为油腻腻的中年人”等不那么正经的文章。 本文整理了冯
程序员该看的几部电影
##1、骇客帝国(1999) 概念:在线/离线,递归,循环,矩阵等 剧情简介: 不久的将来,网络黑客尼奥对这个看似正常的现实世界产生了怀疑。 他结识了黑客崔妮蒂,并见到了黑客组织的首领墨菲斯。 墨菲斯告诉他,现实世界其实是由一个名叫“母体”的计算机人工智能系统控制,人们就像他们饲养的动物,没有自由和思想,而尼奥就是能够拯救人类的救世主。 可是,救赎之路从来都不会一帆风顺,到底哪里才是真实的世界?
Python绘图,圣诞树,花,爱心 | Turtle篇
每周每日,分享Python实战代码,入门资料,进阶资料,基础语法,爬虫,数据分析,web网站,机器学习,深度学习等等。 公众号回复【进群】沟通交流吧,QQ扫码进群学习吧 微信群 QQ群 1.画圣诞树 import turtle screen = turtle.Screen() screen.setup(800,600) circle = turtle.Turtle()...
作为一个程序员,CPU的这些硬核知识你必须会!
CPU对每个程序员来说,是个既熟悉又陌生的东西? 如果你只知道CPU是中央处理器的话,那可能对你并没有什么用,那么作为程序员的我们,必须要搞懂的就是CPU这家伙是如何运行的,尤其要搞懂它里面的寄存器是怎么一回事,因为这将让你从底层明白程序的运行机制。 随我一起,来好好认识下CPU这货吧 把CPU掰开来看 对于CPU来说,我们首先就要搞明白它是怎么回事,也就是它的内部构造,当然,CPU那么牛的一个东
还记得那个提速8倍的IDEA插件吗?VS Code版本也发布啦!!
去年,阿里云发布了本地 IDE 插件 Cloud Toolkit,仅 IntelliJ IDEA 一个平台,就有 15 万以上的开发者进行了下载,体验了一键部署带来的开发便利。时隔一年的今天,阿里云正式发布了 Visual Studio Code 版本,全面覆盖前端开发者,帮助前端实现一键打包部署,让开发提速 8 倍。 VSCode 版本的插件,目前能做到什么? 安装插件之后,开发者可以立即体验...
破14亿,Python分析我国存在哪些人口危机!
2020年1月17日,国家统计局发布了2019年国民经济报告,报告中指出我国人口突破14亿。 猪哥的朋友圈被14亿人口刷屏,但是很多人并没有看到我国复杂的人口问题:老龄化、男女比例失衡、生育率下降、人口红利下降等。 今天我们就来分析一下我们国家的人口数据吧! 一、背景 1.人口突破14亿 2020年1月17日,国家统计局发布了 2019年国民经济报告 ,报告中指出:年末中国大陆总人口(包括31个
2019年除夕夜的有感而发
天气:小雨(加小雪) 温度:3摄氏度 空气:严重污染(399) 风向:北风 风力:微风 现在是除夕夜晚上十点钟,再有两个小时就要新的一年了; 首先要说的是我没患病,至少现在是没有患病;但是心情确像患了病一样沉重; 现在这个时刻应该大部分家庭都在看春晚吧,或许一家人团团圆圆的坐在一起,或许因为某些特殊原因而不能团圆;但不管是身在何处,身处什么境地,我都想对每一个人说一句:新年快乐! 不知道csdn这...
听说想当黑客的都玩过这个Monyer游戏(1~14攻略)
第零关 进入传送门开始第0关(游戏链接) 请点击链接进入第1关: 连接在左边→ ←连接在右边 看不到啊。。。。(只能看到一堆大佬做完的留名,也能看到菜鸡的我,在后面~~) 直接fn+f12吧 &lt;span&gt;连接在左边→&lt;/span&gt; &lt;a href="first.php"&gt;&lt;/a&gt; &lt;span&gt;←连接在右边&lt;/span&gt; o...
在家远程办公效率低?那你一定要收好这个「在家办公」神器!
相信大家都已经收到国务院延长春节假期的消息,接下来,在家远程办公可能将会持续一段时间。 但是问题来了。远程办公不是人在电脑前就当坐班了,相反,对于沟通效率,文件协作,以及信息安全都有着极高的要求。有着非常多的挑战,比如: 1在异地互相不见面的会议上,如何提高沟通效率? 2文件之间的来往反馈如何做到及时性?如何保证信息安全? 3如何规划安排每天工作,以及如何进行成果验收? ......
作为一个程序员,内存和磁盘的这些事情,你不得不知道啊!!!
截止目前,我已经分享了如下几篇文章: 一个程序在计算机中是如何运行的?超级干货!!! 作为一个程序员,CPU的这些硬核知识你必须会! 作为一个程序员,内存的这些硬核知识你必须懂! 这些知识可以说是我们之前都不太重视的基础知识,可能大家在上大学的时候都学习过了,但是嘞,当时由于老师讲解的没那么有趣,又加上这些知识本身就比较枯燥,所以嘞,大家当初几乎等于没学。 再说啦,学习这些,也看不出来有什么用啊!
2020年的1月,我辞掉了我的第一份工作
其实,这篇文章,我应该早点写的,毕竟现在已经2月份了。不过一些其它原因,或者是我的惰性、还有一些迷茫的念头,让自己迟迟没有试着写一点东西,记录下,或者说是总结下自己前3年的工作上的经历、学习的过程。 我自己知道的,在写自己的博客方面,我的文笔很一般,非技术类的文章不想去写;另外我又是一个还比较热衷于技术的人,而平常复杂一点的东西,如果想写文章写的清楚点,是需要足够...
别低估自己的直觉,也别高估自己的智商
所有群全部吵翻天,朋友圈全部沦陷,公众号疯狂转发。这两周没怎么发原创,只发新闻,可能有人注意到了。我不是懒,是文章写了却没发,因为大家的关注力始终在这次的疫情上面,发了也没人看。当然,我...
这个世界上人真的分三六九等,你信吗?
偶然间,在知乎上看到一个问题 一时间,勾起了我深深的回忆。 以前在厂里打过两次工,做过家教,干过辅导班,做过中介。零下几度的晚上,贴过广告,满脸、满手地长冻疮。 再回首那段岁月,虽然苦,但让我学会了坚持和忍耐。让我明白了,在这个世界上,无论环境多么的恶劣,只要心存希望,星星之火,亦可燎原。 下文是原回答,希望能对你能有所启发。 如果我说,这个世界上人真的分三六九等,...
节后首个工作日,企业们集体开晨会让钉钉挂了
By 超神经场景描述:昨天 2 月 3 日,是大部分城市号召远程工作的第一天,全国有接近 2 亿人在家开始远程办公,钉钉上也有超过 1000 万家企业活跃起来。关键词:十一出行 人脸...
Java基础知识点梳理
Java基础知识点梳理 摘要: 虽然已经在实际工作中经常与java打交道,但是一直没系统地对java这门语言进行梳理和总结,掌握的知识也比较零散。恰好利用这段时间重新认识下java,并对一些常见的语法和知识点做个总结与回顾,一方面为了加深印象,方便后面查阅,一方面为了学好java打下基础。 Java简介 java语言于1995年正式推出,最开始被命名为Oak语言,由James Gosling(詹姆
2020年全新Java学习路线图,含配套视频,学完即为中级Java程序员!!
新的一年来临,突如其来的疫情打破了平静的生活! 在家的你是否很无聊,如果无聊就来学习吧! 世上只有一种投资只赚不赔,那就是学习!!! 传智播客于2020年升级了Java学习线路图,硬核升级,免费放送! 学完你就是中级程序员,能更快一步找到工作! 一、Java基础 JavaSE基础是Java中级程序员的起点,是帮助你从小白到懂得编程的必经之路。 在Java基础板块中有6个子模块的学
B 站上有哪些很好的学习资源?
哇说起B站,在小九眼里就是宝藏般的存在,放年假宅在家时一天刷6、7个小时不在话下,更别提今年的跨年晚会,我简直是跪着看完的!! 最早大家聚在在B站是为了追番,再后来我在上面刷欧美新歌和漂亮小姐姐的舞蹈视频,最近两年我和周围的朋友们已经把B站当作学习教室了,而且学习成本还免费,真是个励志的好平台ヽ(.◕ฺˇд ˇ◕ฺ;)ノ 下面我们就来盘点一下B站上优质的学习资源: 综合类 Oeasy: 综合
如何优雅地打印一个Java对象?
你好呀,我是沉默王二,一个和黄家驹一样身高,和刘德华一样颜值的程序员。虽然已经写了十多年的 Java 代码,但仍然觉得自己是个菜鸟(请允许我惭愧一下)。 在一个月黑风高的夜晚,我思前想后,觉得再也不能这么蹉跎下去了。于是痛下决心,准备通过输出的方式倒逼输入,以此来修炼自己的内功,从而进阶成为一名真正意义上的大神。与此同时,希望这些文章能够帮助到更多的读者,让大家在学习的路上不再寂寞、空虚和冷。 ...
雷火神山直播超两亿,Web播放器事件监听是怎么实现的?
Web播放器解决了在手机浏览器和PC浏览器上播放音视频数据的问题,让视音频内容可以不依赖用户安装App,就能进行播放以及在社交平台进行传播。在视频业务大数据平台中,播放数据的统计分析非常重要,所以Web播放器在使用过程中,需要对其内部的数据进行收集并上报至服务端,此时,就需要对发生在其内部的一些播放行为进行事件监听。 那么Web播放器事件监听是怎么实现的呢? 01 监听事件明细表 名...
【CSDN学院出品】 你不可不知的JS面试题(分期更新……)
1、JS中有哪些内置类型? 7种。分别是boolean、number、string、object、undefined、null、symbol。 2、NaN是独立的一种类型吗? 不是。NaN是number类型。 3、如何判断是哪个类型? Object.prototype.toString.call(),返回为[object Type]。 现在我们来验证一下。 Object.prototype.toS...
3万字总结,Mysql优化之精髓
本文知识点较多,篇幅较长,请耐心学习 MySQL已经成为时下关系型数据库产品的中坚力量,备受互联网大厂的青睐,出门面试想进BAT,想拿高工资,不会点MySQL优化知识,拿offer的成功率会大大下降。 为什么要优化 系统的吞吐量瓶颈往往出现在数据库的访问速度上 随着应用程序的运行,数据库的中的数据会越来越多,处理时间会相应变慢 数据是存放在磁盘上的,读写速度无法和内存相比 如何优化 设计...
HTML5适合的情人节礼物有纪念日期功能
前言 利用HTML5,css,js实现爱心树 以及 纪念日期的功能 网页有播放音乐功能 以及打字倾诉感情的画面,非常适合情人节送给女朋友 具体的HTML代码 具体只要修改代码里面的男某某和女某某 文字段也可自行修改,还有代码下半部分的JS代码需要修改一下起始日期 注意月份为0~11月 也就是月份需要减一。 当然只有一部分HTML和JS代码不够运行的,文章最下面还附加了完整代码的下载地址 &lt;!...
Git笔记(3) 安装配置
Git的安装,基础配置以及如何获取帮助
Python新型冠状病毒疫情数据自动爬取+统计+发送报告+数据屏幕(三)发送篇
今天介绍的项目是使用 Itchat 发送统计报告 项目功能设计: 定时爬取疫情数据存入Mysql 进行数据分析制作疫情报告 使用itchat给亲人朋友发送分析报告 基于Django做数据屏幕 使用Tableau做数据分析 来看看最终效果 目前已经完成,预计2月12日前更新 使用 itchat 发送数据统计报告 itchat 是一个基于 web微信的一个框架,但微信官方并不允许使用这...
相关热词 c# 压缩图片好麻烦 c#计算数组中的平均值 c#获取路由参数 c#日期精确到分钟 c#自定义异常必须继承 c#查表并返回值 c# 动态 表达式树 c# 监控方法耗时 c# listbox c#chart显示滚动条
立即提问