Volver a la página principal
miércoles 25 septiembre 2024
59

Cómo guardar y cargar modelos en PyTorch con save() y load()

En PyTorch, torch.save() y torch.load() son funciones utilizadas para guardar y cargar modelos, así como sus parámetros o cualquier objeto serializable de PyTorch. Estas funciones son esenciales para almacenar el progreso del entrenamiento o reutilizar modelos preentrenados.

¿Cómo funciona torch.save() y torch.load() en Python?

  • torch.save(): Guarda un objeto de PyTorch en un archivo utilizando pickle. Se utiliza comúnmente para guardar los pesos o el estado de un modelo.
  • torch.load(): Carga el objeto guardado de un archivo. Se usa para restaurar el modelo o los parámetros previamente guardados.

Guardar un modelo con torch.save()

Para guardar un modelo entrenado, lo más común es almacenar el diccionario de estado (state_dict), que contiene los parámetros del modelo.

import torch

# Guardar el estado del modelo
torch.save(model.state_dict(), 'modelo_entrenado.pth')

Cargar un modelo con torch.load()

Para cargar el modelo guardado, debes crear una instancia del modelo e inicializarlo con los parámetros almacenados.

import torch

# Cargar el estado guardado
model = MiModelo()  # Se crea una nueva instancia del modelo
model.load_state_dict(torch.load('modelo_entrenado.pth'))
model.eval()  # Se pone el modelo en modo de evaluación

Ejemplos adicionales

Guardar y cargar un modelo completo

Es posible guardar el modelo completo, pero es más recomendable guardar solo el diccionario de estado.

# Guardar el modelo completo
torch.save(model, 'modelo_completo.pth')

# Cargar el modelo completo
model = torch.load('modelo_completo.pth')
model.eval()

Guardar y cargar optimizadores

Al igual que con los modelos, también puedes guardar y cargar los optimizadores para continuar el entrenamiento desde donde lo dejaste.

# Guardar el estado del optimizador
torch.save(optimizer.state_dict(), 'optimizador.pth')

# Cargar el estado del optimizador
optimizer.load_state_dict(torch.load('optimizador.pth'))

Referencia oficial

Para más información sobre cómo guardar y cargar modelos en PyTorch, consulta la documentación oficial de PyTorch.

Compartir:
Creado por:
Author photo

Jorge García

Fullstack developer