tensorflow axis1越界

-*- coding: utf-8 -*-

#author JLU_GuanQQ
"""
Spyder Editor

This is a temporary script file.
"""

#step0 import module and generate dataset
import tensorflow as tf
import numpy as np
#import matplotlib.pyplot as plt
BATCH_SIZE=8
seed=2
#generate random numbers based on seed
rdnum=np.random.RandomState(seed)
x_true=rdnum.rand(32,2)
y_true=[[int(x1+x2<1)] for (x1,x2) in x_true]

#step1 define the input parameter output and the process of forward propagation of the network
x=tf.placeholder(tf.float32,shape=(None,2)) #input
y=tf.placeholder(tf.float32,shape=(None,1)) #output
w1=tf.Variable(tf.random_normal([2,3])) #parameter
w2=tf.Variable(tf.random_normal([3,1])) #parameter
a=tf.matmul(x,w1) #propagation
y_=tf.matmul(a,w2) #propagation

#step2 define loss_function and backforward propagation
loss_function=tf.reduce_mean(tf.square(y_-y))
train_step=tf.train.GradientDescentOptimizer(0.001).minimize(loss_function)

#step3 define session and begin to train
with tf.Session() as sess:
init_parameter=tf.global_variables_initializer() #initialize parameters
sess.run(init_parameter)
steps=20000 #define the number of training sessions
for i in range(steps):
start_position=(i*BATCH_SIZE)%32
end_position=start_position+BATCH_SIZE
sess.run(train_step,feed_dict={x:x_true[start_position,end_position],y:y_true[start_position,end_position]})

            最后一句话 index 8 is out of bounds for axis 1 with size 2
Csdn user default icon
上传中...
上传图片
插入图片
抄袭、复制答案,以达到刷声望分或其他目的的行为,在CSDN问答是严格禁止的,一经发现立刻封号。是时候展现真正的技术了!
立即提问