Part 19: PyTorch, Custom Gradient Functions

by digitaltech2.com

In PyTorch, you can define custom gradient functions by subclassing torch.autograd.Function. This allows you to implement both forward and backward passes with custom operations, providing more control over gradient computation.

Defining a Custom Function

To create a custom function, you need to define two static methods: forward and backward. The forward method performs the forward computation, and the backward method computes the gradients.

  • Custom ReLU Example:
import torch
from torch.autograd import Function

class MyReLU(Function):
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        return input.clamp(min=0)

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input < 0] = 0
        return grad_input

# Using the custom ReLU
relu = MyReLU.apply
x = torch.tensor([-1.0, 0.0, 1.0, 2.0], requires_grad=True)
y = relu(x)
y.sum().backward()

print("Input:", x)
print("ReLU Output:", y)
print("Gradient of x:", x.grad)

In this example, the custom ReLU function applies the ReLU operation in the forward pass and computes the gradient in the backward pass.

Custom Gradient Example with Additional Parameters

You can also create custom functions that take additional parameters. Here’s an example with a scaled addition operation.

  • Custom Scaled Addition Example:
class ScaledAdd(Function):
    @staticmethod
    def forward(ctx, input, scale):
        ctx.save_for_backward(input)
        ctx.scale = scale
        return input + scale

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_scale = grad_output.sum()
        return grad_input, grad_scale

# Using the custom scaled addition
scaled_add = ScaledAdd.apply
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
scale = torch.tensor(0.5, requires_grad=True)
y = scaled_add(x, scale)
y.sum().backward()

print("Input:", x)
print("Scale:", scale)
print("Output:", y)
print("Gradient of x:", x.grad)
print("Gradient of scale:", scale.grad)

In this example, the ScaledAdd function adds a scale to the input tensor and computes the corresponding gradients.

Related Posts