热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

PyTorch常见预训练模型的下载链接及使用指南

本文提供了PyTorch框架中常用的预训练模型的下载链接及详细使用指南,涵盖ResNet、Inception、DenseNet、AlexNet、VGGNet等六大分类模型。每种模型的预训练参数均经过精心调优,适用于多种计算机视觉任务。文章不仅介绍了模型的下载方式,还详细说明了如何在实际项目中高效地加载和使用这些模型,为开发者提供全面的技术支持。

pytorch框架:常用模型的预训练参数

六大分类模型下载方式和使用方法:
Resnet
inception
Densenet
Alexnet
vggnet

Resnet:
model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}
inception:
model_urls = {
# Inception v3 ported from TensorFlow
'inception_v3_google': 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth',
}
Densenet:
model_urls = {
'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth',
'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth',
'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth',
'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth',
}
Alexnet:
model_urls = {
'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth',
}
vggnet:
model_urls = {
'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth',
'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
}

学习内容:测试实现预训练模型的使用,并牢记该方式-拿为己用

关键步骤讲述:



  1. 默认已经安装好环境和pytorch框架,以及torchvision等需要的库。



  2. import torchvision.models as models 所有成熟网络模型几乎都在里面



  3. # 初始化模型 model = models.resnet18()此处应用ResNet18来分类。



  4. 修改尾巴,毕竟你的输出不一定和原版(1000)一模一样。
    # 修改网络结构,将fc层1000个输出改为9个输出。
    # 获取最后一层的输入特征层信息。 fc_input_feature = model.fc.in_features
    # 取代原来输出层为新的nn。 model.fc = nn.Linear(fc_input_feature, 9)到这里,网络就构建好了。



  5. 下载预训练参数,为己所用。# load除最后一层的预训练权重 pretrained_weight = torch.hub.load_state_dict_from_url( url='https://download.pytorch.org/models/resnet18-5c106cde.pth', progress=True)到这里,下载的是原版的1000分类的参数,我们需要删除不需要的尾巴,并训练自己的尾巴。del pretrained_weight['fc.weight']
    del pretrained_weight['fc.bias']因为分类就是用的线性函数,包括权重w和偏移b,只需删除尾巴。



  6. 最后,将剩下的模型参数load到我们的模型上即可。model.load_state_dict(pretrained_weight, strict=False)模型准备完毕,剩下的操作和所有训练方法一样。参见详细训练代码。



import os
import torch
from torch.utils.data import DataLoader
from torch import nn
from torch import optim
import torchvision.models as models
import time
# use res18
# from resnet.resnetmini import ClassificModel as Model
from datasets.read_data_sleep import PlayPhoneData
def train(data_path=r"E:\Datasets\sleep_traindata"):
# 设置超参数
batch_size = 1 # 每次训练的数据量
LR = 0.01 # 学习率
STEP_SIZE = 5 # 控制学习率变化
MAX_EPOCH = 20 # 总的训练次数
num_print = 100 # 每n个batch打印一次
playPhoneData = PlayPhoneData(data_path)
# 利用dataloader加载数据集
train_loader = torch.utils.data.DataLoader(playPhoneData, batch_size=batch_size, shuffle=True, drop_last=True)
# 生成驱动器
use_gpu = torch.cuda.is_available()
if use_gpu:
print('congratulation! You can use gpu to support acceleration')
else:
print('oppps, please use a small batch size')
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# 初始化模型
model = models.resnet18()
# 修改网络结构,将fc层1000个输出改为9个输出
fc_input_feature = model.fc.in_features
model.fc = nn.Linear(fc_input_feature, 9)
# load除最后一层的预训练权重
pretrained_weight = torch.hub.load_state_dict_from_url(
url='https://download.pytorch.org/models/resnet18-5c106cde.pth', progress=True)
del pretrained_weight['fc.weight']
del pretrained_weight['fc.bias']
model.load_state_dict(pretrained_weight, strict=False)
model.to(device)
# net = Model(8).to(device) # class_num=8分八类:睡岗(趴着睡,躺着睡,仰着睡,低头睡),玩手机(俯视玩手机,平视玩手机,侧视玩手机),其他=[0,1,2,3,4,5,6,7]
# net = Model(9).to(device) # class_num=9分九类:睡岗(趴着睡,躺着睡,低头睡),站立,半蹲,坐(背坐,正坐,侧坐),其他=[0,1,2,3,4,5,6,7,8]
# 损失函数
get_loss = nn.CrossEntropyLoss() #交叉熵损失函数
# SGD优化器 第一个参数是输入需要优化的参数,第二个是学习率,第三个是动量,大致就是借助上一次导数结果,加快收敛速度。
'''
这一行代码里面实际上包含了多种优化:
一个是动量优化,增加了一个关于上一次迭代得到的系数的偏置,借助上一次的指导,减小梯度震荡,加快收敛速度
一个是权重衰减,通过对权重增加一个(正则项),该正则项会使得迭代公式中的权重按照比例缩减,这么做的原因是,过拟合的表现一般为参数浮动大,使用小参数可以防止过拟合
'''
optimizer = optim.SGD(model.parameters(), lr=LR, momentum=0.9, weight_decay=0.001)
# optimizer = optim.Adam(net.parameters(), lr=learn_rate)
# 动态调整学习率 StepLR 是等间隔调整学习率,每step_size 令lr=lr*gamma
# 学习率衰减,随着训练的加深,目前的权重也越来越接近最优权重,原本的学习率会使得,loss上下震荡,逐步减小学习率能加快收敛速度。
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=STEP_SIZE, gamma=0.5, last_epoch=-1)
# Step:设置学习率下降策略
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
loss_list = []
start = time.time()
for epoch in range(MAX_EPOCH):
running_loss = 0.0
# enumerate()是python自带的函数,用于迭代字典。参数1,是需要迭代的对象,第二参数是迭代的起始位置
for i, (inputs, labels) in enumerate(train_loader, 0):
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs) # 前向传播求出预测的值
optimizer.zero_grad() # 将梯度初始化为0
loss = get_loss(outputs, labels.long())
loss.backward() # 反向传播求梯度
optimizer.step() # 更新所有参数
running_loss += loss.item() # loss是张量,访问值时需要使用item()
loss_list.append(loss.item())
if i % num_print == num_print - 1: # 每num_print打印平均loss
print('[%d epoch, %d] loss: %.6f' % (epoch + 1, i + 1, running_loss / num_print))
running_loss = 0.0
lr = optimizer.param_groups[0]['lr'] # 查看目前的学习率
print('learn_rate : %.5f' % lr)
scheduler.step() # 根据迭代epoch更新学习率
end = time.time()
print('time:{}'.format(end - start))
torch.save(model, f'E:/model/playphone+sleepthepose/model_resnetmini_睡岗9分类{end}.pth')
if __name__ == "__main__":
train()

训练情况:

......
[3 epoch, 500] loss: 2.186424
[3 epoch, 600] loss: 2.192622
[3 epoch, 700] loss: 2.165229
[3 epoch, 800] loss: 2.125184
[3 epoch, 900] loss: 2.185377
learn_rate : 0.01000
[4 epoch, 100] loss: 2.138786
[4 epoch, 200] loss: 2.177925
[4 epoch, 300] loss: 2.103718
......

备注:代码只是讲解工具,并非可以运行的实例,因为里面的数据集需要有并自己写数据集的代码。




学习内容:进阶应用方法

直接拿来用固然不错,但自己分装一遍再用,显得更加标准,有水平。
比如封装如下:


class ResNet18forClassify(nn.Module):
def __init__(self, phase="train"):
super(ResNet18forClassify, self).__init__()
self.phase = phase
self.net = models.resnet18()
fc_input_feature = self.net.fc.in_features
self.net.fc = nn.Linear(fc_input_feature, 9)
pretrained_weight = torch.hub.load_state_dict_from_url(
url='https://download.pytorch.org/models/resnet18-5c106cde.pth', progress=True)
del pretrained_weight['fc.weight']
del pretrained_weight['fc.bias']
self.net.load_state_dict(pretrained_weight, strict=False)
self.softmax = nn.Softmax(dim=1)
def forward(self, input_img):
out = self.net(input_img)
if self.phase == "test":
return self.softmax(out)
return out

备注:封装成自己的网络模型,更加方便。
其中,if self.phase == "test": return self.softmax(out),分类时训练输出的是类别标签与实际标签做损失计算;测试时,预测结果由激活函数转换为–类型和该类型可能性概率。输出可能是该类别的概率值。


参考文献:

1.https://github.com/pytorch/vision/tree/master/torchvision/models
2.环境搭建:NVIDIA+CUDA+cudaNN的配置与Anaconda虚拟环境的搭建–深度学习第一步
3.Parallax:常用预训练模型下载地址



来源:柏常青



推荐阅读
  • 本文探讨了在Git子模块目录中运行pre-commit时遇到的错误,并提供了一种通过Docker环境解决此问题的方法。 ... [详细]
  • 本文介绍如何利用Python中的pyftpdlib库快速搭建一个功能完备的FTP服务器。此示例代码采用基础配置,适合初学者理解FTP服务器的工作机制,包括用户权限管理、连接限制及被动端口设置等。 ... [详细]
  • 本文详细介绍了 TensorFlow 的入门实践,特别是使用 MNIST 数据集进行数字识别的项目。文章首先解析了项目文件结构,并解释了各部分的作用,随后逐步讲解了如何通过 TensorFlow 实现基本的神经网络模型。 ... [详细]
  • 本文介绍了一种算法,用于从一个整数的末尾获取第 K 位数字。如果该位置不存在,则返回 -1。 ... [详细]
  • 吴裕雄探讨混合神经网络模型在深度学习中的应用:结合RNN与CNN优化网络性能
    本文由吴裕雄撰写,深入探讨了如何利用Python、Keras及TensorFlow构建混合神经网络模型,特别是通过结合递归神经网络(RNN)和卷积神经网络(CNN),实现对网络运行效率的有效提升。 ... [详细]
  • Alluxio 1.5.0 版本发布:增强功能与优化
    Alluxio 1.5.0 开源版本引入了多项新特性和改进,旨在提升数据访问速度和系统互操作性。 ... [详细]
  • 基于鲁棒弹性变形的视差容忍图像拼接 - Python实现
    本文介绍了《视差容忍图像拼接基于鲁棒弹性变形》(Parallax-Tolerant Image Stitching Based on Robust Elastic Warping)论文的Python实现,主要针对两幅图像的拼接。该方法通过弹性变形技术提高图像拼接的质量,尤其是在存在视差的情况下。 ... [详细]
  • 本文详细介绍如何通过Anaconda 3.5.01快速安装TensorFlow,包括环境配置和具体步骤。 ... [详细]
  • 本文通过Python编程语言,利用Pandas和Matplotlib库,对电影数据集中的类型字段进行处理,实现电影类型的统计分析及可视化展示。 ... [详细]
  • 本文档详细介绍了如何使用XIB文件创建和管理具有不同高度的单元格,通过具体的代码示例展示了在iOS开发中实现这一功能的方法。 ... [详细]
  • Spring Cloud实践:构建Eureka单节点注册中心
    本文详细介绍如何在Spring Cloud环境下搭建Eureka单节点注册中心,包括项目初始化、依赖添加、配置设置及启动测试等步骤。 ... [详细]
  • MQTT协议:轻量级消息传输的基石
    MQTT(Message Queuing Telemetry Transport,消息队列遥测传输)是一种基于发布/订阅模式的轻量级通信协议,适用于低带宽、高延迟或不可靠的网络环境。该协议基于TCP/IP构建,由IBM在1999年首次推出,旨在通过最小化网络流量和代码量,为远程设备提供高效、可靠的消息传输服务。 ... [详细]
  • Python标准库概览:shelve模块的使用
    当项目需要一个简单且高效的存储方案时,Python的shelve模块是一个不错的选择。本文将详细介绍如何利用shelve模块进行基本的数据持久化操作,包括如何打开数据库、进行数据的增删查改等。 ... [详细]
  • 本文探讨了如何在Python中处理长数据的完全显示问题,包括numpy数组、pandas DataFrame以及tensor类型的完整输出设置。 ... [详细]
  • 本文探讨了如何在TensorFlow中使用张量来处理和分析数字图像,特别是通过具体的代码示例展示了张量在图像处理中的作用。 ... [详细]
author-avatar
UUUUUUUUUU8
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有