
会话构建Session
Session用于操作和控制模型的上层系列接口
创建会话实例
Session *create_session(Graph *graph, int h, int w, int c, int truth_num, char *type, char *path);
| 参数 | 描述 |
|---|---|
| Graph *graph | 计算图 |
| int h | 输入图像数据的height |
| int w | 输入图像数据的width |
| int c | 输入图像数据的channel |
| int truth_num | 离散标签数据个数 |
| char *type | 运行内核选择CPU/GPU |
| char *path | 权重文件路径 |
初始化会话
void init_session(Session *sess, char *data_path, char *label_path);
| 参数 | 描述 |
|---|---|
| Session *sess | 会话实例 |
| char *data_path | 数据路径 |
| char *label_path | 标签路径 |
设置训练超参数
void set_train_params(Session *sess, int epoch, int batch, int subdivision, float learning_rate);
| 参数 | 描述 |
|---|---|
| Session *sess | 会话实例 |
| int epoch | 训练轮次 |
| int batch | 随机梯度下降批次大小 |
| int subdivision | 批次分割大小 |
| float learning_rate | 步长(学习率) |
设置测试超参数
void set_detect_params(Session **sess*)
| 参数 | 描述 |
|---|---|
| Session *sess | 会话实例 |
运行训练
void train(Session *sess);
| 参数 | 描述 |
|---|---|
| Session *sess | 会话实例 |
运行测试
void detect_classification(Session *sess);
| 参数 | 描述 |
|---|---|
| Session *sess | 会话实例 |
设置参数优化器
void SGDOptimizer_sess(Session *sess, float momentum, float dampening, float decay, int nesterov, int maximize);
| 参数 | 描述 |
|---|---|
| Session *sess | 会话实例 |
| float momentum | 动量参数[0,1] |
| float dampening | 动量衰减率[0,1] |
| float decay | L2范数惩罚参数[0,1] |
| int nesterov | 是否开启nesterov动量法(bool) |
| int maximize | 梯度最小(0)梯度最大(1) |
设置数据缩放
void transform_resize_sess(Session *sess, int height, int width);
| 参数 | 描述 |
|---|---|
| Session *sess | 会话实例 |
| int height | 缩放height |
| int width | 缩放width |
设置数据归一化
void transform_normalize_sess(Session *sess, float *mean, float *std);
| 参数 | 描述 |
|---|---|
| Session *sess | 会话实例 |
| float *mean | 通道对应均值 |
| float *std | 通道对应方差 |
以图像通道为单位,将图像数据归一化为均值为mean方差为std的分布区间