作者:静静敲代码 | 来源:互联网 | 2023-10-11 21:10
DSLR-QualityPhotosonMobileDeviceswithDeepConvolutionalNetworks---colorloss-Pytorch实现1.实现原理
DSLR-Quality Photos on Mobile Devices with Deep Convolutional Networks---colorloss-Pytorch实现
1.实现原理 最近在做图像增强相关的工作,偶然间看到了这篇文章,作者提出了一个损失叫做color-loss,根据文章描述该方法是通过模糊输入图像与ground-Truth的纹理、内容,仅仅保存图像的颜色信息实现图像颜色的校正。实现过程比较简单,首先构建一个高斯模糊核,然后利用高斯模糊核作为卷积核对图像进行卷积运算,得到模糊后的图像;然后计算输入图像与ground-Truth的MSE作为损失函数。 作者的github中有该模型的代码,但是是用TensorFlow实现的。因为我的代码pytorch的,所以自己重新改写了一下。在作者的代码中用到了深度可分离卷积,在pytorch中我没有对其进行深度可分离操作。 算是个深度学习的小白吧,有问题可以给我留言呀~~~
2.代码 import torchimport torch. nn as nnimport torch. nn. functional as F import mathfrom math import exp, piimport numpy as npimport cv2 as cvimport scipy. stats as stimport matplotlib. pyplot as pltdef gauss_kernel ( kernlen= 21 , nsig= 3 , channels= 1 ) : interval = ( 2 * nsig+ 1. ) / ( kernlen) x = np. linspace ( - nsig- interval/ 2. , nsig+ interval/ 2. , kernlen+ 1 ) kern1d = np. diff ( st. norm. cdf ( x) ) kernel_raw = np. sqrt ( np. outer ( kern1d, kern1d) ) kernel = kernel_raw/ kernel_raw. sum ( ) out_filter = np. array ( kernel, dtype = np. float32) out_filter = out_filter. reshape ( ( kernlen, kernlen) ) # out_filter = np. repeat ( out_filter, channels, axis = 0 ) return out_filter # kernel_size= 21 class SeparableConv2d ( nn. Module) : def __init__ ( self) : super ( SeparableConv2d, self) . __init__ ( ) kernel = gauss_kernel ( 21 , 3 , 3 ) kernel = torch. FloatTensor ( kernel) . unsqueeze ( 0 ) . unsqueeze ( 0 ) ## kernel_point = [ [ 1.0 ] ] ## kernel_point = torch. FloatTensor ( kernel_point) . unsqueeze ( 0 ) . unsqueeze ( 0 ) # kernel = torch. FloatTensor ( kernel) . expand ( 3 , 3 , 21 , 21 ) # torch. expand ( )向输入的维度前面进行扩充,输入为三通道时,将weight扩展为[ 3 , 3 , 21 , 21 ] ## kernel_point = torch. FloatTensor ( kernel_point) . expand ( 3 , 3 , 1 , 1 ) self. weight = nn. Parameter ( data= kernel, requires_grad= False) # self. pointwise = nn. Conv2d ( 1 , 1 , 1 , 1 , 0 , 1 , 1 , bias= False) # 单通道时in_channels= 1 ,out_channels= 1 , 三通道时,in_channels= 3 , out_channels= 3 卷积核为随机的## self. weight_point = nn. Parameter ( data= kernel_point, requires_grad= False) def forward ( self, img1) : x = F . conv2d ( img1, self. weight, groups= 1 , padding= 10 ) ## x = F . conv2d ( x, self. weight_point, groups= 1 , padding= 0 ) #卷积核为[ 1 ] # x = self. pointwise ( x) return x # plt. imshow ( out_kernel) # plt. imshow ( out_kernel)