Bootstrap

Python从0到100(八十五):神经网络-使用迁移学习完成猫狗分类

在这里插入图片描述

前言: 零基础学Python:Python从0到100最新最全教程 想做这件事情很久了,这次我更新了自己所写过的所有博客,汇集成了Python从0到100,共一百节课,帮助大家一个月时间里从零基础到学习Python基础语法、Python爬虫、Web开发、 计算机视觉、机器学习、神经网络以及人工智能相关知识,成为学习学习和学业的先行者!
欢迎大家订阅专栏:零基础学Python:Python从0到100最新最全教程!

今天来学习一下如何使用基于tensorflow和keras的迁移学习完成猫狗分类,欢迎大家一起前来探讨学习~

说明:在此试验下,我们使用的是使用tf2.x版本,在jupyter环境下完成
在本文中,我们将主要完成以下任务:

  1. 实现基于tensorflow和keras的迁移学习

  2. 加载tensorflow提供的数据集(不得使用cifar10)

  3. 需要使用markdown单元格对数据集进行说明

  4. 加载tensorflow提供的预训练模型(不得使用vgg16)

  5. 需要使用markdown单元格对原始模型进行说明

  6. 网络末端连接任意结构的输出端网络

  7. 用图表显示准确率和损失函数

  8. 用cnn工具可视化一批数据的预测结果

  9. 用cnn工具可视化一个数据样本的各层输出

一、加载数据集

1.调用库函数

import matplotlib.pyplot as plt
import numpy as np
import os
import tensorflow as tf
import cnn_utils
from tensorflow.keras.preprocessing import image_dataset_from_directory
from tensorflow.keras.layers import GlobalAveragePooling2D,Dense,Input,Dropout

2.加载数据集

数据集加载,数据是通过这个网站下载的猫狗数据集:http://aimaksen.bslience.cn/cats_and_dogs_filtered.zip,实验中为了训练方便,我们取了一个较小的数据集。

path_to_zip = tf.keras.utils.get_file(
    'data.zip',
    origin='http://aimaksen.bslience.cn/cats_and_dogs_filtered.zip',
    extract=True,
)
PATH = os.path.join(os.path.dirname(path_to_zip), 'cats_and_dogs_filtered')

train_dir = os.path.join(PATH, 'train')
validation_dir = os.path.join(PATH, 'validation')

BATCH_SIZE = 32
IMG_SIZE = (160, 160)

3.数据集管理

使用image_dataset_from_director进行数据集管理,使用ImageDataGenerator训练过程中会出现错误,不知道是什么原因,就使用了原始的image_dataset_from_director方法进行数据集管理。

train_dataset = image_dataset_from_directory(train_dir,
                                             shuffle=True,
                                             batch_size=BATCH_SIZE,
                                             image_size=IMG_SIZE)

validation_dataset = image_dataset_from_directory(validation_dir,
                                                  shuffle=True,
                                                  batch_size=BATCH_SIZE,
                                                  image_size=IMG_SIZE)

二、猫狗数据集介绍

1.猫狗数据集介绍:

猫狗数据集包括25000张训练图片,12500张测试图片,包括猫和狗两种图片。在此次实验中为了训练方便,我们取了一个较小的数据集。 数据解压之后会有两个文件夹,一个是 “train”,一个是 “test”,顾名思义一个是用来训练的,另一个是作为检验正确性的数据。
在这里插入图片描述
在train文件夹里边是一些已经命名好的图像,有猫也有狗。而在test文件夹中是只有编号名的图像。
在这里插入图片描述

2.图片展示

下面是数据集中的图片展示:

class_names = ['cats', 'dogs']

plt.figure(figsize=(10, 10))
for images, labels in train_dataset.take(1):
  for i in range(9):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(images[i].numpy().astype("uint8"))
    plt.title(class_names[labels[i]])
    plt.axis("off")

🌟🌟🌟 这里是输出的结果:✨✨✨
在这里插入图片描述

三、MobileNetV2网络介绍

1.加载tensorflow提供的预训练模型

val_batches = tf.data.experimental.cardinality(validation_dataset)
test_dataset = validation_dataset.take(val_batches // 5)
validation_dataset = validation_dataset.skip(val_batches // 5)

2.轻量级网络——MobileNetV2

使用轻量级网络——MobileNetV2进行数据预处理 说明: MobileNetV2是基于倒置的残差结构,普通的残差结构是先经过 1x1 的卷积核把 feature map的通道数压下来,然后经过 3x3 的卷积核,最后再用 1x1 的卷积核将通道数扩张回去,即先压缩后扩张,而MobileNetV2的倒置残差结构是先扩张后压缩
在这里插入图片描述

3.MobileNetV2的网络模块

MobileNetV2的网络模块样子是这样的:
在这里插入图片描述
MobileNetV2是基于深度级可分离卷积构建的网络,它是将标准卷积拆分为了两个操作:深度卷积 和 逐点卷积,深度卷积和标准卷积不同,对于标准卷积其卷积核是用在所有的输入通道上,而深度卷积针对每个输入通道采用不同的卷积核,就是说一个卷积核对应一个输入通道,所以说深度卷积是depth级别的操作。而逐点卷积其实就是普通的卷积,只不过其采用1x1的卷积核。
MobileNetV2的模型如下图所示,其中t为Bottleneck内部升维的倍数,c为通道数,n为该bottleneck重复的次数,s为sride
在这里插入图片描述

其中,当stride=1时,才会使用elementwise 的sum将输入和输出特征连接(如下图左侧);stride=2时,无short cut连接输入和输出特征(下图右侧):
在这里插入图片描述

四、搭建迁移学习

1.训练

inital_input = tf.keras.applications.mobilenet_v2.preprocess_input
IMG_SHAPE = IMG_SIZE + (3,)
base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,
                                               include_top=False,
                                               weights='imagenet')
base_model.trainable = False
base_model.summary()

🌟🌟🌟 这里是输出的结果:✨✨✨
在这里插入图片描述

2.训练结果可视化

用图表显示准确率和损失函数

# 训练结果可视化,用图表显示准确率和损失函数
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

epochs_range=range(initial_epochs)
plt.figure(figsize=(8,8))
plt.subplot(2,1,1)
plt.plot(epochs_range, acc, label="Training Accuracy")
plt.plot(epochs_range, val_acc,label="Validation Accuracy")
plt.legend()
plt.title("Training and Validation Accuracy")
 
plt.subplot(2,1,2)
plt.plot(epochs_range, loss, label="Training Loss")
plt.plot(epochs_range, val_loss,label="Validation Loss")
plt.legend()
plt.title("Training and Validation Loss")
plt.show()

🌟🌟🌟 这里是输出的结果:✨✨✨
在这里插入图片描述

3.输出训练的准确率

# 输出训练的准确率
test_loss, test_accuracy = model.evaluate(test_dataset)
print('test accuracy: {:.2f}'.format(test_accuracy))

🌟🌟🌟 这里是输出的结果:✨✨✨
在这里插入图片描述

4.用cnn工具可视化一批数据的预测结果

label_dict = {
    0: 'cat',
    1: 'dog'
}

test_image_batch, test_label_batch = test_dataset.as_numpy_iterator().next()
# 编码成uint8 以图片形式输出
test_image_batch = test_image_batch.astype('uint8')

cnn_utils.plot_predictions(model, test_image_batch, test_label_batch, label_dict, 32, 5, 5)

🌟🌟🌟 这里是输出的结果:✨✨✨
在这里插入图片描述

5.数据输出

# 数据输出,数字化特征图
test_image_batch, test_label_batch = train_dataset.as_numpy_iterator().next()

img_idx = 0
random_batch = np.random.permutation(np.arange(0,len(test_image_batch)))[:BATCH_SIZE]
image_activation = test_image_batch[random_batch[img_idx]:random_batch[img_idx]+1]

cnn_utils.get_activations(base_model, image_activation[0])

🌟🌟🌟 这里是输出的结果:✨✨✨
在这里插入图片描述

6.用cnn工具可视化一个数据样本的各层输出

cnn_utils.display_activations(cnn_utils.get_activations(base_model, image_activation[0]))

🌟🌟🌟 这里是输出的结果:✨✨✨
在这里插入图片描述

7.输出结果图像

🌟🌟🌟 这里是输出的结果:✨✨✨
在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

文末送书

本期推荐1:

《Java面向对象程序设计:AI大模型给程序员插上翅膀》
AI工具助力Java编程:故事引领思政,AI助力学习;任务驱动实践,项目提升能力。
在这里插入图片描述

京东:https://item.jd.com/14850722.html

从AI助力角度出发,轻松学习编程
故事引入思政,引发读者动手实践
引出目标任务,明确学习目的和方向
AI学习问答与同步训练,提升学习效率
丰富的学习资源,助力实际项目开发
内容简介
随着云计算、物联网、大数据、人工智能等新一代信息技术的发展,Java 作为一种高性能、跨平台的编程语言,有着广泛的应用。本书从应用的角度详尽介绍了Java开发的核心技术。
全书分为12章,主要介绍了Java开发环境、Java编程基础、类和对象、继承和多态、抽象类和接口、Java常用类、内部类和泛型、集合容器、JDBC编程、图形用户界面设计、多线程,最后通过企业项目管理的方式进行实践,实现一个完整案例。
本书每章都通过故事的方式引入思政,并且从故事中引出目标任务。针对目标任务,辅以人工智能工具(ChatGPT、文心一言、讯飞星火)的帮助,得到行之有效的示例。之后对其进行知识解析,并完成上机练习。通过相关的练习巩固知识,并在合适的阶段引入一些常见的算法,加强学生的逻辑思维能力。在每章末尾有AI学习问答,让读者自行探索,同时加入同步训练,加强学习效果。

本期推荐2:

《Python金融大数据分析》
掌握Python,从零到一速成金融分析高手!实战案例深剖,让数字说话,让决策更精准!深入了解金融数据分析的具体过程和方法,提高实操能力。附赠书中案例源代码。
在这里插入图片描述

京东:https://item.jd.com/14827368.html

系统:全面构建Python金融大数据分析框架,从零到一,系统掌握核心技能,让学习之路有条不紊。
经典:凝聚笔者多年智慧,解读大数据在金融领域的应用,确保学习内容前沿且可靠。
深入:深度剖析Python在金融大数据分析中的关键技术,直击核心难点,助您深入理解数据背后的价值。
案例:精选实战案例,让您在真实场景中磨炼技能,实现从理论到实践的完美跨越。
内容简介
本书共分为11 章,全面介绍了以Python为工具的金融大数据的理论和实践,特别是量化投资和交易领域的相关应用,并配有项目实战案例。书中涵盖的内容主要有Python概览,结合金融场景演示Python的基本操作,金融数据的获取及实战,MySQL数据库详解及应用,Python在金融大数据分析方面的核心模块详解,金融分析及量化投资,Python量化交易,数据可视化Matplotlib,基于NumPy的股价统计分析实战,基于Matplotlib的股票技术分析实战,以及量化交易策略实战案例等。
本书内容通俗易懂,案例丰富,实用性强,特别适合以下人群阅读:金融行业的从业者、数据分析师、量化投资者、希望提高数据分析能力的投资者,以及对大数据分析感兴趣的编程人员。另外,本书也适合作为相关培训机构的教材。

;