tensorflow简单介绍
tensor采用图运算的方式搭建并训练深度学习网络,该部分使用的库包版本为tensorflow==1.14.0
代码拆分(每个代码块可以放到一个jupyter的cell里)
导入tensorflow及其他包
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
载入数据集
# 载入数据集
mnist = input_data.read_data_sets("MNIST_data",one_hot=True)
定义网络结构及进行训练
# 每个批次的大小
batch_size = 100
# 计算一共有多少个批次
n_batch = mnist.train.num_examples //batch_size
# 定义两个placeholder
x = tf.placeholder(tf.float32,[None,784])
y = tf.placeholder(tf.float32,[None,10])
# 创建一个简单的神经网络
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
prediction = tf.nn.softmax(tf.matmul(x,W) + b)
# 二次损失函数
loss = tf.reduce_mean(tf.square(y-prediction))
# 使用梯度下降法
train_step