from future import absolute_import
from future import division
from future import print_function
import argparse #解析训练和检测数据模块
import sys
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
FLAGS = None
def main(_):
# Import data
mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
# Create the model
x = tf.placeholder(tf.float32, [None, 784]) #此函数可以理解为形参,用于定义过程,在执行的时候再赋具体的值
W = tf.Variable(tf.zeros([784, 10])) # tf.zeros表示所有的维度都为0
b = tf.Variable(tf.zeros([10]))
y = tf.matmul(x, W) + b #对应每个分类概率值。
# Define loss and optimizer
y_ = tf.placeholder(tf.float32, [None, 10])
# The raw formulation of cross-entropy,
#
# tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(tf.nn.softmax(y)),
# reduction_indices=[1]))
#
# can be numerically unstable.
#
# So here we use tf.nn.softmax_cross_entropy_with_logits on the raw
# outputs of 'y', and then average across the batch.
cross_entropy = tf.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
sess = tf.InteractiveSession()
tf.global_variables_initializer().run()
# Train
for _ in range(1000):
batch_xs, batch_ys = mnist.train.next_batch(100)
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
# Test trained model
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print(sess.run(accuracy, feed_dict={x: mnist.test.images,
y_: mnist.test.labels}))
if name == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', type=str, default='/tmp/tensorflow/mnist/input_data',
help='Directory for storing input data')
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
```下面是报错:
TimeoutError Traceback (most recent call last)
~\Anaconda3\envs\tensorflow\lib\urllib\request.py in do_open(self, http_class, req, **http_conn_args)
1317 h.request(req.get_method(), req.selector, req.data, headers,
-> 1318 encode_chunked=req.has_header('Transfer-encoding'))
1319 except OSError as err: # timeout error
~\Anaconda3\envs\tensorflow\lib\http\client.py in request(self, method, url, body, headers, encode_chunked)
1238 """Send a complete request to the server."""
-> 1239 self._send_request(method, url, body, headers, encode_chunked)
1240
~\Anaconda3\envs\tensorflow\lib\http\client.py in _send_request(self, method, url, body, headers, encode_chunked)
1284 body = _encode(body, 'body')
-> 1285 self.endheaders(body, encode_chunked=encode_chunked)
1286
~\Anaconda3\envs\tensorflow\lib\http\client.py in endheaders(self, message_body, encode_chunked)
1233 raise CannotSendHeader()
-> 1234 self._send_output(message_body, encode_chunked=encode_chunked)
1235
~\Anaconda3\envs\tensorflow\lib\http\client.py in _send_output(self, message_body, encode_chunked)
1025 del self._buffer[:]
-> 1026 self.send(msg)
1027
~\Anaconda3\envs\tensorflow\lib\http\client.py in send(self, data)
963 if self.auto_open:
--> 964 self.connect()
965 else:
~\Anaconda3\envs\tensorflow\lib\http\client.py in connect(self)
1399 self.sock = self._context.wrap_socket(self.sock,
-> 1400 server_hostname=server_hostname)
1401 if not self._context.check_hostname and self._check_hostname:
~\Anaconda3\envs\tensorflow\lib\ssl.py in wrap_socket(self, sock, server_side, do_handshake_on_connect, suppress_ragged_eofs, server_hostname, session)
400 server_hostname=server_hostname,
--> 401 _context=self, _session=session)
402
~\Anaconda3\envs\tensorflow\lib\ssl.py in __init__(self, sock, keyfile, certfile, server_side, cert_reqs, ssl_version, ca_certs, do_handshake_on_connect, family, type, proto, fileno, suppress_ragged_eofs, npn_protocols, ciphers, server_hostname, _context, _session)
807 raise ValueError("do_handshake_on_connect should not be specified for non-blocking sockets")
--> 808 self.do_handshake()
809
~\Anaconda3\envs\tensorflow\lib\ssl.py in do_handshake(self, block)
1060 self.settimeout(None)
-> 1061 self._sslobj.do_handshake()
1062 finally:
~\Anaconda3\envs\tensorflow\lib\ssl.py in do_handshake(self)
682 """Start the SSL/TLS handshake."""
--> 683 self._sslobj.do_handshake()
684 if self.context.check_hostname:
TimeoutError: [WinError 10060] 由于连接方在一段时间后没有正确答复或连接的主机没有反应,连接尝试失败。
During handling of the above exception, another exception occurred:
URLError Traceback (most recent call last)
<ipython-input-1-eaf9732201f9> in <module>()
57 help='Directory for storing input data')
58 FLAGS, unparsed = parser.parse_known_args()
---> 59 tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
~\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\platform\app.py in run(main, argv)
46 # Call the main function, passing through any arguments
47 # to the final program.
---> 48 _sys.exit(main(_sys.argv[:1] + flags_passthrough))
49
50
<ipython-input-1-eaf9732201f9> in main(_)
15 def main(_):
16 # Import data
---> 17 mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
18
19 # Create the model
~\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py in read_data_sets(train_dir, fake_data, one_hot, dtype, reshape, validation_size, seed)
238
239 local_file = base.maybe_download(TRAIN_LABELS, train_dir,
--> 240 SOURCE_URL + TRAIN_LABELS)
241 with open(local_file, 'rb') as f:
242 train_labels = extract_labels(f, one_hot=one_hot)
~\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\base.py in maybe_download(filename, work_directory, source_url)
206 filepath = os.path.join(work_directory, filename)
207 if not gfile.Exists(filepath):
--> 208 temp_file_name, _ = urlretrieve_with_retry(source_url)
209 gfile.Copy(temp_file_name, filepath)
210 with gfile.GFile(filepath) as f:
~\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\base.py in wrapped_fn(*args, **kwargs)
163 for delay in delays():
164 try:
--> 165 return fn(*args, **kwargs)
166 except Exception as e: # pylint: disable=broad-except)
167 if is_retriable is None:
~\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\base.py in urlretrieve_with_retry(url, filename)
188 @retry(initial_delay=1.0, max_delay=16.0, is_retriable=_is_retriable)
189 def urlretrieve_with_retry(url, filename=None):
--> 190 return urllib.request.urlretrieve(url, filename)
191
192
~\Anaconda3\envs\tensorflow\lib\urllib\request.py in urlretrieve(url, filename, reporthook, data)
246 url_type, path = splittype(url)
247
--> 248 with contextlib.closing(urlopen(url, data)) as fp:
249 headers = fp.info()
250
~\Anaconda3\envs\tensorflow\lib\urllib\request.py in urlopen(url, data, timeout, cafile, capath, cadefault, context)
221 else:
222 opener = _opener
--> 223 return opener.open(url, data, timeout)
224
225 def install_opener(opener):
~\Anaconda3\envs\tensorflow\lib\urllib\request.py in open(self, fullurl, data, timeout)
524 req = meth(req)
525
--> 526 response = self._open(req, data)
527
528 # post-process response
~\Anaconda3\envs\tensorflow\lib\urllib\request.py in _open(self, req, data)
542 protocol = req.type
543 result = self._call_chain(self.handle_open, protocol, protocol +
--> 544 '_open', req)
545 if result:
546 return result
~\Anaconda3\envs\tensorflow\lib\urllib\request.py in _call_chain(self, chain, kind, meth_name, *args)
502 for handler in handlers:
503 func = getattr(handler, meth_name)
--> 504 result = func(*args)
505 if result is not None:
506 return result
~\Anaconda3\envs\tensorflow\lib\urllib\request.py in https_open(self, req)
1359 def https_open(self, req):
1360 return self.do_open(http.client.HTTPSConnection, req,
-> 1361 context=self._context, check_hostname=self._check_hostname)
1362
1363 https_request = AbstractHTTPHandler.do_request_
~\Anaconda3\envs\tensorflow\lib\urllib\request.py in do_open(self, http_class, req, **http_conn_args)
1318 encode_chunked=req.has_header('Transfer-encoding'))
1319 except OSError as err: # timeout error
-> 1320 raise URLError(err)
1321 r = h.getresponse()
1322 except:
URLError: <urlopen error [WinError 10060] 由于连接方在一段时间后没有正确答复或连接的主机没有反应,连接尝试失败。>
In [ ]: