pytorchでモデルを保存する方法は僕の知っているところ2通りあって、
torch.save(model, filename)
と
torhc.save(model.state_dict(), filename)
である。
後者はモデルというより、モデルのパラメータを保存してるらしい。
モデルを保存するサンプルコードを探すと、公式ドキュメントでもだいたいのサイトでも後者のmodel.state_dict()
を使うサンプルコードになっていると思う。
今回は、なぜ前者でなく後者なのか知る機会があったので、メモする
例えば自作のモジュールを使ってモデルを学習したとする。
そのモデルがtorch.save(model, filename)
で保存されたとすると、そのモデルをロードするときにその自作のモジュールも必要となる(使う使わないに限らず)。
今回、学習したモデルをロードして推論に使おうとしたが、モデル内のスクリプトでインポートされていたモジュールがインストールできなくて、モデルをロードできないということがあった。
正確には、モデルのモジュールがロードする時に、saveした時と同じモジュールがimport可能になっていないとダメって感じっぽい。
https://github.com/pytorch/pytorch/issues/3678
調べた感じだと、他にもGPUに送ったモデルをtorch.save(model, filename)
した場合、そのモデルをロードするときにGPUに置かれるから、GPUが使えないとモデルをロードできなくなるらしい。
ということがあるので、
torhc.save(model.state_dict(), filename)
を使います。