Example
In [1]:
Copied!
from torch_conv_gradfix import enable, disable, conv2d, conv_transpose2d, no_weight_grad
import torch
transpose = True
N, SIZE = 2, 8
IN_C, OUT_C = 3, 8
KERNEL, STRIDE, PADDING, DILATION, GROUPS = 3, 2, 2, 1, 1
device = "cpu" if not torch.cuda.is_available() else "cuda"
input = torch.randn(
N, IN_C, SIZE, SIZE, dtype=torch.double, requires_grad=True, device=device
)
weight = torch.randn(
OUT_C, IN_C, KERNEL, KERNEL, dtype=torch.double, requires_grad=True, device=device
)
trans_weight = torch.randn(
IN_C, OUT_C, KERNEL, KERNEL, dtype=torch.double, requires_grad=True, device=device
)
bias = torch.randn(OUT_C, dtype=torch.double, requires_grad=True, device=device)
from torch_conv_gradfix import enable, disable, conv2d, conv_transpose2d, no_weight_grad
import torch
transpose = True
N, SIZE = 2, 8
IN_C, OUT_C = 3, 8
KERNEL, STRIDE, PADDING, DILATION, GROUPS = 3, 2, 2, 1, 1
device = "cpu" if not torch.cuda.is_available() else "cuda"
input = torch.randn(
N, IN_C, SIZE, SIZE, dtype=torch.double, requires_grad=True, device=device
)
weight = torch.randn(
OUT_C, IN_C, KERNEL, KERNEL, dtype=torch.double, requires_grad=True, device=device
)
trans_weight = torch.randn(
IN_C, OUT_C, KERNEL, KERNEL, dtype=torch.double, requires_grad=True, device=device
)
bias = torch.randn(OUT_C, dtype=torch.double, requires_grad=True, device=device)
In [2]:
Copied!
# When torch_conv_gradfix is enabled, no_weight_grad() stops the weight gradient calculations
conv_out = conv2d(
input,
weight,
bias,
stride=STRIDE,
padding=PADDING,
dilation=DILATION,
groups=GROUPS,
)
input.grad = weight.grad = bias.grad = None
with no_weight_grad():
conv_out.sum().backward()
assert input.grad is not None and weight.grad is None
# When torch_conv_gradfix is enabled, no_weight_grad() stops the weight gradient calculations
conv_out = conv2d(
input,
weight,
bias,
stride=STRIDE,
padding=PADDING,
dilation=DILATION,
groups=GROUPS,
)
input.grad = weight.grad = bias.grad = None
with no_weight_grad():
conv_out.sum().backward()
assert input.grad is not None and weight.grad is None
--------------------------------------------------------------------------- AssertionError Traceback (most recent call last) Input In [2], in <cell line: 14>() 12 with no_weight_grad(): 13 conv_out.sum().backward() ---> 14 assert input.grad is not None and weight.grad is None AssertionError:
In [3]:
Copied!
# When torch_conv_gradfix is disabled, no_weight_grad() has no effects
conv_trans_out = conv_transpose2d(
input,
trans_weight,
bias,
stride=STRIDE,
padding=PADDING,
dilation=DILATION,
groups=GROUPS,
)
disable()
input.grad = trans_weight.grad = bias.grad = None
with no_weight_grad():
conv_trans_out.sum().backward()
assert input.grad is not None and trans_weight.grad is not None
# When torch_conv_gradfix is disabled, no_weight_grad() has no effects
conv_trans_out = conv_transpose2d(
input,
trans_weight,
bias,
stride=STRIDE,
padding=PADDING,
dilation=DILATION,
groups=GROUPS,
)
disable()
input.grad = trans_weight.grad = bias.grad = None
with no_weight_grad():
conv_trans_out.sum().backward()
assert input.grad is not None and trans_weight.grad is not None