Introduction
保存模型
- 模型的保存就是將訓練後改變的變量儲存起來
- 保存模型的文件格式為checkpoint文件(檢查點文件)
train.Saver
使用tf.train.Saver(var_list=None,max_to_keep=5)返回一個saver物件
- var_list:指定要保存和還原的變量,可以作為一個dict或是一個list傳遞進去
- max_to_keep:指示要保留的最近檢查點文件的最大數量- 創建新文件時,會刪除較舊的文件
- 如果無舊文件,則保留所有檢查點文件,預設為5(即保留最新的5個檢查點文件)
 
- 使用返回的saver物件去操作保存(save)或是加載(restore)- saver.save(sess,"保存路徑及模型名")
- saver.restore(sess,"加載路徑及模型名")
 
範例
| 1 | import tensorflow as tf | 
Result
在欲儲存的路徑下會多了四個文件分別如下
- checkpoint文件中,只記載model儲存的位置及檔名
- 變量數據皆保存在 後綴名為.data-00000-of-00001的文件中- 每次更新時會一直往後加1- data-00001-of-00002
- data-00002-of-00003
 
 
- 每次更新時會一直往後加1
加載模型
- 再次訓練之前應先加載模型
- 加載後會覆蓋graph中使用變量定義的參數,例如:權重(weight),偏置(bias)
- 與儲存模型一樣須創建一個 加載/儲存的 saver物件,調用restore()方法加載
- saver.restore(sess,"加載路徑及模型名")
範例
| 1 | import tensorflow as tf | 
Result
| 1 | 起始初始化權重:0.093984, 初始化偏置:0.000000 | 
- 可以看到加載後的模型就是上次訓練最後的結果,其又被繼續練,並儲存
 
              