多分类问题

鸢尾花(iris)分类问题

加载IRIS数据集(训练集与验证集)

1
2
3
4
5
6
7
//生成:训练集特征,训练集标签,验证集特征,验证集标签(数据类型为tensor)
const [xTrain,yTrain,xTest,yTest] = getIrisData(0.15);
xTrain.print();
yTrain.print();
xTest.print();
yTest.print();
console.log(IRIS_CLASSES);

image

定义模型结构:带有softMax激活函数的多层神经网络

1
2
//初始化模型
const model = tf.sequential();
1
2
3
4
5
6
7
8
9
10
11
12
//添加隐藏层:全链接层
model.add(tf.layers.dense({
units:10, //10个神经元,超参数
inputShape:[xTrain.shape[1]],
activation:'sigmoid',
}));
//添加输出层:全链接层
model.add(tf.layers.dense({
units:3, //必须是输出类别的个数
// inputShape:[yTrain.shape], //除了第一层以外 都不需要设计inputShape,会根据上一层的输出自动设计
activation:'softmax', //softmax 激活函数 适用于多种分类输出层
}));

交叉熵损失函数

image

定义损失函数、优化器、准确度度量

1
2
3
4
5
6
//设置损失函数,增加训练过程中的“准确度”度量
model.compile({
loss:'categoricalCrossentropy',
optimizer: tf.train.adam(0.1),
metrics: ['accuracy']
});

训练模型并可视化

1
2
3
4
5
6
7
8
9
10
//训练并可视化
await model.fit(xTrain,yTrain,{
epochs:100,
validationData:[xTest,yTest],
callbacks:tfvis.show.fitCallbacks(
{name:''},
['loss','val_loss','acc','val_acc'],
{callbacks:['onEpochEnd']},
),
})

image

模型多分类预测

1
2
3
4
5
6
7
8
9
10
11
12
13
window.predict = (form)=>{
const input = tf.tensor([[
form.a.value * 1,
form.b.value * 1,
form.c.value * 1,
form.d.value * 1,
]]);

debugger;
const pred = model.predict(input);
pred.print();
alert(`预测结果${IRIS_CLASSES[pred.argMax(1).dataSync(0)]}`);
}

image


代码仓库