< NCNN-Lession-8 > 读取网络的权重信息

开始

今天我们要学些ncnn怎么load权重信息,咱们再插一个小红旗:

作用

作用当然是给含有权重的层把权重(与层的参数区分开)给load到对应的内存中.

包含有权重的层最常用也就两种:

卷积层

bn层

我们今天就以卷积层为例子,来说一下ncnn中Load 权重的实现.

实现

实现Load 权重的功能其实很简单,因为我们在第一节就学习了数据读取类DataReader,我们读取网络proto的时候是就是用的这个类,在读取权重的时候,也是使用的这个类.我们重新把DataReader的框架图放到下面来再看一下:

我们在读取proto信息的时候,主要是用了scan这个成员函数,因为proto信息中结构化的信息比较多.

在我们读取网络权重的时候,主要需要用到read这个函数,因为权重是以二进制的方式存储的,我们在知道elemsize和权重个数的情况下,就可以方便的通过read这个成员函数读取.

ncnn中用ModelBinFromDataReader这个类来实现权重的读取,它其实是对于DataReader这个类的一种包装,它的框架图如下:

成员变量就不用多说,我们看一下它的成员函数的实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
Mat ModelBinFromDataReader::load(int w, int type) const {                                               
if (type == 0) {
size_t nread;

union {
struct {
unsigned char f0;
unsigned char f1;
unsigned char f2;
unsigned char f3;
};
unsigned int tag;

} flag_struct;

nread = dr.read(&flag_struct, sizeof(flag_struct));

unsigned int flag = flag_struct.f0 + flag_struct.f1 + flag_struct.f2 + flag_struct.f3;
Mat m(w);
if (m.empty()) return m;
if (flag != 0) {
return m;
} else if (flag_struct.f0 == 0) {
nread = dr.read(m, w * sizeof(float));
}
return m;
} else if (type == 1) {
Mat m(w);
if (m.empty()) return m;
size_t nread = dr.read(m, w * sizeof(float));
return m;
}
return Mat();
}

load这个函数,它有两个参数:

w 它表示需要读取的元素个数

type 是否去要按照不同的精度读取

这里,对type做一下说明,当type为0的时候,需要读取数据的头部4个字节,然后由头部四个字节的数据来决定按照什么精度读取.这里我对代码做了简化,按照默认的float精度(4字节)的去读.

当type为1的时候,则不读取头部的四个字节,直接按照默认的float精度(4字节)去读.

有了ModelBinFromDataReader这个类,我们就可以在需要load权重的类里面实现一个load_model的函数去调用ModelBinFromDataReader这个类的对象.

在不需要load权重的类里面就不用实现load_model的函数,只需要继承来自Layer父类的load_model函数就行.

最终在net这个类里面,也实现一个load_model的函数,在这个函数里,按顺序遍历每一个层的load_model成员函数即可.

代码示例

测试程序在这里

代码结构如下:

< NCNN-Lession-8 > 读取网络的权重信息

https://zhengtq.github.io/2020/12/18/ncnn-lesson-8/

Author

Billy

Posted on

2020-12-18

Updated on

2021-03-16

Licensed under

Comments