In PyTorch, you can define custom gradient functions by subclassing torch.autograd.Function
. This allows you to implement both forward and backward passes with custom operations, providing more control over gradient computation.
Defining a Custom Function
To create a custom function, you need to define two static methods: forward
and backward
. The forward
method performs the forward computation, and the backward
method computes the gradients.
- Custom ReLU Example:
import torch
from torch.autograd import Function
class MyReLU(Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
return input.clamp(min=0)
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
grad_input = grad_output.clone()
grad_input[input < 0] = 0
return grad_input
# Using the custom ReLU
relu = MyReLU.apply
x = torch.tensor([-1.0, 0.0, 1.0, 2.0], requires_grad=True)
y = relu(x)
y.sum().backward()
print("Input:", x)
print("ReLU Output:", y)
print("Gradient of x:", x.grad)
In this example, the custom ReLU function applies the ReLU operation in the forward pass and computes the gradient in the backward pass.
Custom Gradient Example with Additional Parameters
You can also create custom functions that take additional parameters. Here’s an example with a scaled addition operation.
- Custom Scaled Addition Example:
class ScaledAdd(Function):
@staticmethod
def forward(ctx, input, scale):
ctx.save_for_backward(input)
ctx.scale = scale
return input + scale
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
grad_input = grad_output.clone()
grad_scale = grad_output.sum()
return grad_input, grad_scale
# Using the custom scaled addition
scaled_add = ScaledAdd.apply
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
scale = torch.tensor(0.5, requires_grad=True)
y = scaled_add(x, scale)
y.sum().backward()
print("Input:", x)
print("Scale:", scale)
print("Output:", y)
print("Gradient of x:", x.grad)
print("Gradient of scale:", scale.grad)
In this example, the ScaledAdd
function adds a scale to the input tensor and computes the corresponding gradients.