论文参考: Deep Generative Filter for motion deblurring 论文解读
1. 图像大小 256×256 。
2. 根据训练和测试图片生成 HDF5格式 文件
HDF(Hierarchical Data Format)可以存储不同类型的图像和数码数据的文件格式,并且可以在不同类型的机器上传输,支持并行I / O。
# according the image path to read the image and covert it # to the given
size, then slice it, finally return the full and blur images def
format_image(image_path, size): image = Image.open(image_path) # slice image
into full and blur images image_full = image.crop((0, 0, image.size[0] / 2,
image.size[1])) # Note the full image in left, the blur image in right
image_blur = image.crop((image.size[0] / 2, 0, image.size[0], image.size[1])) #
image_full.show() # image_blur.show() image_full = image_full.resize((size,
size), Image.ANTIALIAS) image_blur = image_blur.resize((size, size),
Image.ANTIALIAS) # return the numpy arrays return np.array(image_full),
np.array(image_blur) # convert images to hdf5 data def build_hdf5(jpeg_dir,
size=256): # put data in HDF5 hdf5_file = os.path.join('data', 'data.h5') with
h5py.File(hdf5_file, 'w') as f: for data_type in tqdm(['train', 'test'],
desc='create HDF5 dataset from images'): data_path = jpeg_dir + '/%s/*.jpg' %
data_type images_path = gb.glob(data_path) # print(images_path) data_full = []
data_blur = [] for image_path in images_path: image_full, image_blur =
format_image(image_path, size) data_full.append(image_full)
data_blur.append(image_blur) # print(len(data_full)) # print(len(data_blur))
f.create_dataset('%s_data_full' % data_type, data=data_full)
f.create_dataset('%s_data_blur' % data_type, data=data_blur)
3. 生成器网络
def generator_model(): # Input Image, Note the shape is variable inputs =
Input(shape=(None, None, 3)) # The Head h = Convolution2D(filters=4 *
channel_rate, kernel_size=(3, 3), padding='same')(inputs) # The Dense Field d_1
= dense_block(inputs=h) x = concatenate([h, d_1]) # the paper used dilated
convolution at every even numbered layer within the dense field d_2 =
dense_block(inputs=x, dilation_factor=(1, 1)) x = concatenate([x, d_2]) d_3 =
dense_block(inputs=x) x = concatenate([x, d_3]) d_4 = dense_block(inputs=x,
dilation_factor=(2, 2)) x = concatenate([x, d_4]) d_5 = dense_block(inputs=x) x
= concatenate([x, d_5]) d_6 = dense_block(inputs=x, dilation_factor=(3, 3)) x =
concatenate([x, d_6]) d_7 = dense_block(inputs=x) x = concatenate([x, d_7]) d_8
= dense_block(inputs=x, dilation_factor=(2, 2)) x = concatenate([x, d_8]) d_9 =
dense_block(inputs=x) x = concatenate([x, d_9]) d_10 = dense_block(inputs=x,
dilation_factor=(1, 1)) # The Tail x = LeakyReLU(alpha=0.2)(d_10) x =
Convolution2D(filters=4 * channel_rate, kernel_size=(1, 1), padding='same')(x)
x = BatchNormalization()(x) # The Global Skip Connection x = concatenate([h,
x]) x = Convolution2D(filters=channel_rate, kernel_size=(3, 3),
padding='same')(x) # PReLU can't be used, because it is connected with the
input shape # x = PReLU()(x) x = LeakyReLU(alpha=0.2)(x) # Output Image outputs
= Convolution2D(filters=3, kernel_size=(3, 3), padding='same',
activation='tanh')(x) model = Model(inputs=inputs, outputs=outputs,
name='Generator') return model
# Dense Block def dense_block(inputs, dilation_factor=None): x =
LeakyReLU(alpha=0.2)(inputs) x = Convolution2D(filters=4 * channel_rate,
kernel_size=(1, 1), padding='same')(x) x = BatchNormalization()(x) x =
LeakyReLU(alpha=0.2)(x) # the 3 × 3 convolutions along the dense field are
alternated between ‘spatial’ convolution # and ‘dilated’ convolution with
linearly increasing dilation factor if dilation_factor is not None: x =
Convolution2D(filters=channel_rate, kernel_size=(3, 3), padding='same',
dilation_rate=dilation_factor)(x) else: x = Convolution2D(filters=channel_rate,
kernel_size=(3, 3), padding='same')(x) x = BatchNormalization()(x) # add
Gaussian noise x = Dropout(rate=0.5)(x) return x
* 1. 参数包含300万个可训练参数,和近7000个固定参数。
* 2. 输入是彩色图像,第一层卷积核大小是3×3,个数是256个;
* 3.
每个密集连接模块包含2个Leaky ReLU函数,2个Batch
* 4.
4. 判别器网络
def discriminator_model(): # PatchGAN inputs = Input(shape=patch_shape) x =
Convolution2D(filters=channel_rate, kernel_size=(3, 3), strides=(2, 2),
padding="same")(inputs) x = BatchNormalization()(x) x = LeakyReLU(alpha=0.2)(x)
x = Convolution2D(filters=2 * channel_rate, kernel_size=(3, 3), strides=(2, 2),
padding="same")(x) x = BatchNormalization()(x) x = LeakyReLU(alpha=0.2)(x) x =
Convolution2D(filters=4 * channel_rate, kernel_size=(3, 3), strides=(2, 2),
padding="same")(x) x = BatchNormalization()(x) x = LeakyReLU(alpha=0.2)(x) x =
Convolution2D(filters=4 * channel_rate, kernel_size=(3, 3), strides=(2, 2),
padding="same")(x) x = BatchNormalization()(x) x = LeakyReLU(alpha=0.2)(x) x =
Flatten()(x) outputs = Dense(units=1, activation='sigmoid')(x) model =
Model(inputs=inputs, outputs=outputs, name='PatchGAN') # model.summary() #
discriminator inputs = Input(shape=image_shape) list_row_idx = [(i *
channel_rate, (i + 1) * channel_rate) for i in range(int(image_shape[0] /
patch_shape[0]))] list_col_idx = [(i * channel_rate, (i + 1) * channel_rate)
for i in range(int(image_shape[1] / patch_shape[1]))] list_patch = [] for
row_idx in list_row_idx: for col_idx in list_col_idx: x_patch = Lambda(lambda
z: z[:, row_idx[0]:row_idx[1], col_idx[0]:col_idx[1], :])(inputs)
list_patch.append(x_patch) x = [model(patch) for patch in list_patch] outputs =
Average()(x) model = Model(inputs=inputs, outputs=outputs,
name='Discriminator') return model
* 1. 参数包含96万个可训练参数,1400个固定参数。
* 2. 输入是大小是256×256的彩色图像。
* 3. 判别器网络的第二层并列包含16个卷积模块,每个模块都会把图像压缩到64×64大小。这16个模块的输出又会分别作为PatchGAN模块的输入,
* 4. 网络的最后是把16个模块的输出进行平均,得到最后的图像是否是真实图像的判定值,范围[0,1]
5. 生成器和判别器构成的整体GAN结构
6. 生成器G的损失函数
def l1_loss(y_true, y_pred): return K.mean(K.abs(y_pred - y_true)) def
perceptual_loss(y_true, y_pred): vgg = VGG16(include_top=False,
weights='imagenet', input_shape=image_shape) loss_model =
Model(inputs=vgg.input, outputs=vgg.get_layer('block3_conv3').output) # let the
loss model can't be trained loss_model.trainable = False # loss_model.summary()
return K.mean(K.square(loss_model(y_true) - loss_model(y_pred))) def
generator_loss(y_true, y_pred,K_1=145, K_2=170): return K_1 *
perceptual_loss(y_true, y_pred) + K_2 * l1_loss(y_true, y_pred)
7. 判别器D的损失函数
判别器D的损失函数使用对数损失函数(logarithmic loss)
8. 整体GAN结构的损失函数
def adversarial_loss(y_true, y_pred): return -K.log(y_pred)
* 1. 按batch_size大小获得训练清晰图片和模糊图片
* 2. 使用模糊图片+生成器网络生成目标清晰图片
* 3. 根据真实清晰图片和生成的清晰图片求判别器的损失,之后固定判别器,使判别器参数不能更新
* 4. 分别求GAN整体网络loss和生成器loss
* 5. 设判别器trainable=True,使判别器可以训练
10. 训练效果示意
