Bootstrap

图像处理学习笔记-20241118

霍夫变换

霍夫变换(Hough Transform)是一种用于检测图像中具有特定形状的特征(如直线、圆等)的技术。它通过将图像空间中的点映射到参数空间,以便从噪声中更可靠地检测出全局形状。霍夫变换广泛应用于图像处理和计算机视觉中,特别是用于直线和圆的检测。

基本原理

在图像中,直线可以用笛卡尔坐标系下的方程表示为:
y = − c o s θ / s i n θ ∗ x + r / s i n θ y = -cosθ/sinθ*x + r/sinθ y=cosθ/sinθx+r/sinθ
然而,这时候两个参数不知道r 和θ,因此用极坐标表示为r=f(theta)更为合适:
ρ = x cos ⁡ θ + y sin ⁡ θ \rho = x \cos \theta + y \sin \theta ρ=xcosθ+ysinθ
其中:

  • ρ \rho ρ 表示直线到坐标原点的距离。
  • θ \theta θ 表示直线与 x x x 轴的夹角。

对于图像空间中的每个点 ( x , y ) (x, y) (x,y),可以在参数空间中绘制对应的 ( ρ , θ ) (\rho, \theta) (ρ,θ) 曲线。如果多条曲线在某一点相交,说明这些点共线。
图片来源:【霍夫Hough直线变换原理检测算法的个人简易理解(极简版,看不会说话的吴克的霍夫直线检测观后感)】https://www.bilibili.com/video/BV1k44y1G772?vd_source=b1f5728f3a9c87006fa48f39e09acbab
在这里插入图片描述

霍夫变换的步骤

  1. 边缘检测:在应用霍夫变换之前,通常会先使用 Canny 边缘检测来提取图像的边缘。
  2. 参数空间映射:将图像空间中的每个边缘点映射到参数空间。
  3. 寻找峰值:在参数空间中查找具有最多交点的区域,这些区域对应于图像中的直线。

使用 OpenCV 实现直线检测

OpenCV 提供了 cv2.HoughLinescv2.HoughLinesP 函数来实现标准霍夫变换和概率霍夫变换。

示例:标准霍夫变换
import cv2
import numpy as np

# 读取图像并转换为灰度图
image = cv2.imread('example.jpg')
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)

# 使用 Canny 边缘检测
edges = cv2.Canny(gray, 50, 150)

# 应用标准霍夫变换
lines = cv2.HoughLines(edges, 1, np.pi / 180, 200)

# 绘制检测到的直线
for line in lines:
    rho, theta = line[0]
    a = np.cos(theta)
    b = np.sin(theta)
    x0 = a * rho
    y0 = b * rho
    x1 = int(x0 + 1000 * (-b))
    y1 = int(y0 + 1000 * (a))
    x2 = int(x0 - 1000 * (-b))
    y2 = int(y0 - 1000 * (a))
    cv2.line(image, (x1, y1), (x2, y2), (0, 0, 255), 2)

cv2.imshow('Detected Lines', image)
cv2.waitKey(0)
cv2.destroyAllWindows()

示例:概率霍夫变换

概率霍夫变换是标准霍夫变换的优化版,使用随机采样来提高效率,并返回直线段而不是整条直线。

# 应用概率霍夫变换
lines = cv2.HoughLinesP(edges, 1, np.pi / 180, threshold=100, minLineLength=50, maxLineGap=10)

# 绘制检测到的直线段
for line in lines:
    x1, y1, x2, y2 = line[0]
    cv2.line(image, (x1, y1), (x2, y2), (0, 255, 0), 2)

cv2.imshow('Detected Line Segments', image)
cv2.waitKey(0)
cv2.destroyAllWindows()

参数解释

  • rho:距离分辨率,即霍夫空间的 ρ \rho ρ 单位。
  • theta:角度分辨率,即霍夫空间的 θ \theta θ 单位。
  • threshold:累加器阈值,只有累加值大于该值时才被认为是一条直线。
  • minLineLength(用于 HoughLinesP):直线的最小长度。
  • maxLineGap(用于 HoughLinesP):直线上点之间的最大允许间隙。

霍夫变换检测圆

除了直线,霍夫变换还可以用于检测圆。OpenCV 提供了 cv2.HoughCircles 函数。

# 使用霍夫变换检测圆
circles = cv2.HoughCircles(gray, cv2.HOUGH_GRADIENT, dp=1.2, minDist=30, param1=50, param2=30, minRadius=10, maxRadius=100)

# 绘制检测到的圆
if circles is not None:
    circles = np.round(circles[0, :]).astype("int")
    for (x, y, r) in circles:
        cv2.circle(image, (x, y), r, (0, 255, 0), 4)

cv2.imshow('Detected Circles', image)
cv2.waitKey(0)
cv2.destroyAllWindows()

基于GAN的样本生成

基于生成对抗网络(GAN,Generative Adversarial Networks)的数据增广是一种通过深度学习生成新样本的方法。这些新样本可用来扩展数据集,特别是在样本量有限的情况下提升模型性能。

GAN的基本原理

GAN由两个对抗的网络组成:

  1. 生成器(Generator)

    • 输入随机噪声,输出伪造的样本(例如图像、音频等)。
    • 目标是生成尽可能接近真实样本的伪样本。
  2. 判别器(Discriminator)

    • 输入样本,判断是真实样本还是生成样本。
    • 目标是正确区分真实样本和生成样本。

两者通过对抗性训练(生成器试图欺骗判别器,判别器试图更准确地区分)达到动态平衡。训练完成后,生成器可以生成逼真的新样本。

基于GAN的数据增广流程

  1. 准备数据

    • 收集现有的训练数据集,并对其进行预处理。
  2. 构建和训练GAN

    • 构建GAN网络,包括生成器和判别器。
    • 生成器 (Generator) 是一个神经网络,输入一个随机噪声向量,输出一张与真实图像相似的“假图像”。生成器使用了三层全连接层(Linear)和两个激活函数:第一层将噪声向量扩展到更高的维度。中间层通过 ReLU 激活函数引入非线性。输出层通过 Tanh 将生成图像的像素值归一化到 [ − 1 , 1 ] [-1, 1] [1,1],与数据预处理一致。
    • 判别器 (Discriminator) 是一个神经网络,用于判断输入的图像是真实图像(来自数据集)还是假图像(由生成器生成)。判别器也是一个全连接网络:三层全连接层逐渐降低维度。激活函数使用 LeakyReLU,避免“神经元死亡”问题(梯度过小导致无学习效果)。最后一层通过 Sigmoid 输出一个概率值,表示输入图像为真实的概率,范围为 [ 0 , 1 ] [0, 1] [0,1]
    • 使用训练数据训练GAN,确保生成器能够生成逼真的样本。
  3. 生成新样本

    • 使用训练好的生成器生成新的样本。
    • 新样本可以是多样化的、特定类别的,甚至是风格化的。
  4. 加入原始数据集

    • 将生成的样本加入原始数据集中。
    • 对扩展后的数据集进行训练,提高模型性能。

实现代码示例

下面是一个简单的基于GAN的数据增广实现示例(以MNIST为例):

import torch  # 导入PyTorch库,用于深度学习模型的构建与训练
import torch.nn as nn  # 导入神经网络模块
import torch.optim as optim  # 导入优化器模块
from torchvision import datasets, transforms  # 导入用于处理数据的工具
from torchvision.utils import save_image  # 导入用于保存生成图像的工具

# 定义生成器模型,生成假图像
class Generator(nn.Module):
    def __init__(self, z_dim, img_dim):
        super().__init__()
        # 使用全连接层和激活函数构建生成器网络
        self.net = nn.Sequential(
            nn.Linear(z_dim, 128),  # 输入维度为z_dim,输出为128
            nn.ReLU(),  # 使用ReLU激活函数
            nn.Linear(128, 256),  # 从128维到256维
            nn.ReLU(),  # 再次使用ReLU激活函数
            nn.Linear(256, img_dim),  # 输出维度为图像的展平大小
            nn.Tanh()  # 使用Tanh将输出值归一化到[-1, 1]范围
        )

    # 定义前向传播,接收输入噪声并生成图像
    def forward(self, z):
        return self.net(z)  # 将输入z通过生成器网络

# 定义判别器模型,用于区分真假图像
class Discriminator(nn.Module):
    def __init__(self, img_dim):
        super().__init__()
        # 使用全连接层和激活函数构建判别器网络
        self.net = nn.Sequential(
            nn.Linear(img_dim, 256),  # 输入为图像展平后的维度,输出为256
            nn.LeakyReLU(0.2),  # 使用LeakyReLU激活函数,避免死亡神经元问题
            nn.Linear(256, 128),  # 从256维到128维
            nn.LeakyReLU(0.2),  # 再次使用LeakyReLU激活函数
            nn.Linear(128, 1),  # 最后一层输出1个值,表示真假的概率
            nn.Sigmoid()  # 使用Sigmoid激活函数,将输出映射到[0, 1]
        )

    # 定义前向传播,接收输入图像并输出真假概率
    def forward(self, img):
        return self.net(img)

# 参数设置
z_dim = 64  # 噪声向量的维度
img_dim = 28 * 28  # MNIST图像展平后的大小 (28x28像素)
lr = 0.0002  # 学习率

# 初始化生成器和判别器
generator = Generator(z_dim, img_dim)  # 生成器
discriminator = Discriminator(img_dim)  # 判别器

# 初始化优化器,用于更新生成器和判别器的参数
opt_gen = optim.Adam(generator.parameters(), lr=lr)  # 生成器的Adam优化器
opt_disc = optim.Adam(discriminator.parameters(), lr=lr)  # 判别器的Adam优化器

# 损失函数,使用二元交叉熵损失
criterion = nn.BCELoss()

# 数据加载和预处理
transform = transforms.Compose([
    transforms.ToTensor(),  # 将图像转换为张量
    transforms.Normalize((0.5,), (0.5,))  # 将图像像素值归一化到[-1, 1]范围
])
data = datasets.MNIST(root="./data", transform=transform, download=True)  # 下载MNIST数据集
loader = torch.utils.data.DataLoader(data, batch_size=64, shuffle=True)  # 加载数据,批量大小为64

# 训练GAN模型
epochs = 10  # 训练的轮数
for epoch in range(epochs):  # 遍历每一轮
    for real, _ in loader:  # 遍历每个批次的真实图像
        real = real.view(-1, 28*28)  # 将图像展平为一维
        batch_size = real.size(0)  # 获取当前批次的大小

        # ---------------------
        # 训练判别器
        # ---------------------
        z = torch.randn(batch_size, z_dim)  # 生成随机噪声向量
        fake = generator(z)  # 用生成器生成假图像
        disc_real = discriminator(real).view(-1)  # 判别器对真实图像的输出
        disc_fake = discriminator(fake.detach()).view(-1)  # 判别器对假图像的输出(梯度不回传到生成器)
        # 判别器的损失函数,真实图像的标签为1,假图像的标签为0
        loss_disc = criterion(disc_real, torch.ones_like(disc_real)) + \
                    criterion(disc_fake, torch.zeros_like(disc_fake))
        opt_disc.zero_grad()  # 清空判别器的梯度
        loss_disc.backward()  # 反向传播计算梯度
        opt_disc.step()  # 更新判别器参数

        # ---------------------
        # 训练生成器
        # ---------------------
        output = discriminator(fake).view(-1)  # 判别器对生成图像的输出
        # 生成器的损失函数,生成器希望判别器认为其生成的图像都为真(标签为1)
        loss_gen = criterion(output, torch.ones_like(output))
        opt_gen.zero_grad()  # 清空生成器的梯度
        loss_gen.backward()  # 反向传播计算梯度
        opt_gen.step()  # 更新生成器参数

    # 打印当前轮的损失
    print(f"Epoch [{epoch+1}/{epochs}] Loss D: {loss_disc:.4f}, Loss G: {loss_gen:.4f}")
    # 保存生成的图像,便于可视化
    save_image(fake.view(fake.size(0), 1, 28, 28), f"generated_{epoch+1}.png")

同态滤波(Homomorphic Filtering)

同态滤波是一种常用于图像处理中的技术,主要用于增强图像的对比度,尤其是用于抑制图像中的亮度不均匀性(例如由于光照变化或阴影)。同态滤波通过处理图像的亮度和反射分量来提高图像的整体质量。

同态滤波的核心思想是将图像分解为两部分:一部分表示图像的反射成分(物体表面的反射),另一部分表示光照成分。然后,利用频域滤波来增强或抑制这两个部分。

同态滤波的步骤

  1. 图像的对数变换
    同态滤波的第一步是对图像进行对数变换,以分离图像的亮度和反射成分。对数变换将图像的乘法性质转化为加法性质,从而使得图像的亮度成分与反射成分更容易分开。

    对图像进行对数变换后的表达式为:
    I ′ ( x , y ) = log ⁡ ( I ( x , y ) ) I'(x, y) = \log(I(x, y)) I(x,y)=log(I(x,y))
    其中,$ I(x, y) $ 是原始图像,$ I’(x, y) $ 是变换后的图像。

  2. 频域处理
    使用傅里叶变换将图像从空间域转换到频域,这样我们可以对图像的高频和低频部分进行分别处理。高频部分代表反射成分,低频部分代表光照成分。

    通过对频域中的图像进行滤波,我们可以增强或抑制某些频率成分。例如,可以对高频部分进行增强来提高细节,或对低频部分进行抑制来减少光照不均匀的影响。

  3. 反变换
    频域滤波完成后,使用反傅里叶变换将处理后的图像从频域转换回空间域。最后,使用指数变换将图像恢复到原始的动态范围。

    反变换后的图像表达式为:
    I ( x , y ) = exp ⁡ ( I ′ ( x , y ) ) I(x, y) = \exp(I'(x, y)) I(x,y)=exp(I(x,y))

同态滤波的实现步骤

以下是一个简单的同态滤波实现流程:

  1. 读取图像
  2. 对图像进行对数变换
  3. 进行傅里叶变换
  4. 应用频域滤波
  5. 进行反傅里叶变换
  6. 恢复图像的对数值并进行指数变换

示例代码

import cv2
import numpy as np
import matplotlib.pyplot as plt

def homomorphic_filter(image):
    # 1. 对数变换
    image_log = np.log1p(np.float32(image))  # 使用log1p避免log(0)
    
    # 2. 傅里叶变换
    f = np.fft.fft2(image_log)
    fshift = np.fft.fftshift(f)  # 将低频部分移到中心
    
    # 3. 设计一个高通滤波器
    rows, cols = image.shape
    crow, ccol = rows // 2, cols // 2
    mask = np.ones((rows, cols), np.float32)
    r = 30  # 低频区域半径
    center = [crow, ccol]
    x, y = np.fft.fftfreq(cols), np.fft.fftfreq(rows)
    X, Y = np.meshgrid(x, y)
    d = np.sqrt((X - center[1]) ** 2 + (Y - center[0]) ** 2)
    mask[d < r] = 0  # 低频部分置零
    
    # 4. 频域滤波
    fshift_filtered = fshift * mask
    
    # 5. 反傅里叶变换
    f_ishift = np.fft.ifftshift(fshift_filtered)
    image_back = np.fft.ifft2(f_ishift)
    
    # 6. 恢复图像并进行指数变换
    image_back = np.exp(np.abs(image_back)) - 1
    
    return np.uint8(image_back)

# 读取图像
image = cv2.imread('image.jpg', cv2.IMREAD_COLOR)
image_gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)

# 应用同态滤波
filtered_image = homomorphic_filter(image_gray)

# 显示结果
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.title("Original Image")
plt.imshow(image_gray, cmap='gray')
plt.subplot(1, 2, 2)
plt.title("Filtered Image")
plt.imshow(filtered_image, cmap='gray')
plt.show()

偏振相机

光的偏振基础

光是电磁波,偏振是描述光波振动方向的属性。当光波的电场振动方向有特定排列时,我们称其为偏振光。

常见的偏振状态

  1. 自然光:没有固定的偏振方向。
  2. 线偏振光:电场在固定方向振动,通过偏振片,只允许一个方向的光振动通过。天空的蓝光(由大气散射产生)也含有偏振成分。
  3. 圆偏振光:电场旋转,轨迹呈圆形,使用四分之一波片,将线偏振光转变为圆偏振光。。
  4. 椭圆偏振光:电场旋转,轨迹呈椭圆形。线偏振光通过具有特定相位差的波片。

偏振相机的工作原理

偏振相机通过结合传统成像技术和偏振光学原理,记录光的偏振信息。其核心是利用微偏振滤光片阵列偏振光学元件,通过对偏振光的透射、反射等行为进行分析。

关键技术

  1. 微偏振滤光片阵列
    偏振相机的感光芯片(如 CMOS 或 CCD)表面覆盖了一层微偏振滤光片阵列,每个像素单元前安装不同角度的偏振滤光片。

    常见的滤光片方向为:

    • 0 ∘ 0^\circ 0(水平)
    • 4 5 ∘ 45^\circ 45
    • 9 0 ∘ 90^\circ 90(垂直)
    • 13 5 ∘ 135^\circ 135

    这样,每个像素记录一个特定方向的偏振光强。通过将相邻像素组合,得到同一位置处的偏振信息。

  2. 多角度偏振采样
    在偏振相机中,每组 2 × 2 2\times2 2×2 的像素阵列记录了 0 ∘ , 4 5 ∘ , 9 0 ∘ , 13 5 ∘ 0^\circ, 45^\circ, 90^\circ, 135^\circ 0,45,90,135 四个方向的光强信息。这些方向的光强用于计算偏振特性。


偏振信息的计算

Stokes 参数

偏振信息可通过 Stokes 参数 表示:

  • I I I:总光强
  • Q Q Q:水平偏振光强与垂直偏振光强的差异
  • U U U 4 5 ∘ 45^\circ 45 13 5 ∘ 135^\circ 135 偏振光强的差异
  • V V V 是圆偏振光的左旋和右旋分量之差。一般为0
    I = I 0 ∘ + I 9 0 ∘ I = I_{0^\circ} + I_{90^\circ} I=I0+I90
    Q = I 0 ∘ − I 9 0 ∘ Q = I_{0^\circ} - I_{90^\circ} Q=I0I90
    U = I 4 5 ∘ − I 13 5 ∘ U = I_{45^\circ} - I_{135^\circ} U=I45I135
偏振特性

利用 Stokes 参数计算偏振光的主要特性:

  • 偏振度 (DOP)
    DOP = Q 2 + U 2 + V 2 I \text{DOP} = \frac{\sqrt{Q^2 + U^2+V^2}}{I} DOP=IQ2+U2+V2
    表示偏振光强占总光强的比例。

  • 偏振角 (AOP)
    AOP = 1 2 arctan ⁡ U Q \text{AOP} = \frac{1}{2} \arctan{\frac{U}{Q}} AOP=21arctanQU
    表示光波振动方向的角度。


偏振相机的核心流程
  1. 光强采集
    相机通过微偏振滤光片阵列,在一次曝光中同时获取不同偏振方向的光强。

  2. 信号解算
    根据 I , Q , U I, Q, U I,Q,U 公式,计算每个像素的 Stokes 参数。

  3. 特性计算
    根据 Stokes 参数计算偏振度、偏振角等特性。

  4. 图像输出
    生成伪彩色图像,将偏振信息可视化。例如:

    • 红色通道:光强 ( I I I)
    • 绿色通道:偏振度 (DOP)
    • 蓝色通道:偏振角 (AOP)

偏振相机的硬件结构

Bayer格式与偏振相机的结合

在偏振相机中,Bayer格式的概念被延伸到微偏振滤光片阵列。每个像素位置不仅记录光强,还记录特定方向的偏振信息。例如:

[ 0 ∘ 4 5 ∘ 9 0 ∘ 13 5 ∘ ] \begin{bmatrix} 0^\circ & 45^\circ \\ 90^\circ & 135^\circ \end{bmatrix} [09045135]

每组 2 × 2 2 \times 2 2×2 像素记录四个方向的偏振光强,通过组合计算 Stokes 参数(如 I I I Q Q Q U U U),最终生成偏振度和偏振角图像。

  1. 微偏振阵列
    每个像素单元前覆盖一个小型偏振滤光片,阵列结构通常为 Bayer 格式。

    0 ∘ 0^\circ 0 4 5 ∘ 45^\circ 45 9 0 ∘ 90^\circ 90 13 5 ∘ 135^\circ 135
    I 0 ∘ I_{0^\circ} I0 I 4 5 ∘ I_{45^\circ} I45 I 9 0 ∘ I_{90^\circ} I90 I 13 5 ∘ I_{135^\circ} I135
  2. 图像传感器
    CMOS 或 CCD 芯片,用于记录偏振光经过滤光片后的强度。

  3. 处理器
    集成在相机中的嵌入式处理器或外部计算设备,用于实时计算偏振特性。


优点

  • 实时性:无需多次拍摄即可获取完整偏振信息。
  • 高精度:单像素解析多角度偏振光强。
  • 多功能性:可记录光强、偏振度和偏振角。

挑战

  • 分辨率限制:由于每个 2 × 2 2\times2 2×2 单元表示一个像素的偏振信息,实际分辨率降低。
  • 噪声问题:低光照条件下的偏振测量精度较低。
  • 价格较高:相比普通相机,偏振相机的价格和硬件要求更高。

金字塔 Retinex 算法(Pyramid Retinex Algorithm)

该理论的核心思想是通过模拟人眼的视觉机制来改善图像的视觉效果。Pyramid Retinex 是 Retinex 理论的一种改进,它结合了高斯金字塔和多尺度信息来进行图像增强。

Retinex 理论简介:

Retinex 理论由 Edwin Land 和 John McCann 在 1971 年提出,旨在解释人眼如何感知不同照明条件下的颜色和亮度。它假设人眼会分离图像的亮度和反射信息:

  • 亮度(Luminance): 与场景的光源强度相关。
  • 反射(Reflectance): 与物体的颜色或反射特性相关。

Retinex 算法试图通过分离并重建这两种信息,改善图像的对比度和颜色,使得图像在不同光照条件下看起来更加自然。

Pyramid Retinex 算法原理:

Pyramid Retinex 算法通过以下几个步骤对传统的 Retinex 算法进行改进:

  1. 金字塔分解:使用高斯金字塔将图像分解成多个尺度的低频和高频成分。每个尺度上的低频部分表示图像的整体照明(光源),而高频部分表示细节或反射信息。
  2. 多尺度信息融合:在多个尺度上分别应用 Retinex 算法,从不同的尺度中提取对比度增强信息。通过这种方式,算法能够同时考虑图像的细节和全局信息。
  3. 图像重建:最终,通过融合所有尺度的结果,恢复图像的反射信息,并增强图像的对比度。

Pyramid Retinex 处理流程:

  1. 高斯金字塔构建:首先通过高斯滤波将图像分解成不同尺度的图像。金字塔中的每一层表示图像的不同频率分量。
  2. 对比度增强:对每一层图像应用 Retinex 算法,通常是:
    Retinex ( I ) = log ⁡ ( I ) − log ⁡ ( G ∗ I ) \text{Retinex}(I) = \log(I) - \log(G * I) Retinex(I)=log(I)log(GI)
    其中 I I I 是图像, G G G 是高斯核, ∗ * 表示卷积操作。该过程分离了图像的反射成分和照明成分,增强了图像的对比度。
  3. 多尺度融合:通过对多个尺度上的结果进行加权和融合,得到最终的增强图像。
  4. 重建图像:最后,基于每个尺度的反射成分重建出最终的图像。

Pyramid Retinex 算法步骤:

  1. 输入图像:给定一个输入图像。
  2. 构建金字塔:使用高斯滤波器构建多个尺度的金字塔。
  3. 在每个尺度上应用 Retinex:对每一层图像应用 Retinex 算法,得到增强的图像。
  4. 融合结果:将各尺度上的增强结果进行融合,得到最终的增强图像。
  5. 输出增强图像:通过融合的图像恢复细节并增强图像的可视效果。

Pyramid Retinex 算法代码示例(Python):

下面是一个简单的 Pyramid Retinex 算法的实现框架:

import cv2
import numpy as np
import matplotlib.pyplot as plt

def gaussian_pyramid(img, levels=5):
    pyramid = [img]
    for i in range(levels-1):
        img = cv2.pyrDown(img)
        pyramid.append(img)
    return pyramid

def retinex(img, sigma=30):
    # 计算 Retinex 结果
    blurred = cv2.GaussianBlur(img, (0, 0), sigma)
    retinex = np.log1p(img) - np.log1p(blurred)
    return retinex

def pyramid_retinex(img, levels=5, sigma=30):
    # 构建高斯金字塔
    pyramid = gaussian_pyramid(img, levels)
    
    # 对每个尺度应用 Retinex
    retinex_results = []
    for level in pyramid:
        retinex_results.append(retinex(level, sigma))
    
    # 合并所有尺度上的结果
    final_result = np.zeros_like(img, dtype=np.float32)
    for res in retinex_results:
        final_result += res
    
    # 归一化输出结果
    final_result = np.expm1(final_result)  # 使用指数恢复反射
    final_result = np.clip(final_result, 0, 255).astype(np.uint8)
    
    return final_result

# 读取输入图像
image = cv2.imread('input_image.jpg')
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

# 执行 Pyramid Retinex 算法
enhanced_image = pyramid_retinex(image_rgb)

# 显示结果
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.imshow(image_rgb)
plt.title("Original Image")
plt.axis('off')

plt.subplot(1, 2, 2)
plt.imshow(enhanced_image)
plt.title("Enhanced Image (Pyramid Retinex)")
plt.axis('off')

plt.show()
;