XOR

异或(XOR)问题

异或函数XOR,是两个二进制数a,b的运算,当且仅当其中一个值为1时,XOR结果为1,其余结果为0

异或

标签 数据
1 [1, 0]
1 [0, 1]
0 [1, 1]
0 [0, 0]

异或问题是典型的非线性问题

逻辑与

标签 数据
1 [1, 1]
0 [1, 0]
0 [0, 1]
0 [0, 0]

逻辑或

标签 数据
1 [1, 0]
1 [0, 1]
1 [1, 1]
0 [0, 0]

异或,逻辑与,逻辑或的散点图如下

可以看出,逻辑与和逻辑或的数据分布可以用一个线性函数进行分割,而异或无法用单一线性函数进行划分,所以XOR具有典型非线性

数据集

Lumos框架已提供xor数据集,在Lumos项目demo目录下

代码构建

我们构建一个简单的全连接神经网络来解决XOR问题,其网络结构如下

使用Lumos框架构建网络模型

Graph *g = create_graph();
Layer *l1 = make_connect_layer(8, 1, "relu");
Layer *l2 = make_connect_layer(16, 1, "relu");
Layer *l3 = make_connect_layer(2, 1, "linear");
Layer *l4 = make_crossentropy_layer(NULL, -1);

我们使用crossentropy分类器进行分类

接下来构建会话,并设置相关训练超参数

Session *sess = create_session(g, 1, 2, 1, 2, type, path);
set_train_params(sess, 150, 4, 4, 0.1);
SGDOptimizer_sess(sess, 0.9, 0, 0, 0, 0);
init_session(sess, "./demo/xor/data.txt", "./demo/xor/label.txt");

我们使用SGD参数优化器进行参数优化

完整代码如下

#include "xor.h"

void xor(char *type, char *path)
{
    Graph *g = create_graph();
    Layer *l1 = make_connect_layer(8, 1, "relu");
    Layer *l2 = make_connect_layer(16, 1, "relu");
    Layer *l3 = make_connect_layer(2, 1, "linear");
    Layer *l4 = make_crossentropy_layer(NULL, -1);
    append_layer2grpah(g, l1);
    append_layer2grpah(g, l2);
    append_layer2grpah(g, l3);
    append_layer2grpah(g, l4);
    Session *sess = create_session(g, 1, 2, 1, 2, type, path);
    set_train_params(sess, 150, 4, 4, 0.1);
    SGDOptimizer_sess(sess, 0.9, 0, 0, 0, 0);
    init_session(sess, "./demo/xor/data.txt", "./demo/xor/label.txt");
    train(sess);
}

void xor_detect(char *type, char *path)
{
    Graph *g = create_graph();
    Layer *l1 = make_connect_layer(8, 1, "relu");
    Layer *l2 = make_connect_layer(16, 1, "relu");
    Layer *l3 = make_connect_layer(2, 1, "linear");
    Layer *l4 = make_crossentropy_layer(NULL, -1);
    append_layer2grpah(g, l1);
    append_layer2grpah(g, l2);
    append_layer2grpah(g, l3);
    append_layer2grpah(g, l4);
    Session *sess = create_session(g, 1, 2, 1, 2, type, path);
    set_detect_params(sess);
    init_session(sess, "./demo/xor/data.txt", "./demo/xor/label.txt");
    detect_classification(sess);
}

在Lumos框架中demo目录下,您能找到xor.c文件,这就是我们已实现的XOR模型

结果展示

该网络在经过150个epoch训练后,可以准确的对XOR数据进行分类,分类精度100%