目录
一.项目介绍
超分辨率(Super-Resolution),简称超分(SR)。是指利用光学及其相关光学知识,根据已知图像信息恢复图像细节和其他数据信息的过程,简单来说就是增大图像的分辨率,防止其图像质量下降。
GAN的全称是Generative Adversarial Networks,即生成对抗网络。生成对抗网络一般由一个生成器(生成网络),和一个判别器(判别网络)组成
SRGAN使用了生成对抗的方式来进行图像的超分辨率重建,同时提出了一个由Adversarial Loss和Content Loss组成的损失函数。
网络架构:
分为两个网络:生成网络和判别网络
生成网络的作用是,通过学习训练集数据的特征,在判别器的指导下,将随机噪声分布尽量拟合为训练数据的真实分布,从而生成具有训练集特征的相似数据。
判别网络则负责区分输入的数据是真实的还是生成器生成的假数据,并反馈给生成器。
两个网络交替训练,能力同步提高,直到生成网络生成的数据能够以假乱真,并与与判别网络的能力达到一定均衡。
二.项目流程详解
2.1.数据加载与配置
参数配置:
from easydict import EasyDict as edict
import json
config = edict()
config.TRAIN = edict()
## Adam
config.TRAIN.batch_size = 4
config.TRAIN.lr_init = 1e-4
config.TRAIN.beta1 = 0.9
## initialize G
config.TRAIN.n_epoch_init = 100
# config.TRAIN.lr_decay_init = 0.1
# config.TRAIN.decay_every_init = int(config.TRAIN.n_epoch_init / 2)
## adversarial learning (SRGAN)
config.TRAIN.n_epoch = 2000
config.TRAIN.lr_decay = 0.1
config.TRAIN.decay_every = int(config.TRAIN.n_epoch / 2)
## train set location
config.TRAIN.hr_img_path = './srdata/DIV2K_train_HR'
config.TRAIN.lr_img_path = './srdata/DIV2K_train_LR_bicubic/X4'
config.VALID = edict()
## test set location
config.VALID.hr_img_path = './srdata/DIV2K_valid_HR'
config.VALID.lr_img_path = './srdata/DIV2K_valid_LR_bicubic/X4'
def log_config(filename, cfg):
with open(filename, 'w') as f:
f.write("================================================\n")
f.write(json.dumps(cfg, indent=4))
f.write("\n================================================\n")
数据加载:
# 通过tl.files.load_file_list获取图片名字
# 第一个参数是图片所在的文件夹的路径,第二个参数为图片类型
# sorted[:x]表示读取x个图像。(读取图像过多可能造成memory error问题)
train_hr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.png', printable=False))[:100]
train_lr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.lr_img_path, regx='.*.png', printable=False))[:100]
valid_hr_img_list = sorted(tl.files.load_file_list(path=config.VALID.hr_img_path, regx='.*.png', printable=False))[:50]
valid_lr_img_list = sorted(tl.files.load_file_list(path=config.VALID.lr_img_path, regx='.*.png', printable=False))[:50]
# If your machine have enough memory, please pre-load the whole train set.
# 通过tl.vis.read_images读取图片
# 第一个参数是先前取出的图片,第二个参数是图片所在的文件夹地址,第三个参数是一次性读取多少图片
train_hr_imgs = tl.vis.read_images(train_hr_img_list, path=config.TRAIN.hr_img_path,n_threads=8)
2.2.构建生成网络
tf.compat.v1.disable_eager_execution()
t_image = tf.compat.v1.placeholder('float32', [batch_size, 96, 96, 3], name='t_image_input_to_SRGAN_generator')
# 构建生成网络
# reuse=False表示不重复构建网络
net_g = SRGAN_g(t_image, is_train=True, reuse=False)
SRGAN_g:
def SRGAN_g(t_image, is_train=False, reuse=False):
""" Generator in Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network
feature maps (n) and stride (s) feature maps (n) and stride (s)
"""
# 权重初始化
w_init = tf.random_normal_initializer(stddev=0.02)
b_init = None # tf.constant_initializer(value=0.0)
# gamma值初始化(BatchNormalization中的参数)
g_init = tf.random_normal_initializer(1., 0.02)
with tf.variable_scope("SRGAN_g", reuse=reuse) as vs:
# tl.layers.set_name_reuse(reuse) # remove for TL 1.8.0+
# 输入层构造
n = InputLayer(t_image, name='in')
# 卷积层构造
n = Conv2d(n, 64, (3, 3), (1, 1), act=tf.nn.relu, padding='SAME', W_init=w_init, name='n64s1/c')
temp = n
# B residual blocks(增加16层残差模块)
for i in range(16):
nn = Conv2d(n, 64, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='n64s1/c1/%s' % i)
nn = BatchNormLayer(nn, act=tf.nn.relu, is_train=is_train, gamma_init=g_init, name='n64s1/b1/%s' % i)
nn = Conv2d(nn, 64, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='n64s1/c2/%s' % i)
nn = BatchNormLayer(nn, is_train=is_train, gamma_init=g_init, name='n64s1/b2/%s' % i)
# 两个网络相融合形成残差网络:nn = n + nn
# 其中n是最初未经处理的网络,nn是处理后的网络(此处是经过两次卷积和两次BatchNormalization)
nn = ElementwiseLayer([n, nn], tf.add, name='b_residual_add/%s' % i)
n = nn
n = Conv2d(n, 64, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='n64s1/c/m')
n = BatchNormLayer(n, is_train=is_train, gamma_init=g_init, name='n64s1/b/m')
# 最开始的网络和执行了一系列处理后的网络再进行一次融合形成新的网络
n = ElementwiseLayer([n, temp], tf.add, name='add3')
# B residual blacks end
# 开始对照片进行重构操作,由低分辨率重构成高分辨率
n = Conv2d(n, 256, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, name='n256s1/1')
n = SubpixelConv2d(n, scale=2, n_out_channel=None, act=tf.nn.relu, name='pixelshufflerx2/1')
n = Conv2d(n, 256, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, name='n256s1/2')
n = SubpixelConv2d(n, scale=2, n_out_channel=None, act=tf.nn.relu, name='pixelshufflerx2/2')
# 重构后进行一次卷积得到最终的结果
n = Conv2d(n, 3, (1, 1), (1, 1), act=tf.nn.tanh, padding='SAME', W_init=w_init, name='out')
return n
2.3.构建判别网络
tf.compat.v1.disable_eager_execution()
t_target_image = tf.compat.v1.placeholder('float32', [batch_size, 384, 384, 3], name='t_target_image')
# 构建判别网络
# 让判别网络判断什么是真的,传入的数据参数是真实的图像数据
# reuse=False表示不共用网络
net_d, logits_real = SRGAN_d(t_target_image, is_train=True, reuse=False)
# 让判别网络判断什么是假的,传入的数据参数是生成网络生成的图像数据
# reuse=True表示要共用网络
_, logits_fake = SRGAN_d(net_g.outputs, is_train=True, reuse=True)
SRGAN_d:
def SRGAN_d(input_images, is_train=True, reuse=False):
w_init = tf.random_normal_initializer(stddev=0.02)
b_init = None # tf.constant_initializer(value=0.0)
gamma_init = tf.random_normal_initializer(1., 0.02)
df_dim = 64
lrelu = lambda x: tl.act.lrelu(x, 0.2)
# 开始进行网络的构造
with tf.variable_scope("SRGAN_d", reuse=reuse):
tl.layers.set_name_reuse(reuse)
net_in = InputLayer(input_images, name='input/images')
net_h0 = Conv2d(net_in, df_dim, (4, 4), (2, 2), act=lrelu, padding='SAME', W_init=w_init, name='h0/c')
net_h1 = Conv2d(net_h0, df_dim * 2, (4, 4), (2, 2), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='h1/c')
net_h1 = BatchNormLayer(net_h1, act=lrelu, is_train=is_train, gamma_init=gamma_init, name='h1/bn')
net_h2 = Conv2d(net_h1, df_dim * 4, (4, 4), (2, 2), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='h2/c')
net_h2 = BatchNormLayer(net_h2, act=lrelu, is_train=is_train, gamma_init=gamma_init, name='h2/bn')
net_h3 = Conv2d(net_h2, df_dim * 8, (4, 4), (2, 2), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='h3/c')
net_h3 = BatchNormLayer(net_h3, act=lrelu, is_train=is_train, gamma_init=gamma_init, name='h3/bn')
net_h4 = Conv2d(net_h3, df_dim * 16, (4, 4), (2, 2), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='h4/c')
net_h4 = BatchNormLayer(net_h4, act=lrelu, is_train=is_train, gamma_init=gamma_init, name='h4/bn')
net_h5 = Conv2d(net_h4, df_dim * 32, (4, 4), (2, 2), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='h5/c')
net_h5 = BatchNormLayer(net_h5, act=lrelu, is_train=is_train, gamma_init=gamma_init, name='h5/bn')
net_h6 = Conv2d(net_h5, df_dim * 16, (1, 1), (1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='h6/c')
net_h6 = BatchNormLayer(net_h6, act=lrelu, is_train=is_train, gamma_init=gamma_init, name='h6/bn')
net_h7 = Conv2d(net_h6, df_dim * 8, (1, 1), (1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='h7/c')
net_h7 = BatchNormLayer(net_h7, is_train=is_train, gamma_init=gamma_init, name='h7/bn')
net = Conv2d(net_h7, df_dim * 2, (1, 1), (1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='res/c')
net = BatchNormLayer(net, act=lrelu, is_train=is_train, gamma_init=gamma_init, name='res/bn')
net = Conv2d(net, df_dim * 2, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='res/c2')
net = BatchNormLayer(net, act=lrelu, is_train=is_train, gamma_init=gamma_init, name='res/bn2')
net = Conv2d(net, df_dim * 8, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='res/c3')
net = BatchNormLayer(net, is_train=is_train, gamma_init=gamma_init, name='res/bn3')
net_h8 = ElementwiseLayer([net_h7, net], combine_fn=tf.add, name='res/add')
net_h8.outputs = tl.act.lrelu(net_h8.outputs, 0.2)
# 拉长卷积结果,通过全连接层
net_ho = FlattenLayer(net_h8, name='ho/flatten')
net_ho = DenseLayer(net_ho, n_units=1, act=tf.identity, W_init=w_init, name='ho/dense')
logits = net_ho.outputs
# 经过sigmoid函数得到最终的结果值,判断是真还是假
net_ho.outputs = tf.nn.sigmoid(net_ho.outputs)
return net_ho, logits
2.4.VGG特征提取网络
## vgg inference. 0, 1, 2, 3 BILINEAR NEAREST BICUBIC AREA
# 修改数据的尺寸大小,以满足VGG网络的要求
# 对原始图像进行resize
t_target_image_224 = tf.image.resize_images(
t_target_image, size=[224, 224], method=0,
align_corners=False) # resize_target_image_for_vgg # http://tensorlayer.readthedocs.io/en/latest/_modules/tensorlayer/layers.html#UpSampling2dLayer
# 对生成图像进行resize
t_predict_image_224 = tf.image.resize_images(net_g.outputs, size=[224, 224], method=0, align_corners=False) # resize_generate_image_for_vgg
net_vgg, vgg_target_emb = Vgg19_simple_api((t_target_image_224 + 1) / 2, reuse=False)
_, vgg_predict_emb = Vgg19_simple_api((t_predict_image_224 + 1) / 2, reuse=True)
Vgg19_simple_api:
def Vgg19_simple_api(rgb, reuse):
"""
Build the VGG 19 Model
Parameters
-----------
rgb : rgb image placeholder [batch, height, width, 3] values scaled [0, 1]
"""
VGG_MEAN = [103.939, 116.779, 123.68]
with tf.variable_scope("VGG19", reuse=reuse) as vs:
start_time = time.time()
print("build model started")
rgb_scaled = rgb * 255.0
# Convert RGB to BGR
red, green, blue = tf.split(rgb_scaled, 3, 3)
assert red.get_shape().as_list()[1:] == [224, 224, 1]
assert green.get_shape().as_list()[1:] == [224, 224, 1]
assert blue.get_shape().as_list()[1:] == [224, 224, 1]
# 减均值操作:各自的颜色通道减去各自的均值
bgr = tf.concat(
[
blue - VGG_MEAN[0],
green - VGG_MEAN[1],
red - VGG_MEAN[2],
], axis=3)
assert bgr.get_shape().as_list()[1:] == [224, 224, 3]
""" input layer """
net_in = InputLayer(bgr, name='input')
""" conv1 """
network = Conv2d(net_in, n_filter=64, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv1_1')
network = Conv2d(network, n_filter=64, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv1_2')
network = MaxPool2d(network, filter_size=(2, 2), strides=(2, 2), padding='SAME', name='pool1')
""" conv2 """
network = Conv2d(network, n_filter=128, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv2_1')
network = Conv2d(network, n_filter=128, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv2_2')
network = MaxPool2d(network, filter_size=(2, 2), strides=(2, 2), padding='SAME', name='pool2')
""" conv3 """
network = Conv2d(network, n_filter=256, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv3_1')
network = Conv2d(network, n_filter=256, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv3_2')
network = Conv2d(network, n_filter=256, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv3_3')
network = Conv2d(network, n_filter=256, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv3_4')
network = MaxPool2d(network, filter_size=(2, 2), strides=(2, 2), padding='SAME', name='pool3')
""" conv4 """
network = Conv2d(network, n_filter=512, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv4_1')
network = Conv2d(network, n_filter=512, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv4_2')
network = Conv2d(network, n_filter=512, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv4_3')
network = Conv2d(network, n_filter=512, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv4_4')
network = MaxPool2d(network, filter_size=(2, 2), strides=(2, 2), padding='SAME', name='pool4') # (batch_size, 14, 14, 512)
conv = network
""" conv5 """
network = Conv2d(network, n_filter=512, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv5_1')
network = Conv2d(network, n_filter=512, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv5_2')
network = Conv2d(network, n_filter=512, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv5_3')
network = Conv2d(network, n_filter=512, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv5_4')
network = MaxPool2d(network, filter_size=(2, 2), strides=(2, 2), padding='SAME', name='pool5') # (batch_size, 7, 7, 512)
""" fc 6~8 """
# 拉长数据经过全连接层
network = FlattenLayer(network, name='flatten')
network = DenseLayer(network, n_units=4096, act=tf.nn.relu, name='fc6')
network = DenseLayer(network, n_units=4096, act=tf.nn.relu, name='fc7')
network = DenseLayer(network, n_units=1000, act=tf.identity, name='fc8')
print("build model finished: %fs" % (time.time() - start_time))
return network, conv
2.5.损失函数
# ###========================== DEFINE TRAIN OPS ==========================###
# 判别器的loss设置:
# 如果是真实图像,设置ones_like
d_loss1 = tl.cost.sigmoid_cross_entropy(logits_real, tf.ones_like(logits_real), name='d1')
# 如果是假图像,设置zeros_like
d_loss2 = tl.cost.sigmoid_cross_entropy(logits_fake, tf.zeros_like(logits_fake), name='d2')
d_loss = d_loss1 + d_loss2
# 希望生成网络生成的图片是真的,设置ones_like
g_gan_loss = 1e-3 * tl.cost.sigmoid_cross_entropy(logits_fake, tf.ones_like(logits_fake), name='g')
# 生成结果和真实图片进行比较
mse_loss = tl.cost.mean_squared_error(net_g.outputs, t_target_image, is_mean=True)
# 生成结果和真是图片经过VGG网络提取特征后的比较
vgg_loss = 2e-6 * tl.cost.mean_squared_error(vgg_predict_emb.outputs, vgg_target_emb.outputs, is_mean=True)
# 生成对抗网络的最终loss
g_loss = mse_loss + vgg_loss + g_gan_loss
三.完整代码
main.py
#! /usr/bin/python
# -*- coding: utf8 -*-
#http://tensorlayercn.readthedocs.io/zh/latest/user/installation.html
import os
import time
import pickle, random
#from datetime import datetime
import numpy as np
import logging, scipy
import tensorflow as tf
import tensorlayer as tl
from model import SRGAN_g, SRGAN_d, Vgg19_simple_api
from utils import *
from config import config, log_config
###====================== HYPER-PARAMETERS ===========================###
## Adam
batch_size = config.TRAIN.batch_size
lr_init = config.TRAIN.lr_init
beta1 = config.TRAIN.beta1
## initialize G
n_epoch_init = config.TRAIN.n_epoch_init
## adversarial learning (SRGAN)
n_epoch = config.TRAIN.n_epoch
lr_decay = config.TRAIN.lr_decay
decay_every = config.TRAIN.decay_every
ni = int(np.sqrt(batch_size))
def train():
## create folders to save result images and trained model
save_dir_ginit = "samples/{}_ginit".format(tl.global_flag['mode'])
save_dir_gan = "samples/{}_gan".format(tl.global_flag['mode'])
tl.files.exists_or_mkdir(save_dir_ginit)
tl.files.exists_or_mkdir(save_dir_gan)
checkpoint_dir = "checkpoint" # checkpoint_resize_conv
tl.files.exists_or_mkdir(checkpoint_dir)
###====================== PRE-LOAD DATA ===========================###
# 通过tl.files.load_file_list获取图片名字
# 第一个参数是图片所在的文件夹的路径,第二个参数为图片类型
# sorted[:x]表示读取x个图像(读取图像过多可能造成memory error问题)
train_hr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.png', printable=False))[:100]
train_lr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.lr_img_path, regx='.*.png', printable=False))[:100]
valid_hr_img_list = sorted(tl.files.load_file_list(path=config.VALID.hr_img_path, regx='.*.png', printable=False))[:50]
valid_lr_img_list = sorted(tl.files.load_file_list(path=config.VALID.lr_img_path, regx='.*.png', printable=False))[:50]
# If your machine have enough memory, please pre-load the whole train set.
# 通过tl.vis.read_images读取图片
# 第一个参数是先前取出的图片,第二个参数是图片所在的文件夹地址,第三个参数是一次性读取多少图片
train_hr_imgs = tl.vis.read_images(train_hr_img_list, path=config.TRAIN.hr_img_path, n_threads=8)
# for im in train_hr_imgs:
# print(im.shape)
# valid_lr_imgs = tl.vis.read_images(valid_lr_img_list, path=config.VALID.lr_img_path, n_threads=32)
# for im in valid_lr_imgs:
# print(im.shape)
# valid_hr_imgs = tl.vis.read_images(valid_hr_img_list, path=config.VALID.hr_img_path, n_threads=32)
# for im in valid_hr_imgs:
# print(im.shape)
# exit()
###========================== DEFINE MODEL ============================###
## train inference
tf.compat.v1.disable_eager_execution()
t_image = tf.compat.v1.placeholder('float32', [batch_size, 96, 96, 3], name='t_image_input_to_SRGAN_generator')
t_target_image = tf.compat.v1.placeholder('float32', [batch_size, 384, 384, 3], name='t_target_image')
# 构建生成网络
# reuse=False表示不共用网络
net_g = SRGAN_g(t_image, is_train=True, reuse=False)
# 构建判别网络
# 让判别网络判断什么是真的,传入的数据参数是真实的图像数据
# reuse=False表示不共用网络
net_d, logits_real = SRGAN_d(t_target_image, is_train=True, reuse=False)
# 让判别网络判断什么是假的,传入的数据参数是生成网络生成的图像数据
# reuse=True表示要共用网络
_, logits_fake = SRGAN_d(net_g.outputs, is_train=True, reuse=True)
net_g.print_params(False)
net_g.print_layers()
net_d.print_params(False)
net_d.print_layers()
## vgg inference. 0, 1, 2, 3 BILINEAR NEAREST BICUBIC AREA
# 修改数据的尺寸大小,以满足VGG网络的要求
# 对原始图像进行resize
t_target_image_224 = tf.image.resize_images(
t_target_image, size=[224, 224], method=0,
align_corners=False) # resize_target_image_for_vgg # http://tensorlayer.readthedocs.io/en/latest/_modules/tensorlayer/layers.html#UpSampling2dLayer
# 对生成图像进行resize
t_predict_image_224 = tf.image.resize_images(net_g.outputs, size=[224, 224], method=0, align_corners=False) # resize_generate_image_for_vgg
net_vgg, vgg_target_emb = Vgg19_simple_api((t_target_image_224 + 1) / 2, reuse=False)
_, vgg_predict_emb = Vgg19_simple_api((t_predict_image_224 + 1) / 2, reuse=True)
## test inference
net_g_test = SRGAN_g(t_image, is_train=False, reuse=True)
# ###========================== DEFINE TRAIN OPS ==========================###
# 判别器的loss设置:
# 如果是真实图像,设置ones_like
d_loss1 = tl.cost.sigmoid_cross_entropy(logits_real, tf.ones_like(logits_real), name='d1')
# 如果是假图像,设置zeros_like
d_loss2 = tl.cost.sigmoid_cross_entropy(logits_fake, tf.zeros_like(logits_fake), name='d2')
d_loss = d_loss1 + d_loss2
# 希望生成网络生成的图片是真的,设置ones_like
g_gan_loss = 1e-3 * tl.cost.sigmoid_cross_entropy(logits_fake, tf.ones_like(logits_fake), name='g')
# 生成结果和真实图片进行比较
mse_loss = tl.cost.mean_squared_error(net_g.outputs, t_target_image, is_mean=True)
# 生成结果和真是图片经过VGG网络提取特征后的比较
vgg_loss = 2e-6 * tl.cost.mean_squared_error(vgg_predict_emb.outputs, vgg_target_emb.outputs, is_mean=True)
# 生成对抗网络的最终loss
g_loss = mse_loss + vgg_loss + g_gan_loss
# 获取参数
g_vars = tl.layers.get_variables_with_name('SRGAN_g', True, True)
d_vars = tl.layers.get_variables_with_name('SRGAN_d', True, True)
with tf.variable_scope('learning_rate'):
lr_v = tf.Variable(lr_init, trainable=False)
## Pretrain
g_optim_init = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(mse_loss, var_list=g_vars)
## SRGAN
g_optim = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(g_loss, var_list=g_vars)
d_optim = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(d_loss, var_list=d_vars)
###========================== RESTORE MODEL =============================###
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False))
tl.layers.initialize_global_variables(sess)
if tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/g_{}.npz'.format(tl.global_flag['mode']), network=net_g) is False:
tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/g_{}_init.npz'.format(tl.global_flag['mode']), network=net_g)
tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/d_{}.npz'.format(tl.global_flag['mode']), network=net_d)
###============================= LOAD VGG ===============================###
vgg19_npy_path = "vgg19.npy"
if not os.path.isfile(vgg19_npy_path):
print("Please download vgg19.npz from : https://github.com/machrisaa/tensorflow-vgg")
exit()
npz = np.load(vgg19_npy_path, encoding='latin1').item()
params = []
for val in sorted(npz.items()):
W = np.asarray(val[1][0])
b = np.asarray(val[1][1])
print(" Loading %s: %s, %s" % (val[0], W.shape, b.shape))
params.extend([W, b])
tl.files.assign_params(sess, params, net_vgg)
# net_vgg.print_params(False)
# net_vgg.print_layers()
print ('ok')
###============================= TRAINING ===============================###
## use first `batch_size` of train set to have a quick test during training
sample_imgs = train_hr_imgs[0:batch_size]
# sample_imgs = tl.vis.read_images(train_hr_img_list[0:batch_size], path=config.TRAIN.hr_img_path, n_threads=32) # if no pre-load train set
sample_imgs_384 = tl.prepro.threading_data(sample_imgs, fn=crop_sub_imgs_fn, is_random=False)
print('sample HR sub-image:', sample_imgs_384.shape, sample_imgs_384.min(), sample_imgs_384.max())
sample_imgs_96 = tl.prepro.threading_data(sample_imgs_384, fn=downsample_fn)
print('sample LR sub-image:', sample_imgs_96.shape, sample_imgs_96.min(), sample_imgs_96.max())
tl.vis.save_images(sample_imgs_96, [ni, ni], save_dir_ginit + '/_train_sample_96.png')
tl.vis.save_images(sample_imgs_384, [ni, ni], save_dir_ginit + '/_train_sample_384.png')
tl.vis.save_images(sample_imgs_96, [ni, ni], save_dir_gan + '/_train_sample_96.png')
tl.vis.save_images(sample_imgs_384, [ni, ni], save_dir_gan + '/_train_sample_384.png')
###========================= initialize G ====================###
## fixed learning rate
sess.run(tf.assign(lr_v, lr_init))
print(" ** fixed learning rate: %f (for init G)" % lr_init)
for epoch in range(0, n_epoch_init + 1):
epoch_time = time.time()
total_mse_loss, n_iter = 0, 0
## If your machine cannot load all images into memory, you should use
## this one to load batch of images while training.
# random.shuffle(train_hr_img_list)
# for idx in range(0, len(train_hr_img_list), batch_size):
# step_time = time.time()
# b_imgs_list = train_hr_img_list[idx : idx + batch_size]
# b_imgs = tl.prepro.threading_data(b_imgs_list, fn=get_imgs_fn, path=config.TRAIN.hr_img_path)
# b_imgs_384 = tl.prepro.threading_data(b_imgs, fn=crop_sub_imgs_fn, is_random=True)
# b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn)
## If your machine have enough memory, please pre-load the whole train set.
for idx in range(0, len(train_hr_imgs), batch_size):
step_time = time.time()
b_imgs_384 = tl.prepro.threading_data(train_hr_imgs[idx:idx + batch_size], fn=crop_sub_imgs_fn, is_random=True)
b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn)
## update G
errM, _ = sess.run([mse_loss, g_optim_init], {t_image: b_imgs_96, t_target_image: b_imgs_384})
print("Epoch [%2d/%2d] %4d time: %4.4fs, mse: %.8f " % (epoch, n_epoch_init, n_iter, time.time() - step_time, errM))
total_mse_loss += errM
n_iter += 1
log = "[*] Epoch: [%2d/%2d] time: %4.4fs, mse: %.8f" % (epoch, n_epoch_init, time.time() - epoch_time, total_mse_loss / n_iter)
print(log)
## quick evaluation on train set
if (epoch != 0) and (epoch % 10 == 0):
out = sess.run(net_g_test.outputs, {t_image: sample_imgs_96}) #; print('gen sub-image:', out.shape, out.min(), out.max())
print("[*] save images")
tl.vis.save_images(out, [ni, ni], save_dir_ginit + '/train_%d.png' % epoch)
## save model
if (epoch != 0) and (epoch % 10 == 0):
tl.files.save_npz(net_g.all_params, name=checkpoint_dir + '/g_{}_init.npz'.format(tl.global_flag['mode']), sess=sess)
###========================= train GAN (SRGAN) =========================###
for epoch in range(0, n_epoch + 1):
## update learning rate
if epoch != 0 and (epoch % decay_every == 0):
new_lr_decay = lr_decay**(epoch // decay_every)
sess.run(tf.assign(lr_v, lr_init * new_lr_decay))
log = " ** new learning rate: %f (for GAN)" % (lr_init * new_lr_decay)
print(log)
elif epoch == 0:
sess.run(tf.assign(lr_v, lr_init))
log = " ** init lr: %f decay_every_init: %d, lr_decay: %f (for GAN)" % (lr_init, decay_every, lr_decay)
print(log)
epoch_time = time.time()
total_d_loss, total_g_loss, n_iter = 0, 0, 0
## If your machine cannot load all images into memory, you should use
## this one to load batch of images while training.
# random.shuffle(train_hr_img_list)
# for idx in range(0, len(train_hr_img_list), batch_size):
# step_time = time.time()
# b_imgs_list = train_hr_img_list[idx : idx + batch_size]
# b_imgs = tl.prepro.threading_data(b_imgs_list, fn=get_imgs_fn, path=config.TRAIN.hr_img_path)
# b_imgs_384 = tl.prepro.threading_data(b_imgs, fn=crop_sub_imgs_fn, is_random=True)
# b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn)
## If your machine have enough memory, please pre-load the whole train set.
for idx in range(0, len(train_hr_imgs), batch_size):
step_time = time.time()
b_imgs_384 = tl.prepro.threading_data(train_hr_imgs[idx:idx + batch_size], fn=crop_sub_imgs_fn, is_random=True)
b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn)
## update D
errD, _ = sess.run([d_loss, d_optim], {t_image: b_imgs_96, t_target_image: b_imgs_384})
## update G
errG, errM, errV, errA, _ = sess.run([g_loss, mse_loss, vgg_loss, g_gan_loss, g_optim], {t_image: b_imgs_96, t_target_image: b_imgs_384})
print("Epoch [%2d/%2d] %4d time: %4.4fs, d_loss: %.8f g_loss: %.8f (mse: %.6f vgg: %.6f adv: %.6f)" %
(epoch, n_epoch, n_iter, time.time() - step_time, errD, errG, errM, errV, errA))
total_d_loss += errD
total_g_loss += errG
n_iter += 1
log = "[*] Epoch: [%2d/%2d] time: %4.4fs, d_loss: %.8f g_loss: %.8f" % (epoch, n_epoch, time.time() - epoch_time, total_d_loss / n_iter,
total_g_loss / n_iter)
print(log)
## quick evaluation on train set
if (epoch != 0) and (epoch % 10 == 0):
out = sess.run(net_g_test.outputs, {t_image: sample_imgs_96}) #; print('gen sub-image:', out.shape, out.min(), out.max())
print("[*] save images")
tl.vis.save_images(out, [ni, ni], save_dir_gan + '/train_%d.png' % epoch)
## save model
if (epoch != 0) and (epoch % 10 == 0):
tl.files.save_npz(net_g.all_params, name=checkpoint_dir + '/g_{}.npz'.format(tl.global_flag['mode']), sess=sess)
tl.files.save_npz(net_d.all_params, name=checkpoint_dir + '/d_{}.npz'.format(tl.global_flag['mode']), sess=sess)
def evaluate():
## create folders to save result images
save_dir = "samples/{}".format(tl.global_flag['mode'])
tl.files.exists_or_mkdir(save_dir)
checkpoint_dir = "checkpoint"
###====================== PRE-LOAD DATA ===========================###
# train_hr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.png', printable=False))
# train_lr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.lr_img_path, regx='.*.png', printable=False))
valid_hr_img_list = sorted(tl.files.load_file_list(path=config.VALID.hr_img_path, regx='.*.png', printable=False))
valid_lr_img_list = sorted(tl.files.load_file_list(path=config.VALID.lr_img_path, regx='.*.png', printable=False))
## If your machine have enough memory, please pre-load the whole train set.
# train_hr_imgs = tl.vis.read_images(train_hr_img_list, path=config.TRAIN.hr_img_path, n_threads=32)
# for im in train_hr_imgs:
# print(im.shape)
valid_lr_imgs = tl.vis.read_images(valid_lr_img_list, path=config.VALID.lr_img_path, n_threads=8)
# for im in valid_lr_imgs:
# print(im.shape)
valid_hr_imgs = tl.vis.read_images(valid_hr_img_list, path=config.VALID.hr_img_path, n_threads=8)
# for im in valid_hr_imgs:
# print(im.shape)
# exit()
###========================== DEFINE MODEL ============================###
imid = 64 # 0: 企鹅 81: 蝴蝶 53: 鸟 64: 古堡
valid_lr_img = valid_lr_imgs[imid]
valid_hr_img = valid_hr_imgs[imid]
# valid_lr_img = get_imgs_fn('test.png', 'data2017/') # if you want to test your own image
valid_lr_img = (valid_lr_img / 127.5) - 1 # rescale to [-1, 1]
# print(valid_lr_img.min(), valid_lr_img.max())
size = valid_lr_img.shape
# t_image = tf.placeholder('float32', [None, size[0], size[1], size[2]], name='input_image') # the old version of TL need to specify the image size
t_image = tf.placeholder('float32', [1, None, None, 3], name='input_image')
net_g = SRGAN_g(t_image, is_train=False, reuse=False)
###========================== RESTORE G =============================###
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False))
tl.layers.initialize_global_variables(sess)
tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/g_srgan.npz', network=net_g)
###======================= EVALUATION =============================###
start_time = time.time()
out = sess.run(net_g.outputs, {t_image: [valid_lr_img]})
print("took: %4.4fs" % (time.time() - start_time))
print("LR size: %s / generated HR size: %s" % (size, out.shape)) # LR size: (339, 510, 3) / gen HR size: (1, 1356, 2040, 3)
print("[*] save images")
tl.vis.save_image(out[0], save_dir + '/valid_gen.png')
tl.vis.save_image(valid_lr_img, save_dir + '/valid_lr.png')
tl.vis.save_image(valid_hr_img, save_dir + '/valid_hr.png')
out_bicu = scipy.misc.imresize(valid_lr_img, [size[0] * 4, size[1] * 4], interp='bicubic', mode=None)
tl.vis.save_image(out_bicu, save_dir + '/valid_bicubic.png')
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--mode', type=str, default='srgan', help='srgan, evaluate')
args = parser.parse_args()
tl.global_flag['mode'] = args.mode
if tl.global_flag['mode'] == 'srgan':
train()
elif tl.global_flag['mode'] == 'evaluate':
evaluate()
else:
raise Exception("Unknow --mode")
config.py
from easydict import EasyDict as edict
import json
config = edict()
config.TRAIN = edict()
## Adam
config.TRAIN.batch_size = 4
config.TRAIN.lr_init = 1e-4
config.TRAIN.beta1 = 0.9
## initialize G
config.TRAIN.n_epoch_init = 100
# config.TRAIN.lr_decay_init = 0.1
# config.TRAIN.decay_every_init = int(config.TRAIN.n_epoch_init / 2)
## adversarial learning (SRGAN)
config.TRAIN.n_epoch = 2000
config.TRAIN.lr_decay = 0.1
config.TRAIN.decay_every = int(config.TRAIN.n_epoch / 2)
## train set location
config.TRAIN.hr_img_path = './srdata/DIV2K_train_HR'
config.TRAIN.lr_img_path = './srdata/DIV2K_train_LR_bicubic/X4'
config.VALID = edict()
## test set location
config.VALID.hr_img_path = './srdata/DIV2K_valid_HR'
config.VALID.lr_img_path = './srdata/DIV2K_valid_LR_bicubic/X4'
def log_config(filename, cfg):
with open(filename, 'w') as f:
f.write("================================================\n")
f.write(json.dumps(cfg, indent=4))
f.write("\n================================================\n")
dowmload_imagenet.py
import argparse
import socket
import os
import urllib
import numpy as np
from PIL import Image
from joblib import Parallel, delayed
def download_image(download_str, save_dir):
img_name, img_url = download_str.strip().split('\t')
save_img = os.path.join(save_dir, "{}.jpg".format(img_name))
downloaded = False
try:
if not os.path.isfile(save_img):
print("Downloading {} to {}.jpg".format(img_url, img_name))
urllib.urlretrieve(img_url, save_img)
# Check size of the images
downloaded = True
with Image.open(save_img) as img:
width, height = img.size
img_size_bytes = os.path.getsize(save_img)
img_size_KB = img_size_bytes / 1024
if width < 500 or height < 500 or img_size_KB < 200:
os.remove(save_img)
print("Remove downloaded images (w:{}, h:{}, s:{}KB)".format(width, height, img_size_KB))
else:
print("Already downloaded {}".format(save_img))
except Exception:
if not downloaded:
print("Cannot download.")
else:
print("Remove failed, downloaded images.")
if os.path.isfile(save_img):
os.remove(save_img)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--img_url_file", type=str, required=True,
help="File that contains list of image IDs and urls.")
parser.add_argument("--output_dir", type=str, required=True,
help="Directory where to save outputs.")
parser.add_argument("--n_download_urls", type=int, default=20000,
help="Directory where to save outputs.")
args = parser.parse_args()
# np.random.seed(123456)
socket.setdefaulttimeout(10)
with open(args.img_url_file) as f:
lines = f.readlines()
lines = np.random.choice(lines, size=args.n_download_urls, replace=False)
Parallel(n_jobs=12)(delayed(download_image)(line, args.output_dir) for line in lines)
if __name__ == "__main__":
main()
model.py
#! /usr/bin/python
# -*- coding: utf8 -*-
import tensorflow as tf
import tensorlayer as tl
from tensorlayer.layers import *
import time
import os
# from tensorflow.python.ops import variable_scope as vs
# from tensorflow.python.ops import math_ops, init_ops, array_ops, nn
# from tensorflow.python.util import nest
# from tensorflow.contrib.rnn.python.ops import core_rnn_cell
# https://github.com/david-gpu/srez/blob/master/srez_model.py
def SRGAN_g(t_image, is_train=False, reuse=False):
""" Generator in Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network
feature maps (n) and stride (s) feature maps (n) and stride (s)
"""
# 权重初始化
w_init = tf.random_normal_initializer(stddev=0.02)
b_init = None # tf.constant_initializer(value=0.0)
# gamma值初始化(BatchNormalization中的参数)
g_init = tf.random_normal_initializer(1., 0.02)
# tf.compat.v1.disable_v2_behavior()
with tf.compat.v1.variable_scope("SRGAN_g", reuse=reuse) as vs:
# tl.layers.set_name_reuse(reuse) # remove for TL 1.8.0+
# 输入层构造
n = InputLayer(t_image, name='in')
# 卷积层构造
n = Conv2d(n, 64, (3, 3), (1, 1), act=tf.nn.relu, padding='SAME', W_init=w_init, name='n64s1/c')
temp = n
# B residual blocks(增加16层残差模块)
for i in range(16):
nn = Conv2d(n, 64, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='n64s1/c1/%s' % i)
nn = BatchNormLayer(nn, act=tf.nn.relu, is_train=is_train, gamma_init=g_init, name='n64s1/b1/%s' % i)
nn = Conv2d(nn, 64, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='n64s1/c2/%s' % i)
nn = BatchNormLayer(nn, is_train=is_train, gamma_init=g_init, name='n64s1/b2/%s' % i)
# 两个网络相融合形成残差网络:nn = n + nn
# 其中n是最初未经处理的网络,nn是处理后的网络(此处是经过两次卷积和两次BatchNormalization)
nn = ElementwiseLayer([n, nn], tf.add, name='b_residual_add/%s' % i)
n = nn
n = Conv2d(n, 64, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='n64s1/c/m')
n = BatchNormLayer(n, is_train=is_train, gamma_init=g_init, name='n64s1/b/m')
# 最开始的网络和执行了一系列处理后的网络再进行一次融合形成新的网络
n = ElementwiseLayer([n, temp], tf.add, name='add3')
# B residual blacks end
# 开始对照片进行重构操作,由低分辨率重构成高分辨率
n = Conv2d(n, 256, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, name='n256s1/1')
n = SubpixelConv2d(n, scale=2, n_out_channel=None, act=tf.nn.relu, name='pixelshufflerx2/1')
n = Conv2d(n, 256, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, name='n256s1/2')
n = SubpixelConv2d(n, scale=2, n_out_channel=None, act=tf.nn.relu, name='pixelshufflerx2/2')
# 重构后进行一次卷积得到最终的结果
n = Conv2d(n, 3, (1, 1), (1, 1), act=tf.nn.tanh, padding='SAME', W_init=w_init, name='out')
return n
def SRGAN_g2(t_image, is_train=False, reuse=False):
""" Generator in Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network
feature maps (n) and stride (s) feature maps (n) and stride (s)
96x96 --> 384x384
Use Resize Conv
"""
w_init = tf.random_normal_initializer(stddev=0.02)
b_init = None # tf.constant_initializer(value=0.0)
g_init = tf.random_normal_initializer(1., 0.02)
size = t_image.get_shape().as_list()
with tf.variable_scope("SRGAN_g", reuse=reuse) as vs:
# tl.layers.set_name_reuse(reuse) # remove for TL 1.8.0+
n = InputLayer(t_image, name='in')
n = Conv2d(n, 64, (3, 3), (1, 1), act=tf.nn.relu, padding='SAME', W_init=w_init, name='n64s1/c')
temp = n
# B residual blocks
for i in range(16):
nn = Conv2d(n, 64, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='n64s1/c1/%s' % i)
nn = BatchNormLayer(nn, act=tf.nn.relu, is_train=is_train, gamma_init=g_init, name='n64s1/b1/%s' % i)
nn = Conv2d(nn, 64, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='n64s1/c2/%s' % i)
nn = BatchNormLayer(nn, is_train=is_train, gamma_init=g_init, name='n64s1/b2/%s' % i)
nn = ElementwiseLayer([n, nn], tf.add, name='b_residual_add/%s' % i)
n = nn
n = Conv2d(n, 64, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='n64s1/c/m')
n = BatchNormLayer(n, is_train=is_train, gamma_init=g_init, name='n64s1/b/m')
n = ElementwiseLayer([n, temp], tf.add, name='add3')
# B residual blacks end
# n = Conv2d(n, 256, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, name='n256s1/1')
# n = SubpixelConv2d(n, scale=2, n_out_channel=None, act=tf.nn.relu, name='pixelshufflerx2/1')
#
# n = Conv2d(n, 256, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, name='n256s1/2')
# n = SubpixelConv2d(n, scale=2, n_out_channel=None, act=tf.nn.relu, name='pixelshufflerx2/2')
## 0, 1, 2, 3 BILINEAR NEAREST BICUBIC AREA
n = UpSampling2dLayer(n, size=[size[1] * 2, size[2] * 2], is_scale=False, method=1, align_corners=False, name='up1/upsample2d')
n = Conv2d(n, 64, (3, 3), (1, 1), padding='SAME', W_init=w_init, b_init=b_init, name='up1/conv2d') # <-- may need to increase n_filter
n = BatchNormLayer(n, act=tf.nn.relu, is_train=is_train, gamma_init=g_init, name='up1/batch_norm')
n = UpSampling2dLayer(n, size=[size[1] * 4, size[2] * 4], is_scale=False, method=1, align_corners=False, name='up2/upsample2d')
n = Conv2d(n, 32, (3, 3), (1, 1), padding='SAME', W_init=w_init, b_init=b_init, name='up2/conv2d') # <-- may need to increase n_filter
n = BatchNormLayer(n, act=tf.nn.relu, is_train=is_train, gamma_init=g_init, name='up2/batch_norm')
n = Conv2d(n, 3, (1, 1), (1, 1), act=tf.nn.tanh, padding='SAME', W_init=w_init, name='out')
return n
def SRGAN_d2(t_image, is_train=False, reuse=False):
""" Discriminator in Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network
feature maps (n) and stride (s) feature maps (n) and stride (s)
"""
w_init = tf.random_normal_initializer(stddev=0.02)
b_init = None
g_init = tf.random_normal_initializer(1., 0.02)
lrelu = lambda x: tl.act.lrelu(x, 0.2)
with tf.variable_scope("SRGAN_d", reuse=reuse) as vs:
# tl.layers.set_name_reuse(reuse) # remove for TL 1.8.0+
n = InputLayer(t_image, name='in')
n = Conv2d(n, 64, (3, 3), (1, 1), act=lrelu, padding='SAME', W_init=w_init, name='n64s1/c')
n = Conv2d(n, 64, (3, 3), (2, 2), act=lrelu, padding='SAME', W_init=w_init, b_init=b_init, name='n64s2/c')
n = BatchNormLayer(n, is_train=is_train, gamma_init=g_init, name='n64s2/b')
n = Conv2d(n, 128, (3, 3), (1, 1), act=lrelu, padding='SAME', W_init=w_init, b_init=b_init, name='n128s1/c')
n = BatchNormLayer(n, is_train=is_train, gamma_init=g_init, name='n128s1/b')
n = Conv2d(n, 128, (3, 3), (2, 2), act=lrelu, padding='SAME', W_init=w_init, b_init=b_init, name='n128s2/c')
n = BatchNormLayer(n, is_train=is_train, gamma_init=g_init, name='n128s2/b')
n = Conv2d(n, 256, (3, 3), (1, 1), act=lrelu, padding='SAME', W_init=w_init, b_init=b_init, name='n256s1/c')
n = BatchNormLayer(n, is_train=is_train, gamma_init=g_init, name='n256s1/b')
n = Conv2d(n, 256, (3, 3), (2, 2), act=lrelu, padding='SAME', W_init=w_init, b_init=b_init, name='n256s2/c')
n = BatchNormLayer(n, is_train=is_train, gamma_init=g_init, name='n256s2/b')
n = Conv2d(n, 512, (3, 3), (1, 1), act=lrelu, padding='SAME', W_init=w_init, b_init=b_init, name='n512s1/c')
n = BatchNormLayer(n, is_train=is_train, gamma_init=g_init, name='n512s1/b')
n = Conv2d(n, 512, (3, 3), (2, 2), act=lrelu, padding='SAME', W_init=w_init, b_init=b_init, name='n512s2/c')
n = BatchNormLayer(n, is_train=is_train, gamma_init=g_init, name='n512s2/b')
n = FlattenLayer(n, name='f')
n = DenseLayer(n, n_units=1024, act=lrelu, name='d1024')
n = DenseLayer(n, n_units=1, name='out')
logits = n.outputs
n.outputs = tf.nn.sigmoid(n.outputs)
return n, logits
def SRGAN_d(input_images, is_train=True, reuse=False):
w_init = tf.random_normal_initializer(stddev=0.02)
b_init = None # tf.constant_initializer(value=0.0)
gamma_init = tf.random_normal_initializer(1., 0.02)
df_dim = 64
lrelu = lambda x: tl.act.lrelu(x, 0.2)
# 开始进行网络的构造
with tf.variable_scope("SRGAN_d", reuse=reuse):
tl.layers.set_name_reuse(reuse)
net_in = InputLayer(input_images, name='input/images')
net_h0 = Conv2d(net_in, df_dim, (4, 4), (2, 2), act=lrelu, padding='SAME', W_init=w_init, name='h0/c')
net_h1 = Conv2d(net_h0, df_dim * 2, (4, 4), (2, 2), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='h1/c')
net_h1 = BatchNormLayer(net_h1, act=lrelu, is_train=is_train, gamma_init=gamma_init, name='h1/bn')
net_h2 = Conv2d(net_h1, df_dim * 4, (4, 4), (2, 2), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='h2/c')
net_h2 = BatchNormLayer(net_h2, act=lrelu, is_train=is_train, gamma_init=gamma_init, name='h2/bn')
net_h3 = Conv2d(net_h2, df_dim * 8, (4, 4), (2, 2), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='h3/c')
net_h3 = BatchNormLayer(net_h3, act=lrelu, is_train=is_train, gamma_init=gamma_init, name='h3/bn')
net_h4 = Conv2d(net_h3, df_dim * 16, (4, 4), (2, 2), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='h4/c')
net_h4 = BatchNormLayer(net_h4, act=lrelu, is_train=is_train, gamma_init=gamma_init, name='h4/bn')
net_h5 = Conv2d(net_h4, df_dim * 32, (4, 4), (2, 2), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='h5/c')
net_h5 = BatchNormLayer(net_h5, act=lrelu, is_train=is_train, gamma_init=gamma_init, name='h5/bn')
net_h6 = Conv2d(net_h5, df_dim * 16, (1, 1), (1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='h6/c')
net_h6 = BatchNormLayer(net_h6, act=lrelu, is_train=is_train, gamma_init=gamma_init, name='h6/bn')
net_h7 = Conv2d(net_h6, df_dim * 8, (1, 1), (1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='h7/c')
net_h7 = BatchNormLayer(net_h7, is_train=is_train, gamma_init=gamma_init, name='h7/bn')
net = Conv2d(net_h7, df_dim * 2, (1, 1), (1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='res/c')
net = BatchNormLayer(net, act=lrelu, is_train=is_train, gamma_init=gamma_init, name='res/bn')
net = Conv2d(net, df_dim * 2, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='res/c2')
net = BatchNormLayer(net, act=lrelu, is_train=is_train, gamma_init=gamma_init, name='res/bn2')
net = Conv2d(net, df_dim * 8, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='res/c3')
net = BatchNormLayer(net, is_train=is_train, gamma_init=gamma_init, name='res/bn3')
net_h8 = ElementwiseLayer([net_h7, net], combine_fn=tf.add, name='res/add')
net_h8.outputs = tl.act.lrelu(net_h8.outputs, 0.2)
# 拉长卷积结果,通过全连接层
net_ho = FlattenLayer(net_h8, name='ho/flatten')
net_ho = DenseLayer(net_ho, n_units=1, act=tf.identity, W_init=w_init, name='ho/dense')
logits = net_ho.outputs
# 经过sigmoid函数得到最终的结果值,判断是真还是假
net_ho.outputs = tf.nn.sigmoid(net_ho.outputs)
return net_ho, logits
def Vgg19_simple_api(rgb, reuse):
"""
Build the VGG 19 Model
Parameters
-----------
rgb : rgb image placeholder [batch, height, width, 3] values scaled [0, 1]
"""
VGG_MEAN = [103.939, 116.779, 123.68]
with tf.variable_scope("VGG19", reuse=reuse) as vs:
start_time = time.time()
print("build model started")
rgb_scaled = rgb * 255.0
# Convert RGB to BGR
red, green, blue = tf.split(rgb_scaled, 3, 3)
assert red.get_shape().as_list()[1:] == [224, 224, 1]
assert green.get_shape().as_list()[1:] == [224, 224, 1]
assert blue.get_shape().as_list()[1:] == [224, 224, 1]
# 减均值操作:各自的颜色通道减去各自的均值
bgr = tf.concat(
[
blue - VGG_MEAN[0],
green - VGG_MEAN[1],
red - VGG_MEAN[2],
], axis=3)
assert bgr.get_shape().as_list()[1:] == [224, 224, 3]
""" input layer """
net_in = InputLayer(bgr, name='input')
""" conv1 """
network = Conv2d(net_in, n_filter=64, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv1_1')
network = Conv2d(network, n_filter=64, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv1_2')
network = MaxPool2d(network, filter_size=(2, 2), strides=(2, 2), padding='SAME', name='pool1')
""" conv2 """
network = Conv2d(network, n_filter=128, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv2_1')
network = Conv2d(network, n_filter=128, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv2_2')
network = MaxPool2d(network, filter_size=(2, 2), strides=(2, 2), padding='SAME', name='pool2')
""" conv3 """
network = Conv2d(network, n_filter=256, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv3_1')
network = Conv2d(network, n_filter=256, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv3_2')
network = Conv2d(network, n_filter=256, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv3_3')
network = Conv2d(network, n_filter=256, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv3_4')
network = MaxPool2d(network, filter_size=(2, 2), strides=(2, 2), padding='SAME', name='pool3')
""" conv4 """
network = Conv2d(network, n_filter=512, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv4_1')
network = Conv2d(network, n_filter=512, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv4_2')
network = Conv2d(network, n_filter=512, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv4_3')
network = Conv2d(network, n_filter=512, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv4_4')
network = MaxPool2d(network, filter_size=(2, 2), strides=(2, 2), padding='SAME', name='pool4') # (batch_size, 14, 14, 512)
conv = network
""" conv5 """
network = Conv2d(network, n_filter=512, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv5_1')
network = Conv2d(network, n_filter=512, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv5_2')
network = Conv2d(network, n_filter=512, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv5_3')
network = Conv2d(network, n_filter=512, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv5_4')
network = MaxPool2d(network, filter_size=(2, 2), strides=(2, 2), padding='SAME', name='pool5') # (batch_size, 7, 7, 512)
""" fc 6~8 """
# 拉长数据经过全连接层
network = FlattenLayer(network, name='flatten')
network = DenseLayer(network, n_units=4096, act=tf.nn.relu, name='fc6')
network = DenseLayer(network, n_units=4096, act=tf.nn.relu, name='fc7')
network = DenseLayer(network, n_units=1000, act=tf.identity, name='fc8')
print("build model finished: %fs" % (time.time() - start_time))
return network, conv
# def vgg16_cnn_emb(t_image, reuse=False):
# """ t_image = 244x244 [0~255] """
# with tf.variable_scope("vgg16_cnn", reuse=reuse) as vs:
# tl.layers.set_name_reuse(reuse)
#
# mean = tf.constant([123.68, 116.779, 103.939], dtype=tf.float32, shape=[1, 1, 1, 3], name='img_mean')
# net_in = InputLayer(t_image - mean, name='vgg_input_im')
# """ conv1 """
# network = tl.layers.Conv2dLayer(net_in,
# act = tf.nn.relu,
# shape = [3, 3, 3, 64], # 64 features for each 3x3 patch
# strides = [1, 1, 1, 1],
# padding='SAME',
# name ='vgg_conv1_1')
# network = tl.layers.Conv2dLayer(network,
# act = tf.nn.relu,
# shape = [3, 3, 64, 64], # 64 features for each 3x3 patch
# strides = [1, 1, 1, 1],
# padding='SAME',
# name ='vgg_conv1_2')
# network = tl.layers.PoolLayer(network,
# ksize=[1, 2, 2, 1],
# strides=[1, 2, 2, 1],
# padding='SAME',
# pool = tf.nn.max_pool,
# name ='vgg_pool1')
# """ conv2 """
# network = tl.layers.Conv2dLayer(network,
# act = tf.nn.relu,
# shape = [3, 3, 64, 128], # 128 features for each 3x3 patch
# strides = [1, 1, 1, 1],
# padding='SAME',
# name ='vgg_conv2_1')
# network = tl.layers.Conv2dLayer(network,
# act = tf.nn.relu,
# shape = [3, 3, 128, 128], # 128 features for each 3x3 patch
# strides = [1, 1, 1, 1],
# padding='SAME',
# name ='vgg_conv2_2')
# network = tl.layers.PoolLayer(network,
# ksize=[1, 2, 2, 1],
# strides=[1, 2, 2, 1],
# padding='SAME',
# pool = tf.nn.max_pool,
# name ='vgg_pool2')
# """ conv3 """
# network = tl.layers.Conv2dLayer(network,
# act = tf.nn.relu,
# shape = [3, 3, 128, 256], # 256 features for each 3x3 patch
# strides = [1, 1, 1, 1],
# padding='SAME',
# name ='vgg_conv3_1')
# network = tl.layers.Conv2dLayer(network,
# act = tf.nn.relu,
# shape = [3, 3, 256, 256], # 256 features for each 3x3 patch
# strides = [1, 1, 1, 1],
# padding='SAME',
# name ='vgg_conv3_2')
# network = tl.layers.Conv2dLayer(network,
# act = tf.nn.relu,
# shape = [3, 3, 256, 256], # 256 features for each 3x3 patch
# strides = [1, 1, 1, 1],
# padding='SAME',
# name ='vgg_conv3_3')
# network = tl.layers.PoolLayer(network,
# ksize=[1, 2, 2, 1],
# strides=[1, 2, 2, 1],
# padding='SAME',
# pool = tf.nn.max_pool,
# name ='vgg_pool3')
# """ conv4 """
# network = tl.layers.Conv2dLayer(network,
# act = tf.nn.relu,
# shape = [3, 3, 256, 512], # 512 features for each 3x3 patch
# strides = [1, 1, 1, 1],
# padding='SAME',
# name ='vgg_conv4_1')
# network = tl.layers.Conv2dLayer(network,
# act = tf.nn.relu,
# shape = [3, 3, 512, 512], # 512 features for each 3x3 patch
# strides = [1, 1, 1, 1],
# padding='SAME',
# name ='vgg_conv4_2')
# network = tl.layers.Conv2dLayer(network,
# act = tf.nn.relu,
# shape = [3, 3, 512, 512], # 512 features for each 3x3 patch
# strides = [1, 1, 1, 1],
# padding='SAME',
# name ='vgg_conv4_3')
#
# network = tl.layers.PoolLayer(network,
# ksize=[1, 2, 2, 1],
# strides=[1, 2, 2, 1],
# padding='SAME',
# pool = tf.nn.max_pool,
# name ='vgg_pool4')
# conv4 = network
#
# """ conv5 """
# network = tl.layers.Conv2dLayer(network,
# act = tf.nn.relu,
# shape = [3, 3, 512, 512], # 512 features for each 3x3 patch
# strides = [1, 1, 1, 1],
# padding='SAME',
# name ='vgg_conv5_1')
# network = tl.layers.Conv2dLayer(network,
# act = tf.nn.relu,
# shape = [3, 3, 512, 512], # 512 features for each 3x3 patch
# strides = [1, 1, 1, 1],
# padding='SAME',
# name ='vgg_conv5_2')
# network = tl.layers.Conv2dLayer(network,
# act = tf.nn.relu,
# shape = [3, 3, 512, 512], # 512 features for each 3x3 patch
# strides = [1, 1, 1, 1],
# padding='SAME',
# name ='vgg_conv5_3')
# network = tl.layers.PoolLayer(network,
# ksize=[1, 2, 2, 1],
# strides=[1, 2, 2, 1],
# padding='SAME',
# pool = tf.nn.max_pool,
# name ='vgg_pool5')
#
# network = FlattenLayer(network, name='vgg_flatten')
#
# # # network = DropoutLayer(network, keep=0.6, is_fix=True, is_train=is_train, name='vgg_out/drop1')
# # new_network = tl.layers.DenseLayer(network, n_units=4096,
# # act = tf.nn.relu,
# # name = 'vgg_out/dense')
# #
# # # new_network = DropoutLayer(new_network, keep=0.8, is_fix=True, is_train=is_train, name='vgg_out/drop2')
# # new_network = DenseLayer(new_network, z_dim, #num_lstm_units,
# # b_init=None, name='vgg_out/out')
# return conv4, network
utils.py
import tensorflow as tf
import tensorlayer as tl
from tensorlayer.prepro import *
# from config import config, log_config
#
# img_path = config.TRAIN.img_path
import scipy
import numpy as np
import os
def get_imgs_fn(file_name, path):
""" Input an image path and name, return an image array """
# return scipy.misc.imread(path + file_name).astype(np.float)
return scipy.misc.imread(path + file_name, mode='RGB')
def crop_sub_imgs_fn(x, is_random=True):
x = crop(x, wrg=384, hrg=384, is_random=is_random)
x = x / (255. / 2.)
x = x - 1.
return x
def downsample_fn(x):
# We obtained the LR images by downsampling the HR images using bicubic kernel with downsampling factor r = 4.
x = imresize(x, size=[96, 96], interp='bicubic', mode=None)
x = x / (255. / 2.)
x = x - 1.
return x
四.数据集
下载地址:
DIV2K Datasethttps://data.vision.ee.ethz.ch/cvl/DIV2K/
五.测试网络
def evaluate():
## create folders to save result images
save_dir = "samples/{}".format(tl.global_flag['mode'])
tl.files.exists_or_mkdir(save_dir)
checkpoint_dir = "checkpoint"
###====================== PRE-LOAD DATA ===========================###
# train_hr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.png', printable=False))
# train_lr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.lr_img_path, regx='.*.png', printable=False))
valid_hr_img_list = sorted(tl.files.load_file_list(path=config.VALID.hr_img_path, regx='.*.png', printable=False))
valid_lr_img_list = sorted(tl.files.load_file_list(path=config.VALID.lr_img_path, regx='.*.png', printable=False))
## If your machine have enough memory, please pre-load the whole train set.
# train_hr_imgs = tl.vis.read_images(train_hr_img_list, path=config.TRAIN.hr_img_path, n_threads=32)
# for im in train_hr_imgs:
# print(im.shape)
valid_lr_imgs = tl.vis.read_images(valid_lr_img_list, path=config.VALID.lr_img_path, n_threads=8)
# for im in valid_lr_imgs:
# print(im.shape)
valid_hr_imgs = tl.vis.read_images(valid_hr_img_list, path=config.VALID.hr_img_path, n_threads=8)
# for im in valid_hr_imgs:
# print(im.shape)
# exit()
###========================== DEFINE MODEL ============================###
imid = 64 # 0: 企鹅 81: 蝴蝶 53: 鸟 64: 古堡
valid_lr_img = valid_lr_imgs[imid]
valid_hr_img = valid_hr_imgs[imid]
# valid_lr_img = get_imgs_fn('test.png', 'data2017/') # if you want to test your own image
valid_lr_img = (valid_lr_img / 127.5) - 1 # rescale to [-1, 1]
# print(valid_lr_img.min(), valid_lr_img.max())
size = valid_lr_img.shape
# t_image = tf.placeholder('float32', [None, size[0], size[1], size[2]], name='input_image') # the old version of TL need to specify the image size
t_image = tf.placeholder('float32', [1, None, None, 3], name='input_image')
net_g = SRGAN_g(t_image, is_train=False, reuse=False)
###========================== RESTORE G =============================###
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False))
tl.layers.initialize_global_variables(sess)
tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/g_srgan.npz', network=net_g)
###======================= EVALUATION =============================###
start_time = time.time()
out = sess.run(net_g.outputs, {t_image: [valid_lr_img]})
print("took: %4.4fs" % (time.time() - start_time))
print("LR size: %s / generated HR size: %s" % (size, out.shape)) # LR size: (339, 510, 3) / gen HR size: (1, 1356, 2040, 3)
print("[*] save images")
tl.vis.save_image(out[0], save_dir + '/valid_gen.png')
tl.vis.save_image(valid_lr_img, save_dir + '/valid_lr.png')
tl.vis.save_image(valid_hr_img, save_dir + '/valid_hr.png')
out_bicu = scipy.misc.imresize(valid_lr_img, [size[0] * 4, size[1] * 4], interp='bicubic', mode=None)
tl.vis.save_image(out_bicu, save_dir + '/valid_bicubic.png')
低分辨率图像:
resize后的图像:
生成网络生成出的图像:
高分辨率图像: