Pytorch inplace

# https://discuss.pytorch.org/t/disable-in-place-correctness-version-check-any-other-workaround/90738/4

import torch


class _AssignInAVerySafeManner(torch.autograd.Function):
    @staticmethod
    def forward(ctx, scratch, value, index):
        ctx.index = index
        scratch.data[index] = value
        return scratch

    @staticmethod
    def backward(ctx, grad_scratch):
        return grad_scratch, grad_scratch[ctx.index], None

N = 3

# create variable that requires gradient
x1 = torch.rand(N)
x2 = x1.clone()
x3 = x1.clone()
x1.requires_grad = True
x2.requires_grad = True
x3.requires_grad = True

# compute inplace in one tensor
y1 = torch.empty(N)
y1[0] = x1[0]
for i in range(1, N):
    tmp = (y1[:i] ** 2).sum() + x1[i]  # y is needed for gradient
    y1 = y1.clone()  # make copy to prevent in-place error
    y1[i] = tmp  # changing y inplace breaks gradient computation without clone

# split into separate tensors
y2 = [None for i in range(N)]  # equivalent of emtpy tensor
y2[0] = x2[0]
for i in range(1, N):
    y2[i] = (torch.stack(y2[:i]) ** 2).sum() + x2[i]  # can do "inplace" operation because it's on separate tensors
y2 = torch.stack(y2)

# compute inplace in one tensor with special function
y3 = torch.empty(N)
y3[0] = x3[0]
for i in range(1, N):
    tmp = (y3[:i] ** 2).sum() + x3[i]
    y3 = _AssignInAVerySafeManner.apply(y3, tmp, i)

# inplace raises error without clone
y1.sum().backward()
y2.sum().backward()
y3.sum().backward()

# make sure the result is the same
assert torch.equal(y1, y2)
assert torch.equal(y2, y3)
assert torch.equal(x1.grad, x2.grad)
assert torch.equal(x2.grad, x3.grad)

Total running time of the script: (0 minutes 3.082 seconds)

Gallery generated by Sphinx-Gallery