En PyTorch, torch.save()
y torch.load()
son funciones utilizadas para guardar y cargar modelos, así como sus parámetros o cualquier objeto serializable de PyTorch. Estas funciones son esenciales para almacenar el progreso del entrenamiento o reutilizar modelos preentrenados.
torch.save()
: Guarda un objeto de PyTorch en un archivo utilizando pickle. Se utiliza comúnmente para guardar los pesos o el estado de un modelo.
torch.load()
: Carga el objeto guardado de un archivo. Se usa para restaurar el modelo o los parámetros previamente guardados.
torch.save()
Para guardar un modelo entrenado, lo más común es almacenar el diccionario de estado (state_dict
), que contiene los parámetros del modelo.
import torch
# Guardar el estado del modelo
torch.save(model.state_dict(), 'modelo_entrenado.pth')
torch.load()
Para cargar el modelo guardado, debes crear una instancia del modelo e inicializarlo con los parámetros almacenados.
import torch
# Cargar el estado guardado
model = MiModelo() # Se crea una nueva instancia del modelo
model.load_state_dict(torch.load('modelo_entrenado.pth'))
model.eval() # Se pone el modelo en modo de evaluación
Es posible guardar el modelo completo, pero es más recomendable guardar solo el diccionario de estado.
# Guardar el modelo completo
torch.save(model, 'modelo_completo.pth')
# Cargar el modelo completo
model = torch.load('modelo_completo.pth')
model.eval()
Al igual que con los modelos, también puedes guardar y cargar los optimizadores para continuar el entrenamiento desde donde lo dejaste.
# Guardar el estado del optimizador
torch.save(optimizer.state_dict(), 'optimizador.pth')
# Cargar el estado del optimizador
optimizer.load_state_dict(torch.load('optimizador.pth'))
Para más información sobre cómo guardar y cargar modelos en PyTorch, consulta la documentación oficial de PyTorch.
Jorge García
Fullstack developer