Saving and loading tensors is a crucial aspect of working with PyTorch, allowing you to persist data and models, share them, and resume computations. PyTorch provides simple functions for saving and loading tensors using the torch.save
and torch.load
methods.
Saving Tensors
You can save tensors to disk using the torch.save
function. This function allows you to save tensors in a binary format that can be loaded later.
Save a Single Tensor:
import torch
tensor = torch.tensor([1, 2, 3, 4, 5])
torch.save(tensor, 'tensor.pt')
Save Multiple Tensors:
tensors = {'tensor_a': torch.tensor([1, 2, 3]), 'tensor_b': torch.tensor([4, 5, 6])}
torch.save(tensors, 'tensors.pt')
Loading Tensors
To load tensors from disk, use the torch.load
function. This function can load tensors saved in a binary format back into memory.
Load a Single Tensor:
loaded_tensor = torch.load('tensor.pt')
print(loaded_tensor)
Load Multiple Tensors:
loaded_tensors = torch.load('tensors.pt')
print(loaded_tensors['tensor_a'])
print(loaded_tensors['tensor_b'])
Example: Saving and Loading Tensors in a Neural Network
Saving and loading tensors can be particularly useful in the context of neural networks for saving model states and resuming training.
Save Model State:
import torch.nn as nn
class SimpleNN(nn.Module):
def __init__(self):
super(SimpleNN, self).__init__()
self.fc1 = nn.Linear(10, 50)
self.fc2 = nn.Linear(50, 1)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
model = SimpleNN()
torch.save(model.state_dict(), 'model_state.pt')
Load Model State:
model = SimpleNN()
model.load_state_dict(torch.load('model_state.pt'))
model.eval() # Set the model to evaluation mode
Saving and Loading Checkpoints
When training large models, it is often useful to save checkpoints periodically so that you can resume training from the last checkpoint in case of interruptions.
Save Checkpoint:
checkpoint = {
'epoch': 10,
'model_state': model.state_dict(),
'optimizer_state': optimizer.state_dict(),
'loss': loss
}
torch.save(checkpoint, 'checkpoint.pt')
Load Checkpoint:
checkpoint = torch.load('checkpoint.pt')
model.load_state_dict(checkpoint['model_state'])
optimizer.load_state_dict(checkpoint['optimizer_state'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']