200字范文,内容丰富有趣,生活中的好帮手!
200字范文 > 利用perceptual_loss感知损失获得更好的图片重建效果

利用perceptual_loss感知损失获得更好的图片重建效果

时间:2024-03-24 09:41:36

相关推荐

利用perceptual_loss感知损失获得更好的图片重建效果

利用perceptual_loss感知损失获得更好的图片重建效果

传统的MSEloss在图像重建领域会带来图像高频信息缺失的问题,导致生成的图片出现模糊。感知损失通过对卷积提取的高层信息进行比较,很好的缓解了上述问题,在此提供一个独立的perceptual_loss代码,方便初学者在训练过程中使用

def build_net(ntype,nin,nwb=None,name=None):if ntype=='conv':return tf.nn.relu(tf.nn.conv2d(nin,nwb[0],strides=[1,1,1,1],padding='SAME',name=name)+nwb[1])elif ntype=='pool':return tf.nn.avg_pool(nin,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')def get_weight_bias(vgg_layers,i):weights=vgg_layers[i][0][0][2][0][0]weights=tf.constant(weights)bias=vgg_layers[i][0][0][2][0][1]bias=tf.constant(np.reshape(bias,(bias.size)))return weights,biasvgg_path=scipy.io.loadmat('your vgg19path')print("[i] Loaded pre-trained vgg19 parameters")# build VGG19 to load pre-trained parametersdef build_vgg19(input,reuse=False):with tf.variable_scope("vgg19"):if reuse:tf.get_variable_scope().reuse_variables()net={}vgg_layers=vgg_path['layers'][0]net['input']=inputnet['conv1_1']=build_net('conv',net['input'],get_weight_bias(vgg_layers,0),name='vgg_conv1_1')net['conv1_2']=build_net('conv',net['conv1_1'],get_weight_bias(vgg_layers,2),name='vgg_conv1_2')net['pool1']=build_net('pool',net['conv1_2'])net['conv2_1']=build_net('conv',net['pool1'],get_weight_bias(vgg_layers,5),name='vgg_conv2_1')net['conv2_2']=build_net('conv',net['conv2_1'],get_weight_bias(vgg_layers,7),name='vgg_conv2_2')net['pool2']=build_net('pool',net['conv2_2'])net['conv3_1']=build_net('conv',net['pool2'],get_weight_bias(vgg_layers,10),name='vgg_conv3_1')net['conv3_2']=build_net('conv',net['conv3_1'],get_weight_bias(vgg_layers,12),name='vgg_conv3_2')net['conv3_3']=build_net('conv',net['conv3_2'],get_weight_bias(vgg_layers,14),name='vgg_conv3_3')net['conv3_4']=build_net('conv',net['conv3_3'],get_weight_bias(vgg_layers,16),name='vgg_conv3_4')net['pool3']=build_net('pool',net['conv3_4'])net['conv4_1']=build_net('conv',net['pool3'],get_weight_bias(vgg_layers,19),name='vgg_conv4_1')net['conv4_2']=build_net('conv',net['conv4_1'],get_weight_bias(vgg_layers,21),name='vgg_conv4_2')net['conv4_3']=build_net('conv',net['conv4_2'],get_weight_bias(vgg_layers,23),name='vgg_conv4_3')net['conv4_4']=build_net('conv',net['conv4_3'],get_weight_bias(vgg_layers,25),name='vgg_conv4_4')net['pool4']=build_net('pool',net['conv4_4'])net['conv5_1']=build_net('conv',net['pool4'],get_weight_bias(vgg_layers,28),name='vgg_conv5_1')net['conv5_2']=build_net('conv',net['conv5_1'],get_weight_bias(vgg_layers,30),name='vgg_conv5_2')return netdef compute_l1_loss(input, output):return tf.reduce_mean(tf.abs(input-output))def compute_percep_loss(input, output, reuse=False):vgg_real=build_vgg19(output,reuse=reuse)vgg_fake=build_vgg19(input,reuse=True)p0=compute_l1_loss(vgg_real['input'],vgg_fake['input'])p1=compute_l1_loss(vgg_real['conv1_2'],vgg_fake['conv1_2'])/2.6p2=compute_l1_loss(vgg_real['conv2_2'],vgg_fake['conv2_2'])/4.8p3=compute_l1_loss(vgg_real['conv3_2'],vgg_fake['conv3_2'])/3.7p4=compute_l1_loss(vgg_real['conv4_2'],vgg_fake['conv4_2'])/5.6p5=compute_l1_loss(vgg_real['conv5_2'],vgg_fake['conv5_2'])*10/1.5return p0+p1+p2+p3+p4+p5

将上述代码加入你的模型中,并在loss中调用compute_percep_loss即可,VGG模型可以在此处下载: VGG19.

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。