Bootstrap

去雾去雨算法

简单版

import cv2
import numpy as np

def dehaze(image):
    """简单去雾算法,使用直方图均衡化来增强图像"""
    # 将图像转换为YUV颜色空间
    yuv_image = cv2.cvtColor(image, cv2.COLOR_BGR2YUV)

    # 对Y通道(亮度)进行直方图均衡化
    yuv_image[:, :, 0] = cv2.equalizeHist(yuv_image[:, :, 0])

    # 将YUV图像转换回BGR
    dehazed_image = cv2.cvtColor(yuv_image, cv2.COLOR_YUV2BGR)

    return dehazed_image

def derain(image):
    """简单去雨算法,使用高通滤波器去除细小的雨滴噪声"""
    # 转换为灰度图像
    gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)

    # 使用高通滤波器去除雨滴
    high_pass_filter = cv2.GaussianBlur(gray_image, (21, 21), 0)
    derained_image = cv2.subtract(gray_image, high_pass_filter)

    # 增强对比度
    derained_image = cv2.equalizeHist(derained_image)

    # 将灰度图转换回BGR格式
    derained_image = cv2.cvtColor(derained_image, cv2.COLOR_GRAY2BGR)

    return derained_image

# 加载图像
image = cv2.imread('input_image.jpg')

# 去雾处理
dehazed_image = dehaze(image)

# 去雨处理
derained_image = derain(dehazed_image)

# 保存或展示结果
cv2.imwrite('output_image.jpg', derained_image)
cv2.imshow('Result', derained_image)
cv2.waitKey(0)
cv2.destroyAllWindows()

升级版

这段代码实现了一个基于**暗通道先验(Dark Channel Prior)**的图像去雾算法,主要流程是通过计算暗通道、估计大气光、推测透射率并利用引导滤波器改善透射率,然后恢复图像的对比度和清晰度。

效果:

import PIL.Image as Image
import skimage.io as io
import numpy as np
import time
from gf import guided_filter
from numba import jit
import matplotlib.pyplot as plt

class HazeRemoval(object):
    def __init__(self, omega=0.95, t0=0.1, radius=7, r=20, eps=0.001):
        pass

    def open_image(self, img_path):
        img = Image.open(img_path)
        self.src = np.array(img).astype(np.double)/255.
        # self.gray = np.array(img.convert('L'))
        self.rows, self.cols, _ = self.src.shape
        self.dark = np.zeros((self.rows, self.cols), dtype=np.double)
        self.Alight = np.zeros((3), dtype=np.double)
        self.tran = np.zeros((self.rows, self.cols), dtype=np.double)
        self.dst = np.zeros_like(self.src, dtype=np.double)
        

    @jit
    def get_dark_channel(self, radius=7):
        print("Starting to compute dark channel prior...")
        start = time.time()
        tmp = self.src.min(axis=2)
        for i in range(self.rows):
            for j in range(self.cols):
                rmin = max(0,i-radius)
                rmax = min(i+radius,self.rows-1)
                cmin = max(0,j-radius)
                cmax = min(j+radius,self.cols-1)
                self.dark[i,j] = tmp[rmin:rmax+1,cmin:cmax+1].min()
        print("time:",time.time()-start)

    def get_air_light(self):
        print("Starting to compute air light prior...")
        start = time.time()
        flat = self.dark.flatten()
        flat.sort()
        num = int(self.rows*self.cols*0.001)
        threshold = flat[-num]
        tmp = self.src[self.dark>=threshold]
        tmp.sort(axis=0)
        self.Alight = tmp[-num:,:].mean(axis=0)
        # print(self.Alight)
        print("time:",time.time()-start)

    @jit
    def get_transmission(self, radius=7, omega=0.95):
        print("Starting to compute transmission...")
        start = time.time()
        for i in range(self.rows):
            for j in range(self.cols):
                rmin = max(0,i-radius)
                rmax = min(i+radius,self.rows-1)
                cmin = max(0,j-radius)
                cmax = min(j+radius,self.cols-1)
                pixel = (self.src[rmin:rmax+1,cmin:cmax+1]/self.Alight).min()
                self.tran[i,j] = 1. - omega * pixel
        print("time:",time.time()-start)

    def guided_filter(self, r=60, eps=0.001):
        print("Starting to compute guided filter trainsmission...")
        start = time.time()
        self.gtran = guided_filter(self.src, self.tran, r, eps)
        print("time:",time.time()-start)

    def recover(self, t0=0.1):
        print("Starting recovering...")
        start = time.time()
        self.gtran[self.gtran<t0] = t0
        t = self.gtran.reshape(*self.gtran.shape,1).repeat(3,axis=2)
        # import ipdb; ipdb.set_trace()
        self.dst = (self.src.astype(np.double) - self.Alight)/t + self.Alight
        self.dst *= 255
        self.dst[self.dst>255] = 255
        self.dst[self.dst<0] = 0
        self.dst = self.dst.astype(np.uint8)
        print("time:",time.time()-start)

    def show(self):
        import cv2
        cv2.imwrite("img/src.jpg", (self.src*255).astype(np.uint8)[:,:,(2,1,0)])
        cv2.imwrite("img/dark.jpg", (self.dark*255).astype(np.uint8))
        cv2.imwrite("img/tran.jpg", (self.tran*255).astype(np.uint8))
        cv2.imwrite("img/gtran.jpg", (self.gtran*255).astype(np.uint8))
        cv2.imwrite("img/dst.jpg", self.dst[:,:,(2,1,0)])
        
        io.imsave("test.jpg", self.dst)



if __name__ == '__main__':
    import sys
    hr = HazeRemoval()
    hr.open_image(sys.argv[1])
    hr.get_dark_channel()
    hr.get_air_light()
    hr.get_transmission()
    hr.guided_filter()
    hr.recover()
    hr.show()

    

;