Note
Go to the end to download the full example code.
Pytorch inplace
…
…
# https://discuss.pytorch.org/t/disable-in-place-correctness-version-check-any-other-workaround/90738/4
import torch
class _AssignInAVerySafeManner(torch.autograd.Function):
@staticmethod
def forward(ctx, scratch, value, index):
ctx.index = index
scratch.data[index] = value
return scratch
@staticmethod
def backward(ctx, grad_scratch):
return grad_scratch, grad_scratch[ctx.index], None
N = 3
# create variable that requires gradient
x1 = torch.rand(N)
x2 = x1.clone()
x3 = x1.clone()
x1.requires_grad = True
x2.requires_grad = True
x3.requires_grad = True
# compute inplace in one tensor
y1 = torch.empty(N)
y1[0] = x1[0]
for i in range(1, N):
tmp = (y1[:i] ** 2).sum() + x1[i] # y is needed for gradient
y1 = y1.clone() # make copy to prevent in-place error
y1[i] = tmp # changing y inplace breaks gradient computation without clone
# split into separate tensors
y2 = [None for i in range(N)] # equivalent of emtpy tensor
y2[0] = x2[0]
for i in range(1, N):
y2[i] = (torch.stack(y2[:i]) ** 2).sum() + x2[i] # can do "inplace" operation because it's on separate tensors
y2 = torch.stack(y2)
# compute inplace in one tensor with special function
y3 = torch.empty(N)
y3[0] = x3[0]
for i in range(1, N):
tmp = (y3[:i] ** 2).sum() + x3[i]
y3 = _AssignInAVerySafeManner.apply(y3, tmp, i)
# inplace raises error without clone
y1.sum().backward()
y2.sum().backward()
y3.sum().backward()
# make sure the result is the same
assert torch.equal(y1, y2)
assert torch.equal(y2, y3)
assert torch.equal(x1.grad, x2.grad)
assert torch.equal(x2.grad, x3.grad)
Total running time of the script: (0 minutes 3.082 seconds)