tensorflowで学習の中断・再開 | 15g.jp

今回は「学習の中断・再開」です。学習データの保存と読込を行います。
tensorflow自体の導入方法は前回をご覧ください。

前回に引き続きmnistを使います。

基本的な機能だけの紹介になりますので、
詳細は公式リファレンスをご覧下さい。

mnistのこの部分を変更します。

中断・再開機能を追加するとこうなります。

作成されるファイルは以下になります。
my-model-xxx:xxxにはステップ数が入ります。
checkpoint:作成したモデルのファイル名が保存されています。checkpoint自体の名前を変更することもできますが、今回はデフォルトで行きます。

では、コードを順に見て行きましょう。

Saverクラスを作成します。max_to_keepは中断ファイルの作成数です。デフォルトは5で、古いものから削除されます。
学習の推移を見たい場合は、この数値を増やすよりステップ数の間隔を広げたほうが良いと思います。(そもそもtensorboard使ったほうがいい気も済ますが)

モデルの有無で分岐させます。

最後に保存されたモデルファイルを呼び出します。

checkpointファイルがない場合。

以上です。
実際はmodel、checkpoint用のディレクトリを作成した方がよいでしょう。
また、checkpointは作成されているけどmodelファイルがない、等の場合はエラーになりますがその辺りは割愛しています。