迁移学习

加载商标训练数据并可视化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
// data.js
const loadImg =(src)=>{
return new Promise(resolve=>{
const img = document.createElement('img')
img.crossOrigin = 'anonymous'
img.src = src
img.width = 224 //以mobileNet为截断模型,其接收图片尺寸为224
img.height = 224
img.onload=()=>reslove(img)
})
}

// 返回Promise
export const getInputs = async()=>{
const loadImgs = []
const labels = []
for(let i=0;i<30;i+=1){
['android','apple','windows'].forEach(label=>{
const imgP = loadImg(`http://127.0.0.1:8080/brand/train/${label}-${index}.jpg`)
loadImgs.push(imgP)
labels.push([
label === 'android' ? 1 :0,
label === 'apple' ? 1 :0,
label === 'windows' ? 1 :0,
])
})
}
const inputs = await Promise.all(loadImgs)
return{
inputs, labels
}
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
// script.js
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
import { getInputs } from './data';
import { img2x, file2img } from './utils';

const MOBILENET_MODEL_PATH = 'http://127.0.0.1:8080/mobilenet/web_model/model.json';
const NUM_CLASSES = 3;
const BRAND_CLASSES = ['android', 'apple', 'windows'];

window.onload = async () => {
const { inputs, labels } = await getInputs();
// console.log([inputs,labels])

//将加载的图片素材可视化
const surface = tfvis.visor().surface({ name: '输入示例', styles: { height: 250 } });
inputs.forEach(img => {
surface.drawArea.appendChild(img);
});
};

加载预训练好的模型Mobilenet

1
2
3
4
5
//加载预训练好的模型Mobilenet
const mobilenet = await tf.loadLayersModel(MOBILENET_MODEL_PATH);

//mobilenet的方法,给出其神经网络的概览
mobilenet.summary();

mobilenet模型概览

定义截断模型

1
2
3
4
5
6
7
8
//获取中间层
const layer = mobilenet.getLayer('conv_pw_13_relu');

//定义一个截断模型truncatedMobilenet
const truncatedMobilenet = tf.model({
inputs: mobilenet.inputs,
outputs: layer.output
});

定义双层的迁移模型

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23

//定义一个模型
const model = tf.sequential();

//添加一个flatten层(将截断模型提取的高维特征提取成一维向量,这一层没有参数,起转换作用
model.add(tf.layers.flatten({
inputShape: layer.outputShape.slice(1)
}));

//添加一个全链接层:用于训练我们的商标图片
model.add(tf.layers.dense({
units: 10,
activation: 'relu'
}));

//添加一个全链接层:用于做多分类
model.add(tf.layers.dense({
units: NUM_CLASSES,
activation: 'softmax'
}));

//设置损失函数:分类交叉熵损失函数,优化器为adam
model.compile({ loss: 'categoricalCrossentropy', optimizer: tf.train.adam() });

先用截断模型训练数据,转为可以用于迁移模型的数据

1
2
3
4
5
6
//训练数据 先经过截断模型,转为可以用于迁移模型的数据
const { xs, ys } = tf.tidy(() => {
const xs = tf.concat(inputs.map(imgEl => truncatedMobilenet.predict(img2x(imgEl))));
const ys = tf.tensor(labels);
return { xs, ys };
});

训练迁移模型

1
2
3
4
5
6
7
8
9
//训练迁移模型
await model.fit(xs, ys, {
epochs: 20,
callbacks: tfvis.show.fitCallbacks(
{ name: '训练效果' },
['loss'],
{ callbacks: ['onEpochEnd'] }
)
});

迁移模型训练效率高

预测

1
2
3
4
5
6
7
8
9
10
11
12
13
14
window.predict = async (file) => {
const img = await file2img(file);
document.body.appendChild(img);
const pred = tf.tidy(() => {
const x = img2x(img);
const input = truncatedMobilenet.predict(x);
return model.predict(input);
});

const index = pred.argMax(1).dataSync()[0];
setTimeout(() => {
alert(`预测结果:${BRAND_CLASSES[index]}`);
}, 0);
};

预测效果

模型的保存和加载

1
2
3
window.download = async () => {
await model.save('downloads://model');
};

代码仓库