XOR 异或逻辑回归

image

Google Playground网站

image

加载XOR数据集

1
2
3
//调脚本接口生成模拟数据
const data = getData(400);
console.log(data);
1
2
3
4
5
6
7
8
9
10
//可视化
tfvis.render.scatterplot(
{ name: "XOR训练数据" },
{
values: [
data.filter(p => p.label === 1),
data.filter(p => p.label === 0)
]
}
);

image

定义模型结构:多层神经网络

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
//初始化一个 sequential model
const mdoel = tf.sequential();

//添加一个隐藏层(全连接层)
mdoel.add(
tf.layers.dense({
units: 4,
inputShape: [2], //只有第一层需要设置inputShape
activition: "relu"
})
);

//添加一个输出层(全连接层)
model.add(
tf.layers.dense({
units: 1,
activition: "sigmoid" //需要输出[0,1]之间的概率所以选sigmoid
})
);
1
2
3
4
5
//定义模型的损失函数和优化器
model.compile({
loss: tf.losses.logLoss,
optimizer: tf.train.adam(0.1)
});

训练模型并预测

1
2
3
//训练数据转换为tensor
const inputs = tf.tensor(data.map(p => [p.x, p.y]));
const labels = tf.tensor(data.map(p => p.label));
1
2
3
4
5
//训练
await model.fit(inputs, labels, {
epochs: 10,
callbacks: tfvis.show.fitCallbacks({ name: "XOR训练过程" }, ["loss"])
});
1
2
3
4
5
6
7
//预测
window.predict = form => {
const pred = model.predict(
tf.tensor([[form.x.value * 1, form.y.value * 1]])
);
alert(`预测结果:${pred.dataSync()[0]}`);
};

image


代码仓库