Bootstrap

torchvision.utils.make_grid详解

torchvision.utils.make_grid 是 PyTorch 的 torchvision 库中的一个函数,用于将多张图片排列成一个网格形式的单张图片,方便于可视化和展示。

功能和用途
当我们在训练或者测试深度学习模型时,经常需要将一批图片进行可视化,make_grid 函数可以将这批图片整理成一个网格,使得可以一次性地将多张图片展示在同一张图片中。这对于监控训练进度、可视化特征图等都非常有用。

参数说明

torchvision.utils.make_grid(tensor, nrow=8, padding=2, normalize=False, range=None, scale_each=False, pad_value=0)

tensor: 输入的张量,通常是一个形状为 (B, C, H, W) 的张量,其中 B 是批次大小,C 是通道数,H 和 W 是图片的高度和宽度。

nrow: 每行图片的数量,默认为 8。如果图片很多,可以适当增加这个值。

padding: 每张图片之间的填充像素数,默认为 2。这个参数可以控制网格中每张图片之间的间隔。

normalize: 是否对图片进行归一化处理,默认为 False。如果设置为 True,函数会将像素值归一化到 [0, 1] 范围。

range: 归一化的范围,如果不为 None,则会将像素值归一化到指定的范围。

scale_each: 如果为 True,则会对每张图片单独进行归一化处理。

pad_value: 填充值,用于设置填充像素的颜色,默认为 0。

返回值
返回一个形状为 (C, H, W) 的张量,其中 C 是通道数,H 和 W 是生成的网格图片的高度和宽度。

示例
假设我们有一个形状为 (16, 3, 64, 64) 的张量,其中包含了 16 张大小为 64x64 的彩色图片。我们可以使用 make_grid 将这些图片排列成一个网格:

import torch
import torchvision
import matplotlib.pyplot as plt

# 创建一些示例图片
images = torch.randn(16, 3, 64, 64)  # 16 张 64x64 的彩色图片

# 将图片制作成网格
grid_img = torchvision.utils.make_grid(images, nrow=4, padding=2)

# 可视化网格图片
plt.figure(figsize=(10, 10))
plt.imshow(grid_img.permute(1, 2, 0))  # 调整通道顺序以适应 matplotlib 的要求
plt.axis('off')
plt.show()

在这个示例中,make_grid 将 16 张图片排列成一个 4x4 的网格,每张图片之间有 2 个像素的填充。最后,我们使用 Matplotlib 显示了生成的网格图片。

使用注意事项
输入张量格式: make_grid 函数的输入张量应该是 (B, C, H, W) 形状的,其中 B 是批次大小,C 是通道数,H 和 W 是图片的高度和宽度。

归一化: 如果希望图片在显示时能够正确映射到颜色空间,通常需要对像素值进行归一化。

可视化: 生成的网格图片可以使用 Matplotlib 或者其他图片展示工具进行可视化。

通过 torchvision.utils.make_grid,可以方便地将多张图片整理成网格形式,便于深度学习模型的可视化和分析。

;