会话构建Session

Session用于操作和控制模型的上层系列接口

创建会话实例

Session *create_session(Graph *graph, int h, int w, int c, int truth_num, char *type, char *path);
参数 描述
Graph *graph 计算图
int h 输入图像数据的height
int w 输入图像数据的width
int c 输入图像数据的channel
int truth_num 离散标签数据个数
char *type 运行内核选择CPU/GPU
char *path 权重文件路径

初始化会话

void init_session(Session *sess, char *data_path, char *label_path);
参数 描述
Session *sess 会话实例
char *data_path 数据路径
char *label_path 标签路径

设置训练超参数

void set_train_params(Session *sess, int epoch, int batch, int subdivision, float learning_rate);
参数 描述
Session *sess 会话实例
int epoch 训练轮次
int batch 随机梯度下降批次大小
int subdivision 批次分割大小
float learning_rate 步长(学习率)

设置测试超参数

void set_detect_params(Session **sess*)
参数 描述
Session *sess 会话实例

运行训练

void train(Session *sess);
参数 描述
Session *sess 会话实例

运行测试

void detect_classification(Session *sess);
参数 描述
Session *sess 会话实例

设置参数优化器

void SGDOptimizer_sess(Session *sess, float momentum, float dampening, float decay, int nesterov, int maximize);
参数 描述
Session *sess 会话实例
float momentum 动量参数[0,1]
float dampening 动量衰减率[0,1]
float decay L2范数惩罚参数[0,1]
int nesterov 是否开启nesterov动量法(bool)
int maximize 梯度最小(0)梯度最大(1)

设置数据缩放

void transform_resize_sess(Session *sess, int height, int width);
参数 描述
Session *sess 会话实例
int height 缩放height
int width 缩放width

设置数据归一化

void transform_normalize_sess(Session *sess, float *mean, float *std);
参数 描述
Session *sess 会话实例
float *mean 通道对应均值
float *std 通道对应方差

以图像通道为单位,将图像数据归一化为均值为mean方差为std的分布区间