多分类问题
- 手写数字分类:10种分类
- 图片分类:成千上万种分类
- 多种特征:需要多层神经网络
- 多种分类输出:需要在输出层加softMax函数
鸢尾花(iris)分类问题
- 著名的数据集,诞生了很久,被无数科学家用来验证自己的算法
- 三种分类(输出):山鸢尾,变色鸢尾,Virginica鸢尾
- 四种特征(输入):花瓣长、宽、花萼长、宽
加载IRIS数据集(训练集与验证集)
- 用脚本生成IRIS数据集:训练集+验证集
1 | //生成:训练集特征,训练集标签,验证集特征,验证集标签(数据类型为tensor) |
- 打印结果:
定义模型结构:带有softMax激活函数的多层神经网络
- 初始化一个神经网络模型
1 | //初始化模型 |
- 为模型添加两个层
- 设计层的神经元个数、inputShape、激活函数
1 | //添加隐藏层:全链接层 |
交叉熵损失函数
- 交叉熵损失函数 Cross-Entropy:是LogLoss对数损失函数的多分类版本,都用于度量分类神经网络模型的性能。
- 当分类数为2时,交叉熵损失=对数损失。
定义损失函数、优化器、准确度度量
1 | //设置损失函数,增加训练过程中的“准确度”度量 |
训练模型并可视化
1 | //训练并可视化 |
- 训练过程:
模型多分类预测
1 | window.predict = (form)=>{ |