С помощью
фреймворка PyTorch можно создавать
довольно крупные НС, которые, как правило, обучаются длительное время.
Соответственно, встает вопрос сохранения обученных моделей на диск и их
последующая загрузка. Давайте посмотрим, как это делается на примере ранее
созданной модели распознавания изображений цифр:
neuro_net_21.py: https://github.com/selfedu-rus/neuro-pytorch
Если
непосредственно после создания модели:
model = DigitNN(28 * 28, 32, 10)
вызвать метод
state_dict:
То переменная st будет ссылаться
на словарь состояний модели model. В данном случае он будет иметь вид:
st = {OrderedDict: 4} OrderedDict(
'layer1.weight' = {Tensor: (32, 784)},
'layer1.bias'
= {Tensor: (32,)},
'layer2.weight' = {Tensor: (10, 32)},
'layer2.bias' = {Tensor: (10,)})
То есть,
содержать веса связей и смещений (biases) текущей
модели. Затем, все эти данные (словарь) можно сохранить на диске с помощью
функции save фреймворка PyTorch:
torch.save(st, 'model_dnn.tar')
В результате в
текущем каталоге появится архивный файл 'model_dnn.tar' с сохраненными
весами связей. Эти данные могут быть загружены в любую НС с такой же структурой,
и она станет распознавать изображения цифр, как это делала сохраненная модель.
Для
этого вначале вызывается функция load:
state_model = torch.load('model_dnn.tar', weights_only=True)
Параметр weights_only=True означает, что
выполняется загрузка примитивных типов данных, например: словарей, тензоров,
списков, строк и т.п. В данном случае мы используем словарь.
После загрузки словарь
состояний state_model передается в
любую подходящую модель (той же архитектуры):
model.load_state_dict(state_model)
Вот общий
принцип записи данных в файл и их последующей загрузки. Причем сохранять можно
вообще произвольные данные. Например, создадим тензор:
tf = torch.tensor([1, 2, 3])
и сохраним его:
Выполним
загрузку данных из этого же файла:
t = torch.load("ex.tar", weights_only=True)
и видим, что t – это такой же
тензор, что и tf. То есть,
функция load() сразу
возвращает объект ранее сохраненного типа. Как пример, создадим строку:
и также сохраним
ее в файл:
После загрузки:
t = torch.load("ex.tar", weights_only=True)
видим, что t – это строка с
соответствующим содержимым. И так можно делать со всеми стандартными типами
данных.
Однако при
работе с тензорами и моделями НС есть один важный нюанс. Если тензор изначально
расположен на графическом процессоре:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tf = torch.tensor([1, 2, 3], device=device)
То при его
сохранении и последующей загрузке:
torch.save(tf, "ex.tar")
t = torch.load("ex.tar", weights_only=True)
Тензор t также окажется
на GPU. Если нам нужно
другое или вполне определенное его расположение, то дополнительно это можно
указать через параметр map_location, например, так:
t = torch.load("ex.tar", weights_only=True, map_location="cpu")
Теперь тензор t будет
расположен на CPU. Иногда такое
явное указание типа устройства бывает очень полезным.
Сохранение состояния модели в процессе обучения
Итак, как в
целом выполнять сохранение и загрузку тензоров, словарей и других встроенных
типов данных, мы с вами разобрались. Все это применяется не только для
сохранения полностью обученных моделей, но и для формирования контрольных точек
в процессе самого обучения, когда обучение затягивается на дни или даже месяцы.
Как вы понимаете, при длительной работе программ могут происходить самые разные
непредвиденные ситуации, из-за которых процесс обучения может прерваться. И
начинать все заново было бы очень неприятно. Поэтому на практике формируют
промежуточные контрольные точки (check point) с текущими
весами модели. Если происходит непредвиденная остановка, то достаточно
загрузить сохраненную модель и продолжить процесс обучения. Давайте рассмотрим
пример того, как это можно было бы сделать.
Перед главным
циклом обучения введем вспомогательную переменную:
best_loss = 1e10 # заведомо высокое значение
А в самом цикле
на каждой эпохе будем сравнивать превышение значения best_loss над значением
средних потерь loss_mean на 10% для
определения значимых улучшений:
for _e in range(epochs):
loss_mean = 0
lm_count = 0
train_tqdm = tqdm(train_data, leave=True)
for x_train, y_train in train_tqdm:
...
if best_loss > loss_mean * 1.1:
best_loss = loss_mean
st = model.state_dict()
torch.save(st, f'model_dnn_{_e}.tar')
Выражение best_loss > loss_mean * 1.1
математически можно получить следующим образом:
Если теперь на
текущей эпохе предыдущие потери окажутся на 10% выше, чем текущие средние, то
формируется контрольная точка (сохраняются веса модели). Соответственно, файл с
самым большим номером будет содержать лучшие значения весов связей.
Конечно,
величина 10% взята с позиции здравого смысла. Это не какое-то строгое значение.
Вполне можно выбрать 15%, 20% или даже 5% в зависимости от протекания процесса
обучения и решаемой задачи.
Сохранение и загрузка дополнительных объектов
Помимо модели НС
сохранять в файле можно и некоторые другие объекты фреймворка PyTorch. Как это
делается с тензорами, вы уже видели. Но помимо них аналогичные операции
допустимо выполнять с объектом трансформации, оптимизатором. Например,
сформируем следующий общий словарь состояния модели:
model_state_dict = {
'tfs': transforms.state_dict(),
'opt': optimizer.state_dict(),
'model': model.state_dict(),
}
И будем
сохранять его в главном цикле обучения:
for _e in range(epochs):
loss_mean = 0
lm_count = 0
train_tqdm = tqdm(train_data, leave=True)
for x_train, y_train in train_tqdm:
...
if best_loss > loss_mean * 1.1:
best_loss = loss_mean
model_state_dict['model'] = model.state_dict()
torch.save(model_state_dict, f'model_dnn_{_e}.tar')
После обучения
(например, в другом модуле) загрузим эти данные:
model_data = torch.load('model_dnn_1.tar', weights_only=True)
model.load_state_dict(model_data['model'])
transforms.load_state_dict(model_data['tfs'])
optimizer.load_state_dict(model_data['opt'])
В результате мы
сохранили, а затем, загрузили данные по модели, преобразованиям и оптимизатору.
Как видите, все достаточно просто и удобно.