生成器和yield关键字
1.生成器介绍:
概述:
它指的是 generator, 类似于以前学过的: 列表推导式, 集合推导式, 字典推导式…
作用:
降低资源消耗, 快速(批量)生成数据.
实现方式:
1.推导式写法.
my_generator = (i for i in range(5))
2.yield写法.
def get_generator():
for i in range(1, 6):
yield i # yield会记录每个生成的数据, 然后逐个的放到生成器对象中, 最终返回生成器对象.
问题: 如何从生成器对象中获取数据?
答案:
1.for循环遍历
2.next()函数, 逐个获取.
# 案例1: 回顾之前的列表推导式, 集合推导式.
# 需求: 生成 1 ~ 5 的数据.
my_list = [i for i in range(1, 6)]
print(my_list, type(my_list)) # [1, 2, 3, 4, 5] <class 'list'>
my_set = {i for i in range(1, 6)}
print(my_set, type(my_set)) # {1, 2, 3, 4, 5} <class 'set'>
# 案例2: 演示 生成器写法1, 推导式写法
# 尝试写一下, "元组"推导式, 发现打印的结果不是元组, 而是对象, 因为这种写法叫: 生成器.
my_tuple = (i for i in range(1, 6))
print(my_tuple) # <generator object <genexpr> at 0x0000024C90F056D0> 生成器对象
print(type(my_tuple)) # <class 'generator'> 生成器类型
print('-' * 31)
# 案例3: 如何从生成器对象中获取数据呢?
# 1. 定义生成器, 获取 1 ~ 5的数字.
my_generator = (i for i in range(1, 6))
# 2. 从生成器中获取数据.
# 格式1: for循环遍历
for i in my_generator:
print(i)
# 格式2: next()函数, 逐个获取.
print(next(my_generator)) # 1
print(next(my_generator)) # 2
2.yield关键字
# 案例: 演示 yield关键字方式, 获取生成器.
# 需求: 自定义 get_generator()函数, 获取 包括: 1 ~ 5之间的整数 生成器.
# 1. 定义函数.
def get_generator():
"""
用于演示 yield关键字的用法
:return: 生成器对象.
"""
# 思路1: 自定义列表, 添加指定元素, 并返回.
# my_list = []
# for i in range(1, 6):
# my_list.append(i)
# return my_list
# 思路2: yield写法, 即: 如下的代码, 效果同上.
for i in range(1, 6):
yield i # yield会记录每个生成的数据, 然后逐个的放到生成器对象中, 最终返回生成器对象.
# 在main中测试.
if __name__ == '__main__':
# 2. 调用函数, 获取生成器对象.
my_generator = get_generator()
# 3. 从生成器中获取每个元素.
print(next(my_generator)) # 1
print(next(my_generator)) # 2
print('-' * 31)
# 4. 遍历, 获取每个元素.
for i in my_generator:
print(i)
3.生成批次的数据
案例: 用生成器生成批次数据, 在模型训练中, 数据都是分批次来 "喂" 的.
需求: 读取项目下的 jaychou_lyrics.txt文件(其中有5000多条 歌词数据), 按照8个 / 批次, 获取生成器, 并从中获取数据.
"""
import math
# 需求1: 铺垫知识, math.ceil(数字): 获取指定数字的天花板数(向上取整), 即: 比这个数字大的所有整数中, 最小的哪个整数.
# print(math.ceil(5.1)) # 6
# print(math.ceil(5.6)) # 6
# print(math.ceil(5.0)) # 5
# 需求2: 获取生成器对象, 从文件中读数据数据, n条 / 批次
# 1. 定义函数 dataset_loader(batch_size), 表示: 数据生成器, 按照 batch_size条 分批.
def dataset_loader(batch_size): # 假设: batch_size = 8
"""
该函数用于获取生成器对象, 每条数据都是一批次的数据. 即: 生成器(8条, 8条, 8条...)
:param batch_size: 每批次有多少条数据
:return: 返回生成器对象.
"""
# 1.1 读取文件, 获取到每条(每行)数据.
with open("./jaychou_lyrics.txt", 'r', encoding='utf-8') as f:
# 一次读取所有行, 每行封装成字符串, 整体放到列表中.
data_lines = f.readlines() # 结果: [第一行, 第二行, 第三行...]
# 1.2 根据上述的数据, 计算出: 数据的总条数(总行数), 假设: 100行(条)
line_count = len(data_lines)
# 1.3 基于上述的总条数 和 batch_size(每批次的条数), 获取: 批次总数(即: 总共多少批)
batch_count = math.ceil(line_count / batch_size) # 例如: math.ceil(100 / 8) = 13
# 1.4 具体的获取每批次数据的动作, 用 yield包裹, 放到生成器中, 并最终返回生成器(对象)即可.
for i in range(batch_count): # batch_count的值: 13, i的值: 0, 1, 2, 3, 4, 5, .... 12
# 1.5 yield会记录每批次数据, 封装到生成器中, 并返回(生成器对象)
"""
推理过程:
i = 0, 代表第1批次数据, 想要 第 1 条 ~~~~ 第 8 条数据, 即: data_lines[0:8]
i = 1, 代表第2批次数据, 想要 第 9 条 ~~~~ 第 16 条数据, 即: data_lines[8:16]
i = 2, 代表第3批次数据, 想要 第 17 条 ~~~~ 第 24 条数据, 即: data_lines[16:24]
......
"""
yield data_lines[i * batch_size: i * batch_size + batch_size]
# 在main中, 测试调用
if __name__ == '__main__':
# 2. 获取生成器对象.
my_generator = dataset_loader(13)
# 3. 从生成器中获取第 1 批数据.
# print(next(my_generator))
# # 从第一批次中, 获取具体的每一条数据.
# for line in next(my_generator):
# print(line, end='')
#
# print('-' * 31)
#
# # 从第二批次中, 获取具体的每一条数据.
# for line in next(my_generator):
# print(line, end='')
# print('-' * 31)
# 4. 查看具体的每一批数据.
for batch_data in my_generator:
print(batch_data)
文件:jaychou_lyrics.txt