tensorflowjsによる手書き文字認識を用いたゲーム

作成: 2020年04月10日

更新: 2020年12月03日

文字認識ゲーム

「令和」を書き初めして〇×で判定するゲームを作りました。
令和で書き初め
screen.png

手書き文字の学習

令、和それぞれの手書き文字教師データはなかったため自分で令(2種類)、和、令でも和でもないものをそれぞれ50枚ずつ手書きし教師データとした。PythonのKerasを用いて教師データからhdf5形式で学習済みモデルを出力した。画像データの学習方法は以下の記事を参考にしました。
Kerasで画像分類~前処理から分類テストまで~
教師データおよび学習プログラムは以下のリポジトリに置きました。
reiwa-learning

tensorflowjsによる判別

tensorflowjsは機械学習でよく用いられる本家tensorflowのJavascriptによるwrapperです。ブラウザ上でモデルの学習、分類、回帰などが行えます。ブラウザ上なので重めの処理は難しいですが、学習済みモデルを用いた簡単な画像分類程度なら問題なく処理できます。
CDNなら以下

<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@1.0.0/dist/tf.min.js"></script>

npmなら以下

npm install @tensorflow/tfjs

Kerasで作成したhdf5形式のモデルを用いるにはpythonを用いてtensorflowjsが処理できる形式に変換する必要があります。変換方法は以下の公式ドキュメント参照
Importing a Keras model into TensorFlow.js

ゲームの処理の流れ

ゲームの流れを簡単に説明すると以下のようになります。

  1. canvasにマウスクリック、マウス移動などのEventListenerを追加し、canvas上に文字を書けるようにする。
  2. canvasを画像データに変換し、得られた画像データを作成済みのモデルに適用し、令、和、その他のカテゴリそれぞれの適合率を算出する。その適合率によって正解、不正解を決める。
  3. 正解なら〇、不正解なら×をcanvas上に描いて一定時間後にcanvasを真っ白に戻す。

canvas上に文字を書く

canvas上に文字を書くにはcanvasにマウスやタッチに関するEventListenerを追加する必要があります。
以下のようにcanvas要素にEventListnerを追加します。

can = document.getElementById("can"); //canvas要素
can.addEventListener("touchstart",onDown, false); //画面に指が触れた
can.addEventListener("touchmove",onMove, false); //タッチ中の移動
can.addEventListener("touchend",onUp, false); //画面から指が離れた
can.addEventListener("mousedown",onMouseDown, false); //マウスの左ボタンを押した
can.addEventListener("mousemove",onMouseMove, false); //マウスの左ボタンを押しながらマウスを動かした
can.addEventListener("mouseup",onMouseUp, false); //マウスの左ボタンを話した
can.addEventListener("mouseleave",onMouseUp, false); //マウスカーソルがcanvasの外に出た
ct = can.getContext("2d");
ct.strokeStyle = "#000000"; //線を黒色に
ct.lineWidth = 7 * can.width / 128; //線の太さ

それぞれのイベント発生時に関数を実行させ、文字の描画を可能にしています。例えば以下のようにすればタッチで文字を書くことができます。

var mf = false; //タッチ中かどうか
var ox = 0,
    oy = 0,
    x = 0,
    y = 0;

function onDown(event) {
    mf = true;
    ox = event.touches[0].pageX -event.target.getBoundingClientRect().left; //タッチ場所のx座標取得
    oy = event.touches[0].pageY -event.target.getBoundingClientRect().top; //タッチ場所のy座標取得
}

function onUp(event) {
    mf = false;
    event.stopPropagation();
}

function onMove(event) {
    if (mf) {
        x = event.touches[0].pageX - event.target.getBoundingClientRect().left;
        y = event.touches[0].pageY - event.target.getBoundingClientRect().top;
        drawLine();
        ox = x;
        oy = y;
        event.preventDefault();
        event.stopPropagation();
    }
}

//(ox, oy) ~ (x, y) の間に線を引く
function drawLine() {
    ct.beginPath();
    ct.moveTo(ox, oy);
    ct.lineTo(x, y);
    ct.stroke();
}

マウスクリックの方も同様に実装することができます。

canvasを画像データに変換し学習済みモデルに適用

canvasに描かれた文字を画像に変換することでKerasによって学習されたモデルを適用し令、和、その他、それぞれのカテゴリを適用することができます。
Kerasの学習モデルファイルをロードするloadLayersModel関数はawaitなどを使用した非同期処理下でないといけないので注意が必要です。

var temp = document.createElement('canvas');
temp.width = 32;
temp.height = 64;
var tempCtx = temp.getContext('2d');
tempCtx.drawImage(can, 0, 0, temp.width, temp.height); // canvasをコピー
var imageData = tempCtx.getImageData(0, 0, temp.width, temp.height); // canvas画像データに変換
for (var i = 0; i < imageData.data.length / 4; i++) { // 念のためモノクロに変換
    var r = imageData.data[i * 4];
    var g = imageData.data[i * 4 + 1];
    var b = imageData.data[i * 4 + 2];
    imageData.data[i * 4] = (r + g + b) / 3;
    imageData.data[i * 4 + 1] = (r + g + b) / 3;
    imageData.data[i * 4 + 2] = (r + g + b) / 3;
}
var inputTensor = tf.browser.fromPixels(imageData, 3).toFloat();
var inputNormTensor = inputTensor.div(tf.scalar(255));
var [reiTensor, waTensor] = tf.split(inputNormTensor, 2); // 令の部分(上半分)と和の部分(下半分)に分割
reiTensor = reiTensor.reshape([1, 32, 32, 3]); //サイズを32*32に変換。3はRGB要素
waTensor = waTensor.reshape([1, 32, 32, 3]);
loadPretrainedModel();
async function loadPretrainedModel() { //モデルのロードは非同期処理である必要があるため
    const model = await tf.loadLayersModel(location.origin + "/other/reiwa/model/model.json");
    const reiPrediction = model.predict(reiTensor, { //予測実行
        batchSize: 1
    });
    const waPrediction = model.predict(waTensor, {
        batchSize: 1
    });
    const rei1 = reiPrediction.arraySync()[0][0];
    const rei2 = reiPrediction.arraySync()[0][1];
    const wa = waPrediction.arraySync()[0][2];
    if ((rei1 > 0.9 || rei2 > 0.9) && wa > 0.9) { //適合率0.9を正解不正解の境界に設定
        correctImgChange();
        correctSE.play();
        correctCount += 1;
        time += 3;
        document.getElementById("correctCount").innerHTML = correctCount;
        strokeCircle(0);
        strokeCircle(1);
    } else {
        wrongImgChange();
        wrongSE.play();
        wrongCount += 1
        document.getElementById("wrongCount").innerHTML = wrongCount;
        if ((rei1 > 0.9 || rei2 > 0.9)) {
            strokeCircle(0);
        } else {
            strokeCross(0);
        }
        if (wa > 0.9) {
            strokeCircle(1);
        } else {
            strokeCross(1);
        }
    }
    setTimeout(function () {
        sub = false;
        clearCan();
    }, 2000);
}

strokeCircle, strokeCrossはそれぞれ正解の〇、不正解の×を描画するための関数で以下のように定義されています。
丸はcanvasのarc関数として実装されているので簡単ですが、バツをcanvasで描画するにはバツ印の各頂点を指定する必要があり、大変でした。

function strokeCircle(num) { //num:何文字目か
    var p;
    if (num == 0) {
        p = 0;
    } else {
        p = canW;
    }
    ct.beginPath();
    ct.arc(canW / 2, canW / 2 + p, canW * 3 / 8, 0, Math.PI * 2, false);
    ct.strokeStyle = 'red';
    ct.stroke();
    ct.strokeStyle = 'black';
}

function strokeCross(num) { //num:何文字目か
    var p;
    if (num == 0) {
        p = 0;
    } else {
        p = canW;
    }
    var x = canW / 8;
    ct.beginPath();
    ct.moveTo(x * 2, x + p);
    ct.lineTo(x, x * 2 + p);
    ct.lineTo(x * 3, x * 4 + p);
    ct.lineTo(x, x * 6 + p);
    ct.lineTo(x * 2, x * 7 + p);
    ct.lineTo(x * 4, x * 5 + p);
    ct.lineTo(x * 6, x * 7 + p);
    ct.lineTo(x * 7, x * 6 + p);
    ct.lineTo(x * 5, x * 4 + p);
    ct.lineTo(x * 7, x * 2 + p);
    ct.lineTo(x * 6, x + p);
    ct.lineTo(x * 4, x * 3 + p);
    ct.fillStyle = 'blue';
    ct.fill();
}

まとめ

tensorflowjsを用いれば簡単な機械学習処理がクライアント側でできるのでサーバー側の言語に依存せずWebで機械学習ができます。機械学習を試す程度ならtensorflowjsを使うのもいいかもしれません。