Introduction
保存模型
- 模型的保存就是將訓練後改變的變量儲存起來
- 保存模型的文件格式為checkpoint文件(檢查點文件)
train.Saver
使用tf.train.Saver(var_list=None,max_to_keep=5)
返回一個saver
物件
var_list
:指定要保存和還原的變量,可以作為一個dict或是一個list傳遞進去max_to_keep
:指示要保留的最近檢查點文件的最大數量- 創建新文件時,會刪除較舊的文件
- 如果無舊文件,則保留所有檢查點文件,預設為5(即保留最新的5個檢查點文件)
- 使用返回的
saver
物件去操作保存(save)或是加載(restore)saver.save(sess,"保存路徑及模型名")
saver.restore(sess,"加載路徑及模型名")
範例
1 | import tensorflow as tf |
Result
在欲儲存的路徑下會多了四個文件分別如下
- checkpoint文件中,只記載model儲存的位置及檔名
- 變量數據皆保存在 後綴名為.data-00000-of-00001的文件中
- 每次更新時會一直往後加1
- data-00001-of-00002
- data-00002-of-00003
- 每次更新時會一直往後加1
加載模型
- 再次訓練之前應先加載模型
- 加載後會覆蓋graph中使用變量定義的參數,例如:權重(weight),偏置(bias)
- 與儲存模型一樣須創建一個 加載/儲存的
saver
物件,調用restore()
方法加載 saver.restore(sess,"加載路徑及模型名")
範例
1 | import tensorflow as tf |
Result
1 | 起始初始化權重:0.093984, 初始化偏置:0.000000 |
- 可以看到加載後的模型就是上次訓練最後的結果,其又被繼續練,並儲存