Bootstrap

数字水印 | 奇异值分解 SVD 的 Python 代码实现

🥑原理:数字水印 | 奇异值分解 SVD 的定义、原理及性质

🥑参考:Python 机器学习笔记:奇异值分解(SVD)算法



正文

对于一个图像矩阵,我们总可以将其分解为以下形式:

在这里插入图片描述

通过选取不同个数 Σ \Sigma Σ 矩阵中的奇异值,就可以实现图像的压缩。

如果你没有了解过原理,那么你当然看不懂这是什么意思😇

如果想要实现图像的压缩,那么可以先使用 n u m p y \mathsf{numpy} numpy 库中的 linalg.svd 函数对图像矩阵进行分解,然后提取前 k k k 个奇异值以实现 SVD 图像压缩效果。下面让我们看一下代码。



1 核心代码

定义 s v d _ c o m p r e s s i o n \mathsf{svd\_compression} svd_compression 函数:

def svd_compression(img, k):
    res_image = np.zeros_like(img)

    for i in range(img.shape[2]):
        U, Sigma, VT = np.linalg.svd(img[:, :, i])
        res_image[:, :, i] = U[:, :k].dot(np.diag(Sigma[:k])).dot(VT[:k, :])

    return res_image

参数说明:

  • i m g \mathsf{img} img 是待处理的图像
  • k \mathsf{k} k 用于设置选定前 k k k 个奇异值

代码说明:

初始化 r e s _ i m a g e \mathsf{res\_image} res_image 变量,用于存放处理结果:

res_image = np.zeros_like(img)

循环压缩每一个通道:

for i in range(img.shape[2]):
        U, Sigma, VT = np.linalg.svd(img[:, :, i])
        res_image[:, :, i] = U[:, :k].dot(np.diag(Sigma[:k])).dot(VT[:k, :])
  • 参数: i m g . s h a p e [ 2 ] \mathsf{img.shape[2]} img.shape[2] 是图像的通道个数
  • 第一行:对第 i i i 个通道进行 SVD 分解
  • 第二行:取前 k k k 个奇异值重新构造图像

说明:由于 S i g m a \mathsf{Sigma} Sigma 矩阵除对角元素外,其余元素都为 0 \mathsf{0} 0,因此 linalg.svd 函数将其处理为一维矩阵返回。在重新构造图像时,我们需要使用 np.diag 函数将其还原为对角矩阵。



2 完整代码

import numpy as np
import cv2
from matplotlib import pyplot as plt


img = cv2.imread('white_bear.jpg')
img = img[:, :, [2, 1, 0]]
print('image shape is ', img.shape)


def svd_compression(img, k):
    res_image = np.zeros_like(img)

    for i in range(img.shape[2]):
        U, Sigma, VT = np.linalg.svd(img[:, :, i])
        res_image[:, :, i] = U[:, :k].dot(np.diag(Sigma[:k])).dot(VT[:k, :])

    return res_image


# 保留前 k 个奇异值
res1 = svd_compression(img, k=300)
res2 = svd_compression(img, k=200)
res3 = svd_compression(img, k=100)
res4 = svd_compression(img, k=50)

plt.subplot(1, 5, 1)
plt.title("image", fontsize=12, loc="center")
plt.axis('off')
plt.imshow(img, cmap='gray')

plt.subplot(1, 5, 2)
plt.title("image", fontsize=12, loc="center")
plt.axis('off')
plt.imshow(res1, cmap='gray')

plt.subplot(1, 5, 3)
plt.title("u", fontsize=12, loc="center")
plt.axis('off')
plt.imshow(res2, cmap='gray')

plt.subplot(1, 5, 4)
plt.title("s", fontsize=12, loc="center")
plt.axis('off')
plt.imshow(res3, cmap='gray')

plt.subplot(1, 5, 5)
plt.title("v", fontsize=12, loc="center")
plt.axis('off')
plt.imshow(res4, cmap='gray')

plt.show()


;