热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

记录训练神经网络过程中常用的权值共享和不同层赋予不同学习率等方法

权值共享 import torchimport torch.nn as nnclass model(nn.Module):def __init__(self):super(model,self).__

权值共享

import torch
import torch.nn as nnclass model(nn.Module):def __init__(self):super(model,self).__init__()self.lstm = nn.LSTM(input_size = 10,hidden_size = 5)self.linear = nn.Linear(input_features = 5,out_features = 2)def forward(self,inputdata1,inputdata2):lstm_result1 = self.lstm (input_data1)lstm_result2 = self.lstm(inputdata2)output = self.linear(lstm_result1+lstm_result2)return output

注释:在神经网络的训练过程中经常用到两层网络共享权值,在上述代码片中,定义神经网络时定义一个lstm模型和一个全连接层,在前向计算中多次调用lstm层进行计算,相当于神经网络模型中有两个lstm层,即计算inputdata1和inputdata2的两个lstm共享权值。
参考:https://www.cnblogs.com/sdu20112013/p/12132786.html

某些层参数不更新

在查询此类资料时,在博客中看到模型层中添加了requires_grad = False后参数仍会训练的问题,博主并给出了相关解决方法,这里记录两个感觉使用方便的方法。
更多内容参考:https://blog.csdn.net/guotong1988/article/details/79739775

import torch
import torch.nn as nnclass model(nn.Module):def __init__(self):super(model,self).__init__()self.lstm = nn.LSTM(input_size = 10,hidden_size = 5)for p in self.parameters():p.requires_grad = Falseself.linear = nn.Linear(input_features = 5,out_features = 2)def forward(self,inputdata1,inputdata2):lstm_result1 = self.lstm (input_data1)lstm_result2 = self.lstm(inputdata2)output = self.linear(lstm_result1+lstm_result2)return output

注释:在不需要参数更新的层后边添加如下代码行:

for p in self.parameters():p.requires_grad = False

但是上述方法适用于模型中最初几层都不需训练,顶层需要训练的情况,如果出现需要训练和不需要训练的模型层交替出现的时候,上述方法就无法使用。博主给出了使用范围更广的方法:

import torch
import torch.nn as nn
from torch.nn import CrossEntropyLossclass model(nn.Module):def __init__(self):super(model,self).__init__()self.lstm1 = nn.LSTM(input_size = 10,hidden_size = 10,requires_grad = True)self.lstm2 = nn.LSTM(input_size = 10,hidden_size = 5,requires_grad = False)self.linear = nn.Linear(input_features = 5,out_features = 2,requires_grad = True)def forward(self,inputdata):lstm_result1 = self.lstm1(input_data)lstm_result2 = self.lstm(lstm_result1)output = self.linear(lstm_result1+lstm_result2)return output
model = model()
#人为构造输入和真实标签
input_data = torch.randn([1,10])#[1,10]代表输入一个样本,该样本的向量是10维,此处必须是二位数据
target = torch.tensor([1],dtype = torch.long)#输入一个样本时真实标签只有一个,如果输入是[5,10],则真实标签就应该为5,例如,torch.tensor([0,1,1,1,0])#模型计算,反向传播
result = model(input_data)
loss_fc = CrossEntropyLoss()
loss = loss_fc(input_data,target)
loss.backward()#优化函数优化
torch.optimizer.SGD(filter(lambda p:p.requires_grad = True,model.parameters(),lr = 0.01))

注释:上述代码片在优化函数部分对参数进行过滤,只选取requires_grad = True的参数进行优化更新。

为不同的层赋予不同的学习率

import torch
import torch.nn as nn
from torch.nn import CrossEntropyLossclass model(nn.Module):def __init__(self):super(model,self).__init__()self.lstm = nn.LSTM(input_size = 10,hidden_size = 10,requires_grad = True)self.linear = nn.Linear(input_features = 5,out_features = 2,requires_grad = True)def forward(self,inputdata):lstm_result = self.lstm(input_data)output = self.linear(lstm_result)return outputmodel = model()#人为构造输入和真实标签
input_data = torch.randn([1,10])#[1,10]代表输入一个样本,该样本的向量是10维,此处必须是二位数据
target = torch.tensor([1],dtype = torch.long)#输入一个样本时真实标签只有一个,如果输入是[5,10],则真实标签就应该为5,例如,torch.tensor([0,1,1,1,0])#模型计算,反向传播
result = model(input_data)
loss_fc = CrossEntropyLoss()
loss = loss_fc(input_data,target)
loss.backward()#使用优化函数优化过程中,为不同的层赋予不同的学习率,
param_lstm = [p for p in model.lstm.parameters()]
param_linear = [p for p in model.linear.parameters()]
params = [{'params':param_lstm,'lr':0.1},{'params':param_linear,'lr':0.01}]
torch.optimizer.SGD(params)

将两个模型参数的平均值赋予第三个模型

import torch
import torch.nn as nn
from collections import OrderedDict
#创建两个模型
model1 = nn.Linear(10,10)
model2 = nn.Linear(10,10)#获取两个模型的平均值
param_dict = {}
for key in model1.state_dict.keys():#model1.state_dict()输出值为OrderedDict类型param_key = (model1.state_dict[key] + model2.state_dict[key]) / 2param_dict[key] = param_key#将两个模型的平均值转换成OrderedDict类型,并赋予第三个模型
param_dict = OrderedDict(param_dict)
model3 = nn.Linear(10,10)#三个模型的构造必须一致
model3.load_state_dict(param_dict)

输出模型中每个层的梯度

import torch
import torch.nn as nn
from torch.nn import CrossEntropyLossclass model(nn.Module):def __init__(self):super(model,self).__init__()self.lstm = nn.LSTM(input_size = 10,hidden_size = 10,requires_grad = True)self.linear = nn.Linear(input_features = 5,out_features = 2,requires_grad = True)def forward(self,inputdata):lstm_result = self.lstm(input_data)output = self.linear(lstm_result)return outputmodel = model()#人为构造输入和真实标签
input_data = torch.randn([1,10])#[1,10]代表输入一个样本,该样本的向量是10维,此处必须是二位数据
target = torch.tensor([1],dtype = torch.long)#输入一个样本时真实标签只有一个,如果输入是[5,10],则真实标签就应该为5,例如,torch.tensor([0,1,1,1,0])#模型计算,反向传播
result = model(input_data)
loss_fc = CrossEntropyLoss()
loss = loss_fc(input_data,target)
loss.backward()#输出不同层的梯度
print(model.lstm.grad)
print(model.linear.grad)#细分输出不同层权值和偏置的梯度
print(model.lstm.weight.grad)
print(model.lstm.bias.grad)
print(model.linear.weight.grad)
print(model.linear.bias.grad)

查看模型梯度参考:https://zhuanlan.zhihu.com/p/36121066
后续还需了解如何直接为某层赋予一定的梯度。


推荐阅读
  • 如何自行分析定位SAP BSP错误
    The“BSPtag”Imentionedintheblogtitlemeansforexamplethetagchtmlb:configCelleratorbelowwhichi ... [详细]
  • Linux重启网络命令实例及关机和重启示例教程
    本文介绍了Linux系统中重启网络命令的实例,以及使用不同方式关机和重启系统的示例教程。包括使用图形界面和控制台访问系统的方法,以及使用shutdown命令进行系统关机和重启的句法和用法。 ... [详细]
  • 本文介绍了OC学习笔记中的@property和@synthesize,包括属性的定义和合成的使用方法。通过示例代码详细讲解了@property和@synthesize的作用和用法。 ... [详细]
  • sklearn数据集库中的常用数据集类型介绍
    本文介绍了sklearn数据集库中常用的数据集类型,包括玩具数据集和样本生成器。其中详细介绍了波士顿房价数据集,包含了波士顿506处房屋的13种不同特征以及房屋价格,适用于回归任务。 ... [详细]
  • Linux服务器密码过期策略、登录次数限制、私钥登录等配置方法
    本文介绍了在Linux服务器上进行密码过期策略、登录次数限制、私钥登录等配置的方法。通过修改配置文件中的参数,可以设置密码的有效期、最小间隔时间、最小长度,并在密码过期前进行提示。同时还介绍了如何进行公钥登录和修改默认账户用户名的操作。详细步骤和注意事项可参考本文内容。 ... [详细]
  • 生成式对抗网络模型综述摘要生成式对抗网络模型(GAN)是基于深度学习的一种强大的生成模型,可以应用于计算机视觉、自然语言处理、半监督学习等重要领域。生成式对抗网络 ... [详细]
  • Spring源码解密之默认标签的解析方式分析
    本文分析了Spring源码解密中默认标签的解析方式。通过对命名空间的判断,区分默认命名空间和自定义命名空间,并采用不同的解析方式。其中,bean标签的解析最为复杂和重要。 ... [详细]
  • 开发笔记:加密&json&StringIO模块&BytesIO模块
    篇首语:本文由编程笔记#小编为大家整理,主要介绍了加密&json&StringIO模块&BytesIO模块相关的知识,希望对你有一定的参考价值。一、加密加密 ... [详细]
  • android listview OnItemClickListener失效原因
    最近在做listview时发现OnItemClickListener失效的问题,经过查找发现是因为button的原因。不仅listitem中存在button会影响OnItemClickListener事件的失效,还会导致单击后listview每个item的背景改变,使得item中的所有有关焦点的事件都失效。本文给出了一个范例来说明这种情况,并提供了解决方法。 ... [详细]
  • Java容器中的compareto方法排序原理解析
    本文从源码解析Java容器中的compareto方法的排序原理,讲解了在使用数组存储数据时的限制以及存储效率的问题。同时提到了Redis的五大数据结构和list、set等知识点,回忆了作者大学时代的Java学习经历。文章以作者做的思维导图作为目录,展示了整个讲解过程。 ... [详细]
  • baresip android编译、运行教程1语音通话
    本文介绍了如何在安卓平台上编译和运行baresip android,包括下载相关的sdk和ndk,修改ndk路径和输出目录,以及创建一个c++的安卓工程并将目录考到cpp下。详细步骤可参考给出的链接和文档。 ... [详细]
  • 本文讨论了一个关于cuowu类的问题,作者在使用cuowu类时遇到了错误提示和使用AdjustmentListener的问题。文章提供了16个解决方案,并给出了两个可能导致错误的原因。 ... [详细]
  • importjava.util.ArrayList;publicclassPageIndex{privateintpageSize;每页要显示的行privateintpageNum ... [详细]
  • 不同优化算法的比较分析及实验验证
    本文介绍了神经网络优化中常用的优化方法,包括学习率调整和梯度估计修正,并通过实验验证了不同优化算法的效果。实验结果表明,Adam算法在综合考虑学习率调整和梯度估计修正方面表现较好。该研究对于优化神经网络的训练过程具有指导意义。 ... [详细]
  • web.py开发web 第八章 Formalchemy 服务端验证方法
    本文介绍了在web.py开发中使用Formalchemy进行服务端表单数据验证的方法。以User表单为例,详细说明了对各字段的验证要求,包括必填、长度限制、唯一性等。同时介绍了如何自定义验证方法来实现验证唯一性和两个密码是否相等的功能。该文提供了相关代码示例。 ... [详细]
author-avatar
匿名用户
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有