最近在研究学习TensorFlow,在做识别手写数字的demo时,遇到了tf.nn.conv2d这个方法,查阅了官网的API 发现讲得比较简略,还是没理解。google了一下,参考了网上一些朋友写得博客,结合自己的理解,差不多整明白了。
方法定义
tf.nn.conv2d (input, filter, strides, padding, use_cudnn_on_gpu=None, data_format=None, name=None)
参数:
- input : 输入的要做卷积的图片,要求为一个张量,shape为 [ batch, in_height, in_width, in_channel ],其中batch为图片的数量,in_height 为图片高度,in_width 为图片宽度,in_channel 为图片的通道数,灰度图该值为1,彩色图为3。(也可以用其它值,但是具体含义不是很理解)
- filter: 卷积核,要求也是一个张量,shape为 [ filter_height, filter_width, in_channel, out_channels ],其中 filter_height 为卷积核高度,filter_width 为卷积核宽度,in_channel 是图像通道数 ,和 input 的 in_channel 要保持一致,out_channel 是卷积核数量。
- strides: 卷积时在图像每一维的步长,这是一个一维的向量,[ 1, strides, strides, 1],第一位和最后一位固定必须是1
- padding: string类型,值为“SAME” 和 “VALID”,表示的是卷积的形式,是否考虑边界。"SAME"是考虑边界,不足的时候用0去填充周围,"VALID"则不考虑
- use_cudnn_on_gpu: bool类型,是否使用cudnn加速,默认为true
具体实现
import tensorflow as tf
# case 1
# 输入是1张 3*3 大小的图片,图像通道数是5,卷积核是 1*1 大小,数量是1
# 步长是[1,1,1,1]最后得到一个 3*3 的feature map
# 1张图最后输出就是一个 shape为[1,3,3,1] 的张量
input = tf.Variable(tf.random_normal([1,3,3,5]))
filter = tf.Variable(tf.random_normal([1,1,5,1]))
op1 = tf.nn.conv2d(input, filter, strides=[1,1,1,1], padding='SAME')
# case 2
# 输入是1张 3*3 大小的图片,图像通道数是5,卷积核是 2*2 大小,数量是1
# 步长是[1,1,1,1]最后得到一个 3*3 的feature map
# 1张图最后输出就是一个 shape为[1,3,3,1] 的张量
input = tf.Variable(tf.random_normal([1,3,3,5]))
filter = tf.Variable(tf.random_normal([2,2,5,1]))
op2 = tf.nn.conv2d(input, filter, strides=[1,1,1,1], padding='SAME')
# case 3
# 输入是1张 3*3 大小的图片,图像通道数是5,卷积核是 3*3 大小,数量是1
# 步长是[1,1,1,1]最后得到一个 1*1 的feature map (不考虑边界)
# 1张图最后输出就是一个 shape为[1,1,1,1] 的张量
input = tf.Variable(tf.random_normal([1,3,3,5]))
filter = tf.Variable(tf.random_normal([3,3,5,1]))
op3 = tf.nn.conv2d(input, filter, strides=[1, 1, 1, 1], padding='VALID')
# case 4
# 输入是1张 5*5 大小的图片,图像通道数是5,卷积核是 3*3 大小,数量是1
# 步长是[1,1,1,1]最后得到一个 3*3 的feature map (不考虑边界)
# 1张图最后输出就是一个 shape为[1,3,3,1] 的张量
input = tf.Variable(tf.random_normal([1,5,5,5]))
filter = tf.Variable(tf.random_normal([3,3,5,1]))
op4 = tf.nn.conv2d(input, filter, strides=[1, 1, 1, 1], padding='VALID')
# case 5
# 输入是1张 5*5 大小的图片,图像通道数是5,卷积核是 3*3 大小,数量是1
# 步长是[1,1,1,1]最后得到一个 5*5 的feature map (考虑边界)
# 1张图最后输出就是一个 shape为[1,5,5,1] 的张量
input = tf.Variable(tf.random_normal([1,5,5,5]))
filter = tf.Variable(tf.random_normal([3,3,5,1]))
op5 = tf.nn.conv2d(input, filter, strides=[1, 1, 1, 1], padding='SAME')
# case 6
# 输入是1张 5*5 大小的图片,图像通道数是5,卷积核是 3*3 大小,数量是7
# 步长是[1,1,1,1]最后得到一个 5*5 的feature map (考虑边界)
# 1张图最后输出就是一个 shape为[1,5,5,7] 的张量
input = tf.Variable(tf.random_normal([1,5,5,5]))
filter = tf.Variable(tf.random_normal([3,3,5,7]))
op6 = tf.nn.conv2d(input, filter, strides=[1, 1, 1, 1], padding='SAME')
# case 7
# 输入是1张 5*5 大小的图片,图像通道数是5,卷积核是 3*3 大小,数量是7
# 步长是[1,2,2,1]最后得到7个 3*3 的feature map (考虑边界)
# 1张图最后输出就是一个 shape为[1,3,3,7] 的张量
input = tf.Variable(tf.random_normal([1,5,5,5]))
filter = tf.Variable(tf.random_normal([3,3,5,7]))
op7 = tf.nn.conv2d(input, filter, strides=[1, 2, 2, 1], padding='SAME')
# case 8
# 输入是10 张 5*5 大小的图片,图像通道数是5,卷积核是 3*3 大小,数量是7
# 步长是[1,2,2,1]最后每张图得到7个 3*3 的feature map (考虑边界)
# 10张图最后输出就是一个 shape为[10,3,3,7] 的张量
input = tf.Variable(tf.random_normal([10,5,5,5]))
filter = tf.Variable(tf.random_normal([3,3,5,7]))
op8 = tf.nn.conv2d(input, filter, strides=[1, 2, 2, 1], padding='SAME')
init = tf.initialize_all_variables()
with tf.Session() as sess:
sess.run(init)
print('*' * 20 + ' op1 ' + '*' * 20)
print(sess.run(op1))
print('*' * 20 + ' op2 ' + '*' * 20)
print(sess.run(op2))
print('*' * 20 + ' op3 ' + '*' * 20)
print(sess.run(op3))
print('*' * 20 + ' op4 ' + '*' * 20)
print(sess.run(op4))
print('*' * 20 + ' op5 ' + '*' * 20)
print(sess.run(op5))
print('*' * 20 + ' op6 ' + '*' * 20)
print(sess.run(op6))
print('*' * 20 + ' op7 ' + '*' * 20)
print(sess.run(op7))
print('*' * 20 + ' op8 ' + '*' * 20)
print(sess.run(op8))
# 运行结果
******************** op1 ********************
[[[[ 0.78366613]
[-0.11703026]
[ 3.533338 ]]
[[ 3.4455981 ]
[-2.40102 ]
[-1.3336506 ]]
[[ 1.9816184 ]
[-3.3166158 ]
[ 2.0968733 ]]]]
******************** op2 ********************
[[[[-4.429776 ]
[ 4.1218996 ]
[-4.1383405 ]]
[[ 0.4804101 ]
[ 1.3983132 ]
[ 1.2663789 ]]
[[-1.8450742 ]
[-0.02915052]
[-0.5696235 ]]]]
******************** op3 ********************
[[[[-6.969367]]]]
******************** op4 ********************
[[[[ -2.9217496 ]
[ 4.4683943 ]
[ 7.5761824 ]]
[[-14.627491 ]
[ -5.014709 ]
[ -3.4593797 ]]
[[ 0.45091882]
[ 4.8827124 ]
[ -9.658895 ]]]]
******************** op5 ********************
[[[[-2.8486536 ]
[ 1.3990458 ]
[ 2.953944 ]
[-6.007198 ]
[ 5.089696 ]]
[[-0.20283715]
[ 2.4726171 ]
[ 6.2137847 ]
[-0.38609552]
[-1.8869443 ]]
[[ 7.7240233 ]
[10.6962805 ]
[-3.1667676 ]
[-3.6487846 ]
[-2.2908094 ]]
[[-9.00223 ]
[ 4.5111785 ]
[ 2.5615098 ]
[-5.8492236 ]
[ 1.7734764 ]]
[[ 2.3674765 ]
[-5.9122458 ]
[ 5.867611 ]
[-0.50353 ]
[-4.890904 ]]]]
******************** op6 ********************
[[[[-4.06957626e+00 5.69651246e-01 2.97890633e-01 -5.08075190e+00
2.76357365e+00 -7.34121323e+00 -2.09436584e+00]
[-9.03515625e+00 -8.96854973e+00 -4.40316677e+00 -3.23745847e+00
-3.56242275e+00 3.67262197e+00 2.59603453e+00]
[ 1.25131302e+01 1.30267200e+01 2.25630283e+00 3.31285048e+00
-1.00396938e+01 -9.06786323e-01 -7.20120049e+00]
[-3.18641067e-01 -7.66135693e+00 5.02029419e+00 -1.65469778e+00
-5.53000355e+00 -4.76842117e+00 4.98133230e+00]
[ 3.68885136e+00 2.54145473e-01 -4.17096436e-01 1.20136106e+00
-2.29291725e+00 6.98313904e+00 4.92819786e-01]]
[[ 1.22962761e+01 3.85902214e+00 -2.91524696e+00 -6.89016438e+00
3.35520816e+00 -1.85112596e+00 5.59113741e+00]
[ 2.99087334e+00 4.42690086e+00 -3.34755349e+00 -7.41521478e-01
3.65099478e+00 -2.84761238e+00 -2.74149513e+00]
[-9.65088654e+00 -4.91817188e+00 3.82093906e+00 -5.72443676e+00
1.43630829e+01 5.11133957e+00 -1.18163595e+01]
[ 1.69606721e+00 -1.00837049e+01 9.65112305e+00 3.48559356e+00
4.71356201e+00 -2.74463081e+00 -5.76961470e+00]
[-5.11555862e+00 1.06215849e+01 1.97274566e+00 -1.66155469e+00
5.40411043e+00 1.64753020e+00 -2.25898552e+00]]
[[ 3.20135975e+00 1.16082029e+01 6.35383892e+00 -1.22541785e+00
-7.81781197e-01 -7.39507914e+00 3.02070093e+00]
[ 3.37887239e+00 -3.17085648e+00 8.15050030e+00 9.17820644e+00
-5.42563820e+00 -1.06148596e+01 1.44039564e+01]
[ 6.06520414e+00 -6.89214110e-01 1.18828654e+00 6.44250536e+00
-3.90648508e+00 -7.45609093e+00 1.70780718e-02]
[-5.51369572e+00 -5.99862814e-01 -5.97459745e+00 5.03705800e-01
-4.89957094e-01 4.65023327e+00 6.97832489e+00]
[ 5.56566572e+00 3.15251064e+00 4.23309374e+00 4.58887959e+00
1.11150384e+00 1.56815052e-01 -2.64446616e+00]]
[[-3.47755957e+00 -2.51347685e+00 5.07092476e+00 -1.79448032e+01
1.23025656e+00 -7.04272604e+00 -3.11969209e+00]
[-3.64519453e+00 -2.48672795e+00 1.45192409e+00 -7.42938709e+00
7.32508659e-01 1.73417020e+00 -8.84127915e-01]
[ 4.80518007e+00 -1.00521259e+01 -1.47410703e+00 -2.73861027e+00
-6.11766815e+00 5.89801645e+00 7.41809845e+00]
[ 1.52897854e+01 3.40052223e+00 -1.17849231e-01 8.11421871e+00
-7.15329647e-02 -8.57025623e+00 -6.36894524e-01]
[-1.29184561e+01 -2.07097292e+00 6.51137114e+00 4.45195580e+00
6.51636696e+00 1.94592953e-01 7.76367307e-01]]
[[-7.64904690e+00 -4.64357853e+00 -5.09730625e+00 1.46977997e+00
-2.66898251e+00 6.18280554e+00 7.30443239e+00]
[ 3.74768376e-02 8.19200230e+00 -2.99126768e+00 -1.25706446e+00
2.82602859e+00 4.79209185e-01 -7.99170971e+00]
[-9.31276321e+00 2.71563363e+00 2.68426132e+00 -2.98767281e+00
2.85978794e-01 5.26730251e+00 -6.51313114e+00]
[-5.16205406e+00 -3.73660684e+00 -1.25655127e+00 -4.03212357e+00
-2.34876966e+00 3.49581933e+00 3.21578264e-01]
[ 4.80592680e+00 -2.01916337e+00 -2.70319057e+00 9.14705086e+00
3.14293051e+00 -5.12257957e+00 1.87513745e+00]]]]
******************** op7 ********************
[[[[ -5.3398733 4.176247 -1.0400615 1.7490227 -2.3762708
-4.43866 -2.9152555 ]
[ -6.2849035 2.9156108 2.2420614 3.0133455 2.697643
-1.2664369 2.2018924 ]
[ -1.7367094 -2.6707978 -4.823809 -2.9799473 -2.588249
-0.8573512 0.7243177 ]]
[[ 9.770168 -6.0919194 -7.755929 0.7116828 4.696847
-1.5403405 -10.603018 ]
[ -2.2849545 7.23973 0.06859291 -0.3011052 -7.885673
-4.7223825 -1.2202084 ]
[ -1.7584102 -0.9349402 1.8078477 6.8720684 11.548839
-1.3058915 1.785974 ]]
[[ 3.8749192 -5.9033284 1.3921509 -2.68101 5.386052
5.2535496 7.804141 ]
[ 1.9598813 -6.1589165 0.9447456 0.06089067 -3.7891803
-2.0653834 -2.60965 ]
[ -2.1243367 -0.9703847 1.5366316 5.8760977 -3.697129
6.050654 -0.01914603]]]]
******************** op8 ********************
[[[[ 7.6126375 -2.261326 0.32292777 8.602917 -2.9009488
3.3160565 2.1506643 ]
[ -3.5364501 -2.1440878 1.354662 5.531647 -1.4339367
5.1957445 -0.9030779 ]
[ 7.844642 -6.1276717 7.7938704 -2.23364 -3.4782376
-5.097751 5.285432 ]]
[[ -1.6915132 2.2787857 -5.9708385 8.21313 -4.5076394
-0.3270775 -8.479343 ]
[ 2.0611243 3.1743298 -0.53598183 -3.0830724 -13.820877
5.3642063 -4.0782714 ]
[ -2.2280676 -6.232974 6.031793 6.4705186 1.1858556
-5.012024 -0.12968755]]
[[ -2.7237153 -2.0637414 1.4018252 -2.937191 2.572178
3.9408593 2.605546 ]
[ -1.607345 5.66703 -4.989913 -6.0507936 -1.9384562
0.61666656 -6.9282484 ]
[ -0.03978544 -2.008681 -7.406146 -1.2036608 -3.8769712
-3.0997906 6.066886 ]]]
[[[ -0.6766513 -0.16299164 3.2324884 -3.3543284 2.711526
-0.7604065 -2.9422672 ]
[-11.477009 6.985447 -7.168281 1.6444209 2.1505005
-2.5210168 1.248457 ]
[ -2.5344536 0.78997815 4.921354 0.32946062 -3.4039345
2.3872323 1.0319829 ]]
[[ 5.672534 -4.6865053 5.780566 11.394991 1.0943577
1.6653306 -0.93034 ]
[ 11.131994 6.8491035 -15.839502 7.006518 3.261397
-0.99962735 10.55006 ]
[ 2.6103654 2.7730281 2.3594556 3.5570846 6.1872926
4.217743 -6.4607897 ]]
[[ -2.7581267 -0.12229636 1.351732 -4.4823456 2.1730578
-2.828763 -3.0473292 ]
[ -2.742803 -5.817521 -4.570032 -7.3254657 3.2537496
-0.6938226 0.6609373 ]
[ -3.1279428 -4.922457 2.745709 -4.864913 -3.6143937
2.6719465 -1.1376699 ]]]
[[[ -0.7445632 0.45240074 5.131389 -2.8525875 1.3901956
-0.4648465 5.4685025 ]
[ 3.1593595 1.2171756 0.1267331 -3.2178001 -2.6123729
-5.186987 4.1898375 ]
[ 9.478796 -1.8722348 4.896418 1.301182 -3.6362329
-1.9956454 -1.770525 ]]
[[ 4.8301635 -3.8837552 7.0490103 1.2435023 3.4047306
-3.2604568 1.051601 ]
[ -2.2003438 0.88552344 -6.8119774 7.017317 -2.9890797
5.8106375 -0.863615 ]
[ -0.17809808 -10.802618 3.225249 -2.0419974 5.072168
1.2349106 -4.600774 ]]
[[ -3.1843624 -2.5729177 1.191327 -3.0042355 0.97465754
-4.564925 3.9409044 ]
[ 1.2322719 14.114404 -0.35690814 2.2237332 0.35432827
-1.9053037 -12.545719 ]
[ 0.80399454 -5.358243 -6.344287 3.5417094 -3.9716966
-0.02347088 3.0606985 ]]]
[[[ 0.37148464 -3.8297706 -2.0831337 6.29245 2.5057077
0.8506646 1.9863653 ]
[ 3.765554 1.4267049 1.0800252 7.7149706 0.44219214
8.109619 3.6685073 ]
[ 4.635173 -2.9154918 -6.4538617 -5.448964 6.57819
0.61271524 2.9938192 ]]
[[ -3.616211 0.0879938 -6.3440037 1.6937144 0.04956067
2.4064069 -8.493458 ]
[ -5.0647597 0.93558145 -1.9845109 -8.771115 4.6100225
1.1144816 -12.28625 ]
[ 1.0221918 -7.5176277 -1.8426392 -4.289383 2.2868915
-8.87014 -0.3772235 ]]
[[ -1.1132717 2.4524128 -0.365159 4.004697 -1.5730555
0.5331385 -6.8898973 ]
[ 3.5391765 2.8012395 0.7159001 7.421248 -3.0292435
3.0187619 -3.9419355 ]
[ -5.387392 -6.63677 2.4566684 1.821631 -0.16935372
-0.88219285 2.2688925 ]]]
[[[ -3.9313369 -1.8516166 -3.2839324 -6.9028835 8.055535
-1.080044 -1.732337 ]
[ -3.1068752 2.6514802 3.7293913 -1.7883471 -5.44104
-4.5572286 -4.829409 ]
[ 2.6451612 -3.1832254 3.171578 -4.6448216 -4.001822
-6.899353 -0.6295476 ]]
[[ -0.65707624 -1.9670736 6.3386445 2.3041923 -4.439172
-2.9729037 -0.94020796]
[ 0.43153757 5.194006 0.45434368 3.0731819 4.0513067
-5.8058457 6.947601 ]
[ -4.2653627 0.9031774 -1.6685407 -5.4121113 0.5529208
0.7007126 9.279081 ]]
[[ -0.37299162 2.7452188 1.9330034 3.6408103 -5.0701776
1.1965587 0.59263295]
[ 4.81972 -1.1006856 7.8824034 5.260598 3.434634
0.04601002 8.869657 ]
[ 4.231048 1.5457909 -4.7653384 -3.4977267 3.7780495
-5.872396 12.113913 ]]]
[[[ -3.8766992 -0.398234 -1.9723368 1.2132525 0.56892383
1.2515173 3.7913866 ]
[ -0.4337333 1.8678297 5.1747704 -0.6080067 -1.3174248
-1.7126535 0.4686459 ]
[ -5.754308 -2.4168007 -3.6410232 -4.5670137 1.6215359
-4.580209 -5.5926514 ]]
[[ 11.04498 4.4554973 3.8934658 -1.4875691 -11.931008
4.515834 -6.144173 ]
[ 3.8855233 -7.6059284 5.552779 -0.4441495 4.6369743
2.3952575 4.981801 ]
[ -4.5357304 8.016967 -3.8956852 8.697634 0.7237491
-1.2161034 9.980692 ]]
[[ -3.8816683 -6.1477547 6.313223 3.8985054 -2.1990623
2.0681944 -0.53726804]
[ 0.9768859 0.2593964 5.1300526 -4.3372006 4.838679
1.2677834 1.0290532 ]
[ -2.7676988 6.0724287 4.556395 -2.004102 -0.79856735
2.4891334 -1.8703268 ]]]
[[[ -2.4113853 -4.7984595 -0.28992027 1.1324785 5.6149826
3.4891384 -0.2521189 ]
[ 11.86079 -2.660718 1.3913785 -9.618228 0.04568058
-2.8031406 1.12844 ]
[ -0.08115374 2.8916602 -5.7155695 -5.4544435 2.526495
6.5253263 1.3852744 ]]
[[ -1.5733382 -0.08704215 2.6952646 7.385515 -0.7799995
3.1702318 -14.530704 ]
[ 0.05908662 -13.9438095 -1.154305 3.4328744 7.0506897
-5.0249805 2.5534477 ]
[ 0.61222774 0.14303133 4.685219 -7.0924406 1.7709903
1.0107443 -4.5374393 ]]
[[ -5.6678987 0.6903403 2.23693 1.2741803 -6.179094
3.0454116 -5.2941957 ]
[ 0.23656422 -2.2511265 3.3220747 2.021302 -3.070989
-3.815312 3.7513428 ]
[ 5.048253 5.163742 -3.064779 5.2195883 6.6997313
-2.0612605 2.076776 ]]]
[[[ -1.1741709 0.50855964 3.7991686 6.946745 -0.99349356
1.4751754 -1.08081 ]
[ 2.1064334 0.3293423 1.8446237 -0.3842956 3.8418627
-2.5760477 -4.709687 ]
[ -3.8787804 5.9237094 -3.8139226 3.2697144 -2.5398688
4.3881574 11.573359 ]]
[[ -3.1857545 7.100687 -3.9305675 0.6854049 -1.2562029
1.2753329 8.361776 ]
[ 2.7635245 -1.649135 -1.3044827 5.9628034 7.0507197
8.040147 -0.5544966 ]
[ 6.0894575 1.864697 2.0811782 -8.773295 3.7755995
5.5564737 -3.4745088 ]]
[[ 1.3517151 2.8740213 -6.181453 0.21349654 -5.9370227
-1.6817973 3.0836923 ]
[ -0.7866033 2.7180645 3.2119308 4.905232 -3.8589058
-3.349786 -1.2415386 ]
[ 7.3208423 7.184522 1.8396591 0.25130635 4.5287986
-1.9662986 -5.4157324 ]]]
[[[ -1.796482 -0.19289398 0.08456608 9.18009 4.3642817
3.9750414 10.058201 ]
[ -3.404979 10.002911 2.6454616 0.09656489 -5.6097493
2.0856397 8.30741 ]
[ -6.1940312 0.20053774 7.5518293 1.6553136 -6.075909
1.9946573 -8.276907 ]]
[[ 1.5515908 -4.065265 6.201588 -10.958014 2.8450232
1.7398013 6.308612 ]
[ 1.3526641 -0.20383507 -0.97939104 -12.001176 6.5776787
7.0159016 -2.6269057 ]
[ -3.5487242 -2.0833373 2.128775 8.243093 -1.1012591
3.3278828 0.64393663]]
[[ 2.3041837 -1.2524377 3.4256964 3.190121 0.32376206
1.0883296 -3.531728 ]
[ -2.393531 0.57050663 -3.172806 7.0572777 -0.7350081
-2.5658474 -6.9233646 ]
[ -1.0682559 -0.22647202 10.799706 -5.5458803 -3.2260892
-0.6237745 6.320084 ]]]
[[[ 8.890318 1.926058 -5.8980203 3.4635465 2.0711088
-1.0413806 -6.304987 ]
[ -7.1290493 -8.781645 -10.162883 3.1751637 2.1062303
-0.04042304 -14.788281 ]
[ -1.382834 -7.988844 2.7986026 -1.9692816 0.30068183
-1.4710974 -5.3116736 ]]
[[ -7.576119 -3.2894049 0.7375753 -1.3818941 2.9862103
-6.683834 -7.8058653 ]
[ 4.9312177 -0.04471028 -0.34124258 8.375692 -8.983649
-2.1781216 -12.752575 ]
[ 9.337945 -5.1725883 10.788802 0.9727853 -2.5389743
1.0551623 1.4216776 ]]
[[ 1.5142308 4.546703 -2.5327616 4.6643023 -2.0437615
-1.7893765 4.8349857 ]
[ 3.843536 8.979685 -5.5770497 12.787272 3.2864804
-9.081071 5.1559086 ]
[ -3.7020745 9.714738 -5.7880783 -2.3634226 4.0264153
5.8175054 -7.454776 ]]]]
参考:
1、https://blog.csdn.net/mao_xiao_feng/article/details/53444333
2、https://www.tensorflow.org/api_docs/python/tf/nn/conv2d
3、CNN原理介绍 https://blog.csdn.net/v_july_v/article/details/51812459