Python 高级项目在Python中使用TensorFlow实现CIFAR10.docx
- 文档编号:27265230
- 上传时间:2023-06-28
- 格式:DOCX
- 页数:46
- 大小:41.33KB
Python 高级项目在Python中使用TensorFlow实现CIFAR10.docx
《Python 高级项目在Python中使用TensorFlow实现CIFAR10.docx》由会员分享,可在线阅读,更多相关《Python 高级项目在Python中使用TensorFlow实现CIFAR10.docx(46页珍藏版)》请在冰豆网上搜索。
Python高级项目在Python中使用TensorFlow实现CIFAR10
Python高级项目
Python有着广泛的应用——从“HelloWorld”一路走到实现人工智能。
实际上,您可以使用Python进行无限多的项目,但如果您想深入了解Python的核心,可以考虑以下几个主要的项目。
使用PyTorch、TensorFlow、Keras和您喜欢的任何机器学习库进行机器学习。
使用OpenCV和PIL研究计算机视觉。
使用测试和文档,创建和发布自己的pip模块。
在这些里面,我最喜欢的就是机器学习和深度学习。
让我们看一个非常好的用例以便深入学习Python。
在Python中使用TensorFlow实现CIFAR10
让我们训练一个网络,对CIFAR10数据集中的图像进行分类。
可以使用TensorFlow内置的卷积神经网络。
为理解用例的工作原理,我们考虑以下流程图:
我们把这个流程图分解成简单的组分:
首先将图像加载到程序中
这些图像存储在程序可以访问的位置
将数据规范化,因为我们需要Python来理解当前的信息。
定义神经网络的基础。
定义损失函数以确保我们在数据集上获得最大精度
训练实际模型,了解一些它所一直看到的数据
对模型进行测试,以分析其准确性,并迭代整个训练过程,以获得更好的精度。
这个用例分为两个程序。
一个是训练网络,另一个是测试网络。
我们先训练一下这个网络。
训练网络
import
import
from
numpy
as
np
tensorflow
as
tf
time
import
time
import
from
math
include.data
include.model
import
import
get_data_set
from
model,
lr
train_x,train_y=get_data_set("train")
test_x,
test_y
=
get_data_set("test")
tf.set_random_seed(21)
x,y,output,y_pred_cls,global_step,learning_rate=model()
global_accuracy=0
epoch_start
=
=
0
#
PARAMS
_BATCH_SIZE
128
_EPOCH
=
60
=
_SAVE_PATH
"./tensorboard/cifar-10-v1.0.0/"
OPTIMIZER
#
LOSS
AND
loss
=
tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=output,
labels=y))
optimizer
=
tf.train.AdamOptimizer(learning_rate=learning_rate,
beta1=0.9,
beta2=0.999,
epsilon=1e-08).minimize(loss,
global_step=global_step)
CALCULATION
#
PREDICTION
AND
ACCURACY
correct_prediction=tf.equal(y_pred_cls,tf.argmax(y,axis=1))
accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
#SAVER
merged
=
tf.summary.merge_all()
saver
sess
=
tf.train.Saver()
tf.Session()
=
train_writer
try:
=
tf.summary.FileWriter(_SAVE_PATH,
sess.graph)
print("\nTrying
last_chk_path
to
restore
last
checkpoint
...")
=
tf.train.latest_checkpoint(checkpoint_dir=_SAVE_PATH)
saver.restore(sess,
save_path=last_chk_path)
print("Restored
checkpoint
from:
",
last_chk_path)
except
ValueError:
print("\nFailed
to
restore
checkpoint.
Initializing
variables
instead.")
sess.run(tf.global_variables_initializer())
def
train(epoch):
global
epoch_start
epoch_start
batch_size
=
0
time()
=
int(math.ceil(len(train_x)
/
_BATCH_SIZE))
i_global
=
for
s
in
range(batch_size):
batch_xs
batch_ys
=
=
train_x[s*_BATCH_SIZE:
train_y[s*_BATCH_SIZE:
(s+1)*_BATCH_SIZE]
(s+1)*_BATCH_SIZE]
start_time
i_global,
=
time()
_,
batch_loss,
optimizer,
batch_xs,
batch_acc
=
sess.run(
[global_step,
feed_dict={x:
loss,
accuracy],
y:
batch_ys,
learning_rate:
lr(epoch)})
duration
=
time()
-
start_time
if
s
%
10
==
0:
percentage
=
int(round((s/batch_size)*100))
bar_len
=
29
=
filled_len
int((bar_len*int(percentage))/100)
bar
msg
=
=
'='
*
filled_len
+
'>'
-
+
'-'
*
(bar_len
-
filled_len)
"Global
step:
{:
>5}
[{}]
{:
>3}%
-
acc:
{:
.4f}
-
loss:
{:
.4f}
-
{:
.1f}
sample/sec"
print(msg.format(i_global,
bar,
percentage,
batch_acc,
batch_loss,
_BATCH_SIZE
/
duration))
test_and_save(i_global,epoch)
def
test_and_save(_global_step,
epoch):
global
global
global_accuracy
epoch_start
i
=
0
predicted_class
=
np.zeros(shape=len(test_x),
dtype=np.int)
while
i
<
len(test_x):
j
=
min(i
+
_BATCH_SIZE,
len(test_x))
batch_xs
=
test_x[i:
j,
:
]
batch_ys
=
test_y[i:
j,
:
]
predicted_class[
print(mes.format((epoch+1),
acc,
correct_numbers,
len(test_x),
int(hours),
int(minutes),
seconds))
if
global_accuracy
!
=
0
and
global_accuracy
<
acc:
summary
=
tf.Summary(value=[
tf.Summary.Value(tag="Accuracy/test",
simple_value=acc),
print(mes.format(acc,
global_accuracy))
global_accuracy
=
acc
elif
global_accuracy
==
0:
global_accuracy
=
acc
print("###########################################################################################################")
defmain():
train_start
=
time()
for
i
in
range(_EPOCH):
print("\nEpoch:
train(i)
{}/{}\n".format((i+1),
_EPOCH))
hours,
rem
=
divmod(time()
-
train_start,
3600)
minutes,seconds=divmod(rem,60)
mes
=
"Best
accuracy
pre
session:
{:
.2f},
time:
{:
0>2}:
{:
0>2}:
{:
05.2f}"
print(mes.format(global_accuracy,
int(hours),
int(minutes),
seconds))
if
__name__
==
"__main__":
main()
sess.close()
输出:
Epoch:
60/60
Global
Global
Global
Global
Global
Global
Global
Global
Global
step:
step:
step:
step:
step:
step:
step:
step:
step:
23070
23080
23090
23100
23110
23120
23130
23140
23150
-
-
-
-
-
-
-
-
-
[>-----------------------------]
[>-----------------------------]
[=>----------------------------]
[==>---------------------------]
[==>---------------------------]
[===>--------------------------]
[====>-------------------------]
[=====>------------------------]
[=====>------------------------]
0%
3%
5%
8%
-
-
-
-
acc:
acc:
acc:
acc:
0.9531
0.9453
0.9844
0.9297
0.9141
0.9297
0.9297
0.9375
0.9297
-
-
-
-
loss:
loss:
loss:
loss:
loss:
loss:
loss:
loss:
loss:
1.5081
1.5159
1.4764
1.5307
1.5462
1.5314
1.5307
1.5231
1.5301
-
-
-
-
7045.4
7147.6
7154.6
7104.4
7091.4
7162.9
7174.8
7140.0
7152.8
sample/sec
sample/sec
sample/sec
sample/sec
sample/sec
sample/sec
sample/sec
sample/sec
sample/sec
10%
-
-
-
-
-
acc:
-
-
-
-
-
-
-
-
-
-
13%
15%
18%
20%
acc:
acc:
acc:
acc:
Global
Global
Global
Global
Global
Global
Global
Global
Global
Global
Global
step:
step:
step:
step:
step:
step:
step:
step:
step:
step:
step:
23160
23170
23180
23190
23200
23210
23220
23230
23240
23250
23260
-
-
-
-
-
-
-
-
-
-
-
[======>-----------------------]
[=======>----------------------]
[========>---------------------]
[========>---------------------]
[=========>--------------------]
[==========>-------------------]
[===========>------------------]
[===========>------------------]
[============>-----------------]
[=============>----------------]
[==============>---------------]
23%
26%
28%
31%
33%
36%
38%
41%
43%
46%
49%
-
-
-
-
-
-
-
-
-
-
-
acc:
acc:
acc:
acc:
acc:
acc:
acc:
acc:
acc:
acc:
acc:
0.9531
0.9609
0.9531
0.9609
0.9609
0.9375
0.9453
0.9375
0.9219
0.8828
0.9219
-
-
-
-
-
-
-
-
-
-
-
loss:
loss:
loss:
loss:
loss:
loss:
loss:
loss:
loss:
loss:
loss:
1.5080
1.5000
1.5074
1.4993
1.4995
1.5231
1.5153
1.5233
1.5387
1.5769
1.5383
-
-
-
-
-
-
-
-
-
-
-
7112.3
7154.0
6862.2
7134.5
7166.0
7116.7
7134.1
7074.5
7176.9
7144.1
7059.7
sample/sec
sample/sec
sample/sec
sample/sec
sample/sec
sample/sec
sample/sec
sample/sec
sample/sec
sample/sec
sample/sec
Globalstep:
23270-[==============>---------------]
Globalstep:
23280-[===============>--------------]
Globalstep:
23290-[================>-------------]
51%-acc:
0.8984-loss:
1.5618-6638.6sample/sec
54%-acc:
0.9453-loss:
1.5151-7035.7sample/sec
56%-acc:
0.9609-loss:
1.4996-7129.0sample/sec
Global
Global
step:
step:
23300
23310
-
-
[=================>------------]
[=================>------------]
59%
61%
-
-
acc:
acc:
0.9609
0.8750
-
-
loss:
loss:
1.4997
1.5842
-
-
7075.4
7117.8
sample/sec
sample/sec
Globalstep:
23320-[==================>-----------]
Globalstep:
23330-[===================>----------]
64%-acc:
0.9141-loss:
1.5463-7157.2sample/sec
66%-acc:
0.9062-loss:
1.5549-7169.3sample/sec
Global
Global
Global
step:
step:
step:
23340
23350
23360
-
-
-
[====================>---------]
[====================>---------]
[=====================>--------]
69%
72%
74%
-
-
-
acc:
acc:
acc:
0.9219
0.9609
0.9766
-
-
-
loss:
loss:
loss:
1.5389
1.5002
1.4842
-
-
-
7164.4
7135.4
7124.2
sample/sec
sample/sec
sample/sec
Globalstep:
23370-[======================>-------]
Globalstep:
23380-[======================>-------]
Globalstep:
23390-[=======================>------]
Globalstep:
23400-[========================>-----]
77%-acc:
0.9375-loss:
1.5231-7168.5sample/sec
79%-acc:
0.8906-loss:
1.5695-7175.2sample/sec
82%-acc:
0.9375-loss:
1.5225-7132.1sample/sec
84%-acc:
0.9844-loss:
1.4768-7100.1sample/sec
Global
step:
23410
-
[=========================>----]
87%
-
acc:
0.9766
-
loss:
1.4840
-
7172.0
sample/sec
Globalstep:
23420-[==========================>---]
90%-acc:
0.9062-loss:
1.5542-7122.1sample/sec
Global
Global
Global
Global
step:
step:
step:
step:
23430
23440
23450
23460
-
-
-
-
[==========================>---]
[===========================>--]
[============================>-]
[=============================>]
92%
95%
97%
-
-
-
acc:
acc:
acc:
0.9297
0.9297
0.9375
0.9250
-
-
-
loss:
loss:
loss:
loss:
1.5313
1.5301
1.5231
1.5362
-
-
-
7145.3
7133.3
7135.7
10297.5
sample/sec
sample/sec
sample/sec
sample/sec
100%
-
acc:
-
-
Epoch
This
60
-
accuracy:
78.81%
(7881/10000)
epoch
receive
better
accuracy:
78.81
>
78.78.
Saving
session...
###########################################################################################################
在测试数据集上运行网络
import
import
from
numpy
as
np
tensorflow
include.data
include.model
as
tf
import
import
get_data_set
model
from
test_x,
test_y
=
get_data_set("test")
x,
y,
output,
y_pred_cls,
global_step,
learning_rate
=
model()
_BATCH_SIZE
_CLASS_SIZE
_SAVE_PATH
=
=
128
10
=
"./tensorboard/cifar-10-v1.0.0/"
saver=tf.train.Saver()
sess
try:
=
tf.Session()
print("\nTrying
to
restore
last
checkpoint
...")
last_chk_path
=
tf.train.latest_checkpoint(checkpoint_dir=_SAVE_PATH)
saver.restore(sess,
save_path=last_chk_path)
print("Restored
checkpoint
from:
",
last_c
- 配套讲稿:
如PPT文件的首页显示word图标,表示该PPT已包含配套word讲稿。双击word图标可打开word文档。
- 特殊限制:
部分文档作品中含有的国旗、国徽等图片,仅作为作品整体效果示例展示,禁止商用。设计者仅对作品中独创性部分享有著作权。
- 关 键 词:
- Python 高级项目在Python中使用TensorFlow实现CIFAR10 高级 项目 使用 TensorFlow 实现 CIFAR10