热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

Caffe源码阅读(1)全连接层

Caffe源码阅读(1)全连接层发表于2014-09-15|Caffe源码阅读(1)全连接层发表于2014-09-15|今天看全连接层的实现。主要看的是ht

Caffe源码阅读(1) 全连接层

今天看全连接层的实现。
主要看的是https://github.com/BVLC/caffe/blob/master/src/caffe/layers/inner_product_layer.cpp

主要是三个方法,setup,forward,backward

  • setup 初始化网络参数,包括了w和b
  • forward 前向传播的实现
  • backward 后向传播的实现

setup

主体的思路,作者的注释给的很清晰。
主要是要弄清楚一些变量对应的含义

1
2
3
M_ 表示的样本数
K_ 表示单个样本的特征长度
N_ 表示输出神经元的个数

 

为了打字方便,以下省略下划线,缩写为M,K,N

forward

实现的功能就是 y=wx+b

1
2
3
4
x为输入,维度 MxK
y为输出,维度 Nx1
w为权重,维度 NxK
b为偏置,维度 Nx1

 

具体到代码实现,用的是这个函数caffe_cpu_gemm,具体的函数头为

1
2
3
4
void caffe_cpu_gemm<float>(const CBLAS_TRANSPOSE TransA,
const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
const float alpha, const float* A, const float* B, const float beta,
float* C)

 

略长,整理它的功能其实很直观,即C←αA×B+βC

1
2
3
4
5
6
7
8
9
10
11
12
13
14
const CBLAS_TRANSPOSE TransA  # A是否转置
const CBLAS_TRANSPOSE TransB # B是否转置

# 这部分都比较直观不用解释了
const int M
const int N
const int K
const float alpha
const float* A
const float* B
const float beta,
float* C

# 其中A维度是MxK,B维度是KxN,C维度为MxN

 

从实际代码来算,全连接层的forward包括了两步:

1
2
3
4
5
6
7
8
# 这一步表示 y←wx,或者说是y←xw\'
caffe_cpu_gemmtype>(CblasNoTrans, CblasTrans, M_, N_, K_, (Dtype)1.,
bottom_data, weight, (Dtype)0., top_data);
# 这一步表示 y←y+b
caffe_cpu_gemmtype>(CblasNoTrans, CblasNoTrans, M_, N_, 1, (Dtype)1.,
bias_multiplier_.cpu_data(),
this->blobs_[1]->cpu_data(), (Dtype)1., top_data);
# 所以两步连起来就等价于y=wx+b

 

backward

分成三步:

  • 更新w
  • 更新b
  • 计算delta

用公式来说是下面三条:

一步步来,先来第一步,更新w,对应代码是:

1
2
caffe_cpu_gemm(CblasTrans, CblasNoTrans, N_, K_, M_, (Dtype)1.,
top_diff, bottom_data, (Dtype)0., this->blobs_[0]->mutable_cpu_diff());

 

对照公式,有

1
2
3
需要更新的w的梯度的维度是NxK
公式中的a^(l)_j对应的是bottom_data,维度是KxM
公式中的\delta_(l+1)_i对应的是top_diff,维度是NxM

 

然后是第二步,更新b,对应代码是:

1
2
3
caffe_cpu_gemv(CblasTrans, M_, N_, (Dtype)1., top_diff,
bias_multiplier_.cpu_data(), (Dtype)0.,
this->blobs_[1]->mutable_cpu_diff());

 

这里用到了caffe_cpu_gemv,简单来说跟上面的caffe_cpu_gemm类似,不过前者是计算矩阵和向量之间的乘法的(从英文命名可以分辨,v for vector, m for matrix)。函数头:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
void caffe_cpu_gemv<float>(const CBLAS_TRANSPOSE TransA, const int M,
const int N, const float alpha, const float* A, const float* x,
const float beta, float* y)

# 实现的功能类似 Y←αAX + βY
# 其中A的维度为 MxN
# X是一个向量,维度为 Nx1
# Y是结果 ,也是一个向量,维度为Mx1

const CBLAS_TRANSPOSE TransA # 是否对A进行转置

# 下面的参数很直观,不描述了
const int M
const int N
const float alpha
const float* A
const float* x
const float beta
float* y

 

绕回到具体的代码实现。。如何更新b?根据公式b的梯度直接就是delta

1
2
3
4
# 所以对应的代码其实就是将top_diff转置后就可以了(忽略乘上bias_multiplier这步)
caffe_cpu_gemvtype>(CblasTrans, M_, N_, (Dtype)1., top_diff,
bias_multiplier_.cpu_data(), (Dtype)0.,
this->blobs_[1]->mutable_cpu_diff());

 

第三步是计算delta,对应公式

这里面可以忽略掉最后一项f’,因为在caffe实现中,这是由Relu layer来实现的,这里只需要实现括号里面的累加就好了,这个累加其实可以等价于矩阵乘法

1
2
3
4
5
6
7
caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, M_, K_, N_, (Dtype)1.,
top_diff, this->blobs_[0]->cpu_data(), (Dtype)0.,
(*bottom)[0]->mutable_cpu_diff());

# top_diff为\delta^(l+1)_j 维度 MxN
# this->blobs_[0]->cpu_data()为W^(l)_ji 维度 NxK
# (*bottom)[0]->mutable_cpu_diff()是要计算的结果,也就是\delta^(l)_i 维度是MxK

 

附录

又及,这里具体计算矩阵相乘用的是blas的功能,描述页面我参考的是:https://developer.apple.com/library/mac/documentation/Accelerate/Reference/BLAS_Ref/Reference/reference.html#//apple_ref/c/func/cblas_sgemm


推荐阅读
  • 本文主要解析了Open judge C16H问题中涉及到的Magical Balls的快速幂和逆元算法,并给出了问题的解析和解决方法。详细介绍了问题的背景和规则,并给出了相应的算法解析和实现步骤。通过本文的解析,读者可以更好地理解和解决Open judge C16H问题中的Magical Balls部分。 ... [详细]
  • EPICS Archiver Appliance存储waveform记录的尝试及资源需求分析
    本文介绍了EPICS Archiver Appliance存储waveform记录的尝试过程,并分析了其所需的资源容量。通过解决错误提示和调整内存大小,成功存储了波形数据。然后,讨论了储存环逐束团信号的意义,以及通过记录多圈的束团信号进行参数分析的可能性。波形数据的存储需求巨大,每天需要近250G,一年需要90T。然而,储存环逐束团信号具有重要意义,可以揭示出每个束团的纵向振荡频率和模式。 ... [详细]
  • 目录实现效果:实现环境实现方法一:基本思路主要代码JavaScript代码总结方法二主要代码总结方法三基本思路主要代码JavaScriptHTML总结实 ... [详细]
  • 本文介绍了九度OnlineJudge中的1002题目“Grading”的解决方法。该题目要求设计一个公平的评分过程,将每个考题分配给3个独立的专家,如果他们的评分不一致,则需要请一位裁判做出最终决定。文章详细描述了评分规则,并给出了解决该问题的程序。 ... [详细]
  • baresip android编译、运行教程1语音通话
    本文介绍了如何在安卓平台上编译和运行baresip android,包括下载相关的sdk和ndk,修改ndk路径和输出目录,以及创建一个c++的安卓工程并将目录考到cpp下。详细步骤可参考给出的链接和文档。 ... [详细]
  • 使用在线工具jsonschema2pojo根据json生成java对象
    本文介绍了使用在线工具jsonschema2pojo根据json生成java对象的方法。通过该工具,用户只需将json字符串复制到输入框中,即可自动将其转换成java对象。该工具还能解析列表式的json数据,并将嵌套在内层的对象也解析出来。本文以请求github的api为例,展示了使用该工具的步骤和效果。 ... [详细]
  • 本文介绍了P1651题目的描述和要求,以及计算能搭建的塔的最大高度的方法。通过动态规划和状压技术,将问题转化为求解差值的问题,并定义了相应的状态。最终得出了计算最大高度的解法。 ... [详细]
  • 推荐系统遇上深度学习(十七)详解推荐系统中的常用评测指标
    原创:石晓文小小挖掘机2018-06-18笔者是一个痴迷于挖掘数据中的价值的学习人,希望在平日的工作学习中,挖掘数据的价值, ... [详细]
  • 解决VS写C#项目导入MySQL数据源报错“You have a usable connection already”问题的正确方法
    本文介绍了在VS写C#项目导入MySQL数据源时出现报错“You have a usable connection already”的问题,并给出了正确的解决方法。详细描述了问题的出现情况和报错信息,并提供了解决该问题的步骤和注意事项。 ... [详细]
  • 闭包一直是Java社区中争论不断的话题,很多语言都支持闭包这个语言特性,闭包定义了一个依赖于外部环境的自由变量的函数,这个函数能够访问外部环境的变量。本文以JavaScript的一个闭包为例,介绍了闭包的定义和特性。 ... [详细]
  • IhaveconfiguredanactionforaremotenotificationwhenitarrivestomyiOsapp.Iwanttwodiff ... [详细]
  • Webpack5内置处理图片资源的配置方法
    本文介绍了在Webpack5中处理图片资源的配置方法。在Webpack4中,我们需要使用file-loader和url-loader来处理图片资源,但是在Webpack5中,这两个Loader的功能已经被内置到Webpack中,我们只需要简单配置即可实现图片资源的处理。本文还介绍了一些常用的配置方法,如匹配不同类型的图片文件、设置输出路径等。通过本文的学习,读者可以快速掌握Webpack5处理图片资源的方法。 ... [详细]
  • 关于我们EMQ是一家全球领先的开源物联网基础设施软件供应商,服务新产业周期的IoT&5G、边缘计算与云计算市场,交付全球领先的开源物联网消息服务器和流处理数据 ... [详细]
  • XML介绍与使用的概述及标签规则
    本文介绍了XML的基本概念和用途,包括XML的可扩展性和标签的自定义特性。同时还详细解释了XML标签的规则,包括标签的尖括号和合法标识符的组成,标签必须成对出现的原则以及特殊标签的使用方法。通过本文的阅读,读者可以对XML的基本知识有一个全面的了解。 ... [详细]
  • 本文介绍了如何使用Express App提供静态文件,同时提到了一些不需要使用的文件,如package.json和/.ssh/known_hosts,并解释了为什么app.get('*')无法捕获所有请求以及为什么app.use(express.static(__dirname))可能会提供不需要的文件。 ... [详细]
author-avatar
心动心爱_342
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有