Part 9: Saving and Loading Tensors

by digitaltech2.com
Saving and Loading Tensors

Saving and loading tensors is a crucial aspect of working with PyTorch, allowing you to persist data and models, share them, and resume computations. PyTorch provides simple functions for saving and loading tensors using the torch.save and torch.load methods.

Saving Tensors

You can save tensors to disk using the torch.save function. This function allows you to save tensors in a binary format that can be loaded later.

Save a Single Tensor:

import torch

tensor = torch.tensor([1, 2, 3, 4, 5])
torch.save(tensor, 'tensor.pt')

Save Multiple Tensors:

tensors = {'tensor_a': torch.tensor([1, 2, 3]), 'tensor_b': torch.tensor([4, 5, 6])}
torch.save(tensors, 'tensors.pt')
Loading Tensors

To load tensors from disk, use the torch.load function. This function can load tensors saved in a binary format back into memory.

Load a Single Tensor:

loaded_tensor = torch.load('tensor.pt')
print(loaded_tensor)

Load Multiple Tensors:

loaded_tensors = torch.load('tensors.pt')
print(loaded_tensors['tensor_a'])
print(loaded_tensors['tensor_b'])
Example: Saving and Loading Tensors in a Neural Network

Saving and loading tensors can be particularly useful in the context of neural networks for saving model states and resuming training.

Save Model State:

import torch.nn as nn

class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(10, 50)
        self.fc2 = nn.Linear(50, 1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

model = SimpleNN()
torch.save(model.state_dict(), 'model_state.pt')

Load Model State:

model = SimpleNN()
model.load_state_dict(torch.load('model_state.pt'))
model.eval()  # Set the model to evaluation mode

Saving and Loading Checkpoints

When training large models, it is often useful to save checkpoints periodically so that you can resume training from the last checkpoint in case of interruptions.

Save Checkpoint:

checkpoint = {
    'epoch': 10,
    'model_state': model.state_dict(),
    'optimizer_state': optimizer.state_dict(),
    'loss': loss
}
torch.save(checkpoint, 'checkpoint.pt')

Load Checkpoint:

checkpoint = torch.load('checkpoint.pt')
model.load_state_dict(checkpoint['model_state'])
optimizer.load_state_dict(checkpoint['optimizer_state'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

Related Posts