Bootstrap

PyTorch——从入门到精通:PyTorch基础知识(normal 函数)【PyTorch系统学习】

torch.normal() 的用法

        该函数的参数如下:

normal(mean, std, *, generator=None, out=None)

        参数说明

  1. mean:

    • 均值,可以是一个数值(标量)或者张量。
    • 如果是张量,则指定生成正态分布的均值,形状需与标准差匹配
  2. std:

    • 标准差,可以是一个数值(标量)或者张量。
    • 如果是张量,则指定生成正态分布的标准差,形状需与均值匹配
  3. generator (可选): 用于生成随机数的随机数生成器。

  4. out (可选): 如果提供,则结果存储在这个张量中。

        返回值

        返回一个与 meanstd 的形状匹配的张量,值是符合指定正态分布的随机数。

        函数使用示例

        生成一个标量值

import torch
# 生成均值为0,标准差为1的正态分布随机数
result = torch.normal(0, 1)
print(result)

        生成一个张量

import torch
# 生成均值为0,标准差为1的3x3张量
result = torch.normal(0, 1, size=(3, 3))
print(result)

        不同均值和标准差的张量

import torch
# 均值和标准差为张量
mean = torch.tensor([0.0, 1.0, 2.0])
std = torch.tensor([1.0, 0.5, 0.25])
result = torch.normal(mean, std)
print(result)

        高维张量的均值和标准差

        当函数的作用对象拓展到高维张量时,相信还有很多小伙伴不太理解,torch.normal()函数中的均值和标准差是如何体现和作用的,接下来就针对于不同情况来细致的讲解一下:

        情况一:当 meanstd 是标量时,生成的矩阵中所有元素都使用同一个均值和标准差。

import torch

# 生成一个 3x3 的矩阵,每个元素均值为 0,标准差为 1
matrix = torch.normal(0, 1, size=(3, 3))
print(matrix)

        例如,在上述的代码中,虽然我们生成的是一个3✖3的矩阵,但由于我们的均值和标准差都是标量,因此每个元素都是从均值为 0、标准差为 1 的正态分布中独立采样。

        情况二:当 meanstd 是张量时,生成的矩阵中每个位置的元素使用对应位置的均值和标准差。

import torch

# 均值和标准差为矩阵
mean = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
std = torch.tensor([[0.1, 0.2], [0.3, 0.4]])

# 生成矩阵
matrix = torch.normal(mean, std)
print(matrix)

        这个时候,对于生成的矩阵matrix:

  • matrix[0,0] 从 N( 1.0, 0.1^2) 中采样。
  • matrix[0,1] 从 N( 2.0, 0.2^2) 中采样。
  • matrix[1,0] 从 N( 3.0, 0.3^2) 中采样。
  • matrix[1,1] 从 N( 4.0, 0.4^2) 中采样。

        这也是为什么在参数说明中强调均值和标准差的形状需要相互匹配,并且返回值的形状在默认情况下会与 meanstd 的形状一致,在这道题中的形状即为(2,2)。

        情况三:如果 meanstd 的形状不同,但满足广播规则,PyTorch 会自动扩展较小的张量以匹配较大的张量。广播机制的介绍曾在先前的博客(PyTorch基础知识(张量))中有所涉及。

import torch

mean = torch.tensor([1.0, 2.0])  # 1D 张量,形状为 (2,)
std = torch.tensor([[0.1], [0.2]])  # 2D 张量,形状为 (2, 1)

# 广播机制
matrix = torch.normal(mean, std)
print(matrix)

广播过程

  1. mean 的形状扩展为 (2, 2),则拓展后的均值为[[1, 2],[1,2]]。
  2. std 的形状扩展为 (2, 2),则拓展后的标准差为[[ 0.1, 0.1 ], [ 0.2, 0.2]]

结果

  • matrix[0,0]  从 N(1.0,0.1^2) 中采样。
  • matrix[0,1]  从 N(2.0,0.1^2) 中采样。
  • matrix[1,0]  从 N(1.0,0.2^2) 中采样。
  • matrix[1,1]  从 N(2.0,0.2^2) 中采样。

感谢阅读,希望对你有所帮助~

;