Skip to content

__all__

Resolution module-attribute

Resolution = Literal[
    4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192
]

__all__ module-attribute

__all__ = [
    "Discriminator",
    "Generator",
    "Resolution",
    "default_channels",
    "Blur",
    "EqualConv2d",
    "EqualLeakyReLU",
    "EqualLinear",
    "d_loss",
    "d_reg_loss",
    "g_loss",
    "g_reg_loss",
]

__author__ module-attribute

__author__ = 'Peter Yuen'

__email__ module-attribute

__email__ = 'ppeetteerrsx@gmail.com'

__version__ module-attribute

__version__ = '0.0.0'

default_channels module-attribute

default_channels: Dict[Resolution, int] = {
    4: 512,
    8: 512,
    16: 512,
    32: 512,
    64: 512,
    128: 256,
    256: 128,
    512: 64,
    1024: 32,
}

Blur

Blur(blur_kernel: List[int], factor: int, kernel_size: int)

Bases: nn.Module

Upsample (factor > 0)

Applied after a transpose convolution of stride U and kernel size K

Apply blurring FIR filter (before / after) a (downsampling / upsampling) op

Parameters:

Name Type Description Default
input Tensor

(N, C, (H - 1) * U + K - 1 + 1, (W - 1) * U + K - 1 + 1)

required
blur_kernel Tensor

FIR filter

required
factor int

U. Defaults to 2.

required
kernel_size int

K. Defaults to 3.

required

Returns:

Name Type Description
Tensor

(N, C, H * U, W * U)

Downsample (factor < 0)

Applied before a convolution of stride U and kernel size K

Parameters:

Name Type Description Default
input Tensor

(N, C, H, W)

required
blur_kernel Tensor

FIR filter

required
factor int

U. Defaults to 2.

required
kernel_size int

K. Defaults to 3.

required

Returns:

Name Type Description
Tensor

(N, C, H - (U + 1) + K - 1, H - (U + 1) + K - 1)

Source code in stylegan2_torch/equalized_lr.py
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
def __init__(self, blur_kernel: List[int], factor: int, kernel_size: int):
    """
    Apply blurring FIR filter (before / after) a (downsampling / upsampling) op

    Case 1: Upsample (factor > 0)
        Applied after a transpose convolution of stride U and kernel size K

    Args:
        input (Tensor): (N, C, (H - 1) * U + K - 1 + 1, (W - 1) * U + K - 1 + 1)
        blur_kernel (Tensor): FIR filter
        factor (int, optional): U. Defaults to 2.
        kernel_size (int, optional): K. Defaults to 3.

    Returns:
        Tensor: (N, C, H * U, W * U)


    Case 2: Downsample (factor < 0)
        Applied before a convolution of stride U and kernel size K

    Args:
        input (Tensor): (N, C, H, W)
        blur_kernel (Tensor): FIR filter
        factor (int, optional): U. Defaults to 2.
        kernel_size (int, optional): K. Defaults to 3.

    Returns:
        Tensor: (N, C, H - (U + 1) + K - 1, H  - (U + 1) + K - 1)
    """
    super().__init__()

    if factor > 0:
        p = (len(blur_kernel) - factor) - (kernel_size - 1)
        pad0 = (p + 1) // 2 + factor - 1
        pad1 = p // 2 + 1
    else:
        p = (len(blur_kernel) - abs(factor)) + (kernel_size - 1)
        pad0 = (p + 1) // 2
        pad1 = p // 2

    # Factor to compensate for averaging with zeros if upsampling
    self.kernel: Tensor
    self.register_buffer(
        "kernel", make_kernel(blur_kernel, factor if factor > 0 else 1)
    )
    self.pad = (pad0, pad1)

__call__ class-attribute

__call__ = proxy(forward)

kernel instance-attribute

kernel: Tensor = None

pad instance-attribute

pad = (pad0, pad1)

forward

forward(input: Tensor) -> Tensor
Source code in stylegan2_torch/equalized_lr.py
171
172
def forward(self, input: Tensor) -> Tensor:
    return upfirdn2d(input, self.kernel, pad=self.pad)

Discriminator

Discriminator(
    resolution: Resolution,
    channels: Dict[Resolution, int] = default_channels,
    blur_kernel: List[int] = [1, 3, 3, 1],
)

Bases: nn.Module

Discriminator module

Source code in stylegan2_torch/discriminator/__init__.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def __init__(
    self,
    resolution: Resolution,
    channels: Dict[Resolution, int] = default_channels,
    blur_kernel: List[int] = [1, 3, 3, 1],
):
    super().__init__()

    # FromRGB followed by ResBlock
    self.n_layers = int(math.log(resolution, 2))

    self.blocks = nn.Sequential(
        ConvBlock(1, channels[resolution], 1),
        *[
            ResBlock(channels[2**i], channels[2 ** (i - 1)], blur_kernel)
            for i in range(self.n_layers, 2, -1)
        ],
    )

    # Minibatch std settings
    self.stddev_group = 4
    self.stddev_feat = 1

    # Final layers
    self.final_conv = ConvBlock(channels[4] + 1, channels[4], 3)
    self.final_relu = EqualLeakyReLU(channels[4] * 4 * 4, channels[4])
    self.final_linear = EqualLinear(channels[4], 1)

__call__ class-attribute

__call__ = proxy(forward)

blocks instance-attribute

blocks = nn.Sequential(
    ConvBlock(1, channels[resolution], 1),
    [
        ResBlock(
            channels[2**i],
            channels[2**i - 1],
            blur_kernel,
        )
        for i in range(self.n_layers, 2, -1)
    ],
)

final_conv instance-attribute

final_conv = ConvBlock(channels[4] + 1, channels[4], 3)

final_linear instance-attribute

final_linear = EqualLinear(channels[4], 1)

final_relu instance-attribute

final_relu = EqualLeakyReLU(
    channels[4] * 4 * 4, channels[4]
)

n_layers instance-attribute

n_layers = int(math.log(resolution, 2))

stddev_feat instance-attribute

stddev_feat = 1

stddev_group instance-attribute

stddev_group = 4

forward

forward(input: Tensor, *, return_features: bool = False)
Source code in stylegan2_torch/discriminator/__init__.py
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
def forward(self, input: Tensor, *, return_features: bool = False):
    # Downsampling blocks
    out: Tensor = self.blocks(input)

    # Minibatch stddev layer in Progressive GAN https://www.youtube.com/watch?v=V1qQXb9KcDY
    # Purpose is to provide variational information to the discriminator to prevent mode collapse
    # Other layers do not cross sample boundaries
    batch, channel, height, width = out.shape
    n_groups = min(batch, self.stddev_group)
    stddev = out.view(
        n_groups, -1, self.stddev_feat, channel // self.stddev_feat, height, width
    )
    stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
    stddev = stddev.mean([2, 3, 4], keepdim=True).squeeze(2)
    stddev = stddev.repeat(n_groups, 1, height, width)
    out = torch.cat([out, stddev], 1)

    # Final layers
    out = self.final_conv(out)
    features = self.final_relu(out.view(batch, -1))
    out = self.final_linear(features)

    if return_features:
        return out, features
    else:
        return out

EqualConv2d

EqualConv2d(
    in_channel: int,
    out_channel: int,
    kernel_size: int,
    stride: int = 1,
    padding: int = 0,
    bias: bool = True,
)

Bases: nn.Module

Conv2d with equalized learning rate

Source code in stylegan2_torch/equalized_lr.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
def __init__(
    self,
    in_channel: int,
    out_channel: int,
    kernel_size: int,
    stride: int = 1,
    padding: int = 0,
    bias: bool = True,
):
    super().__init__()

    # Equalized Learning Rate
    self.weight = Parameter(
        torch.randn(out_channel, in_channel, kernel_size, kernel_size)
    )
    # std = gain / sqrt(fan_in)
    self.scale = 1 / math.sqrt(in_channel * kernel_size**2)
    self.stride = stride
    self.padding = padding
    self.bias = Parameter(torch.zeros(out_channel)) if bias else None

__call__ class-attribute

__call__ = proxy(forward)

bias instance-attribute

bias = Parameter(torch.zeros(out_channel)) if bias else None

padding instance-attribute

padding = padding

scale instance-attribute

scale = 1 / math.sqrt(in_channel * kernel_size ** 2)

stride instance-attribute

stride = stride

weight instance-attribute

weight = Parameter(
    torch.randn(
        out_channel, in_channel, kernel_size, kernel_size
    )
)

__repr__

__repr__() -> str
Source code in stylegan2_torch/equalized_lr.py
51
52
53
54
55
def __repr__(self) -> str:
    return (
        f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},"
        f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})"
    )

forward

forward(input: Tensor) -> Tensor
Source code in stylegan2_torch/equalized_lr.py
42
43
44
45
46
47
48
49
def forward(self, input: Tensor) -> Tensor:
    return conv2d(
        input=input,
        weight=self.weight * self.scale,
        bias=self.bias,
        stride=self.stride,
        padding=self.padding,
    )

EqualLeakyReLU

EqualLeakyReLU(
    in_dim: int, out_dim: int, lr_mult: float = 1
)

Bases: nn.Module

Leaky ReLU with equalized learning rate

Source code in stylegan2_torch/equalized_lr.py
 99
100
101
102
103
104
105
106
107
108
109
def __init__(self, in_dim: int, out_dim: int, lr_mult: float = 1):
    super().__init__()

    # Equalized Learning Rate
    self.weight = Parameter(torch.randn(out_dim, in_dim).div_(lr_mult))

    self.bias = Parameter(torch.zeros(out_dim))

    self.scale = (1 / math.sqrt(in_dim)) * lr_mult

    self.lr_mult = lr_mult

__call__ class-attribute

__call__ = proxy(forward)

bias instance-attribute

bias = Parameter(torch.zeros(out_dim))

lr_mult instance-attribute

lr_mult = lr_mult

scale instance-attribute

scale = 1 / math.sqrt(in_dim) * lr_mult

weight instance-attribute

weight = Parameter(
    torch.randn(out_dim, in_dim).div_(lr_mult)
)

__repr__

__repr__() -> str
Source code in stylegan2_torch/equalized_lr.py
115
116
117
118
def __repr__(self) -> str:
    return (
        f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})"
    )

forward

forward(input: Tensor) -> Tensor
Source code in stylegan2_torch/equalized_lr.py
111
112
113
def forward(self, input: Tensor) -> Tensor:
    out = F.linear(input, self.weight * self.scale)
    return fused_leaky_relu(out, self.bias * self.lr_mult)

EqualLinear

EqualLinear(
    in_dim: int,
    out_dim: int,
    bias_init: int = 0,
    lr_mult: float = 1,
)

Bases: nn.Module

Linear with equalized learning rate

Source code in stylegan2_torch/equalized_lr.py
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
def __init__(
    self,
    in_dim: int,
    out_dim: int,
    bias_init: int = 0,
    lr_mult: float = 1,
):
    super().__init__()

    # Equalized Learning Rate
    self.weight = Parameter(torch.randn(out_dim, in_dim).div_(lr_mult))

    self.bias = Parameter(torch.zeros(out_dim).fill_(bias_init))

    self.scale = (1 / math.sqrt(in_dim)) * lr_mult

    self.lr_mult = lr_mult

__call__ class-attribute

__call__ = proxy(forward)

bias instance-attribute

bias = Parameter(torch.zeros(out_dim).fill_(bias_init))

lr_mult instance-attribute

lr_mult = lr_mult

scale instance-attribute

scale = 1 / math.sqrt(in_dim) * lr_mult

weight instance-attribute

weight = Parameter(
    torch.randn(out_dim, in_dim).div_(lr_mult)
)

__repr__

__repr__() -> str
Source code in stylegan2_torch/equalized_lr.py
86
87
88
89
def __repr__(self) -> str:
    return (
        f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})"
    )

forward

forward(input: Tensor) -> Tensor
Source code in stylegan2_torch/equalized_lr.py
83
84
def forward(self, input: Tensor) -> Tensor:
    return F.linear(input, self.weight * self.scale, bias=self.bias * self.lr_mult)

Generator

Generator(
    resolution: Resolution,
    latent_dim: int = 512,
    n_mlp: int = 8,
    lr_mlp_mult: float = 0.01,
    channels: Dict[Resolution, int] = default_channels,
    blur_kernel: List[int] = [1, 3, 3, 1],
)

Bases: nn.Module

Generator module

Source code in stylegan2_torch/generator/__init__.py
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
def __init__(
    self,
    resolution: Resolution,
    latent_dim: int = 512,
    n_mlp: int = 8,
    lr_mlp_mult: float = 0.01,
    channels: Dict[Resolution, int] = default_channels,
    blur_kernel: List[int] = [1, 3, 3, 1],
):
    super().__init__()

    self.latent_dim = latent_dim

    # Create mapping network
    self.mapping = MappingNetwork(latent_dim, n_mlp, lr_mlp_mult)

    # Create constant input
    self.input = ConstantInput(channels[4], 4)

    # Create Conv, UpConv and ToRGB Blocks
    self.convs = nn.ModuleList()
    self.up_convs = nn.ModuleList()
    self.to_rgbs = nn.ModuleList()

    self.n_layers = int(math.log(resolution, 2))
    self.n_w_plus = self.n_layers * 2 - 2

    for layer_idx in range(2, self.n_layers + 1):
        # Upsample condition
        upsample = layer_idx > 2

        # Calculate image size and channels at the layer
        prev_layer_size = 2 ** (layer_idx - 1)
        layer_size: Resolution = 2 ** layer_idx
        layer_channel = channels[layer_size]

        # Upsampling Conv Block
        if upsample:
            self.up_convs.append(
                UpModConvBlock(
                    channels[prev_layer_size],
                    layer_channel,
                    3,
                    latent_dim,
                    2,
                    blur_kernel,
                )
            )

        # Normal Conv Block
        self.convs.append(ModConvBlock(layer_channel, layer_channel, 3, latent_dim))

        # ToRGB Block
        self.to_rgbs.append(
            ToRGB(
                layer_channel,
                latent_dim,
                2 if upsample else 1,
                blur_kernel,
            )
        )

__call__ class-attribute

__call__ = proxy(forward)

convs instance-attribute

convs = nn.ModuleList()

input instance-attribute

input = ConstantInput(channels[4], 4)

latent_dim instance-attribute

latent_dim = latent_dim

mapping instance-attribute

mapping = MappingNetwork(latent_dim, n_mlp, lr_mlp_mult)

n_layers instance-attribute

n_layers = int(math.log(resolution, 2))

n_w_plus instance-attribute

n_w_plus = self.n_layers * 2 - 2

to_rgbs instance-attribute

to_rgbs = nn.ModuleList()

up_convs instance-attribute

up_convs = nn.ModuleList()

forward

forward(
    input: Sequence[Tensor],
    *,
    return_latents: bool = False,
    input_type: Literal["z", "w", "w_plus"] = "z",
    trunc_option: Optional[Tuple[float, Tensor]] = None,
    mix_index: Optional[int] = None,
    noises: Optional[List[Optional[Tensor]]] = None
)
Source code in stylegan2_torch/generator/__init__.py
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
def forward(
    self,
    # Input tensors (N, latent_dim)
    input: Sequence[Tensor],
    *,
    # Return latents
    return_latents: bool = False,
    # Type of input tensor
    input_type: Literal["z", "w", "w_plus"] = "z",
    # Truncation options
    trunc_option: Optional[Tuple[float, Tensor]] = None,
    # Mixing regularization options
    mix_index: Optional[int] = None,
    # Noise vectors
    noises: Optional[List[Optional[Tensor]]] = None,
):
    # Get w vectors (can have 2 w vectors for mixing regularization)
    ws: List[Tensor]

    if input_type == "z":
        ws = [self.mapping(z) for z in input]
    else:
        ws = list(input)

    # Perform truncation
    if trunc_option:
        trunc_coeff, trunc_tensor = trunc_option
        ws = [trunc_tensor + trunc_coeff * (w - trunc_tensor) for w in ws]

    # Mixing regularization (why add dimension 1 not 0 lol)
    w_plus: Tensor
    if len(ws) == 1:
        # No mixing regularization
        mix_index = self.n_w_plus

        if input_type == "w_plus":
            w_plus = ws[0]
        else:
            w_plus = ws[0].unsqueeze(1).repeat(1, mix_index, 1)

    else:
        mix_index = mix_index if mix_index else random.randint(1, self.n_w_plus - 1)

        w_plus1 = ws[0].unsqueeze(1).repeat(1, mix_index, 1)
        w_plus2 = ws[1].unsqueeze(1).repeat(1, self.n_w_plus - mix_index, 1)

        w_plus = torch.cat([w_plus1, w_plus2], 1)
    # Get noise
    noises_: List[Optional[Tensor]] = (
        noises if noises else [None] * (self.n_w_plus - 1)
    )

    # Constant input
    out = self.input(w_plus)

    # References for this weird indexing:
    # https://github.com/NVlabs/stylegan2-ada-pytorch/issues/50
    # https://github.com/rosinality/stylegan2-pytorch/issues/278
    img = None
    for i in range(self.n_layers - 1):
        if i > 0:
            out = self.up_convs[i - 1](
                out, w_plus[:, i * 2 - 1], noises_[i * 2 - 1]
            )

        out = self.convs[i](out, w_plus[:, i * 2], noises_[i * 2])
        img = self.to_rgbs[i](out, w_plus[:, i * 2 + 1], img)

    if return_latents:
        return img, w_plus
    else:
        return img

mean_latent

mean_latent(n_sample: int, device: str) -> Tensor
Source code in stylegan2_torch/generator/__init__.py
 98
 99
100
101
102
103
def mean_latent(self, n_sample: int, device: str) -> Tensor:
    mean_latent = self.mapping(
        torch.randn(n_sample, self.latent_dim, device=device)
    ).mean(0, keepdim=True)
    mean_latent.detach_()
    return mean_latent

__docs

__docs()

Build gh-pages documentation branch.

Source code in stylegan2_torch/__init__.py
51
52
53
54
55
56
57
58
def __docs():  # pragma: no cover
    """
    Build gh-pages documentation branch.
    """
    shell(
        "cp README.md docs/index.md && \
            mkdocs gh-deploy --force"
    )

__serve

__serve()

Serve local documentation.

Source code in stylegan2_torch/__init__.py
40
41
42
43
44
45
46
47
48
def __serve():  # pragma: no cover
    """
    Serve local documentation.
    """
    print("serving")
    shell(
        "cp README.md docs/index.md && \
            mkdocs serve"
    )

__test

__test()

Runs pytest locally and keeps only coverage.xml for GitHub Actions to upload to Codecov.

Source code in stylegan2_torch/__init__.py
30
31
32
33
34
35
36
37
def __test():  # pragma: no cover
    """
    Runs pytest locally and keeps only `coverage.xml` for GitHub Actions to upload to Codecov.
    """
    shell(
        "pytest --cov=stylegan2_torch --cov-report xml --cov-report term-missing tests \
            && rm -rf .pytest_cache && rm .coverage"
    )

d_loss

d_loss(real_pred: Tensor, fake_pred: Tensor) -> Tensor

Calculates the discriminator loss. (equivalent to adversarial loss in original GAN paper).

loss = softplus(-f(x)) + softplus(f(x))

Parameters:

Name Type Description Default
real_pred Tensor

Predicted scores for real images

required
fake_pred Tensor

Predicted scores for fake images

required

Returns:

Name Type Description
Tensor Tensor

Discriminator loss

Source code in stylegan2_torch/loss.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
def d_loss(
    real_pred: Tensor,
    fake_pred: Tensor,
) -> Tensor:
    """
    Calculates the discriminator loss.
    (equivalent to adversarial loss in original GAN paper).

    loss = softplus(-f(x)) + softplus(f(x))

    Args:
        real_pred (Tensor): Predicted scores for real images
        fake_pred (Tensor): Predicted scores for fake images

    Returns:
        Tensor: Discriminator loss
    """

    real_loss = F.softplus(-real_pred)
    fake_loss = F.softplus(fake_pred)

    return real_loss.mean() + fake_loss.mean()

d_reg_loss

d_reg_loss(real_pred: Tensor, real_img: Tensor) -> Tensor
Note

The loss function was first proposed in https://arxiv.org/pdf/1801.04406.pdf. This regularization term penalizes the discriminator from producing a gradient orthogonal to the true data manifold (i.e. Expected gradient w.r.t. real image distribution should be zero). This means that:

  1. Discriminator score cannot improve once generator reaches true data distribution (because discriminator gives same expected score if inputs are from sample distribution, based on this regularization term)
  2. Near Nash equilibrium, discriminator is encouraged to minimize the gradient magnitude (because adversarial loss cannot improve, see 1)

Points 1 and 2 are sort of chicken-and-egg in nature but the idea is to help converge to the Nash equilibrium.

Calculates the discriminator R_1 loss.

Source code in stylegan2_torch/loss.py
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
def d_reg_loss(real_pred: Tensor, real_img: Tensor) -> Tensor:
    """
    Calculates the discriminator R_1 loss.

    Note:
        The loss function was first proposed in [https://arxiv.org/pdf/1801.04406.pdf](https://arxiv.org/pdf/1801.04406.pdf).
        This regularization term penalizes the discriminator from producing a gradient orthogonal to the true data manifold
        (i.e. Expected gradient w.r.t. real image distribution should be zero). This means that:

        1. Discriminator score cannot improve once generator reaches true data distribution (because discriminator gives same expected score if inputs are from sample distribution, based on this regularization term)
        2. Near Nash equilibrium, discriminator is encouraged to minimize the gradient magnitude (because adversarial loss cannot improve, see 1)

        Points 1 and 2 are sort of chicken-and-egg in nature but the idea is to help converge to the Nash equilibrium.
    """

    # Gradients w.r.t. convolution weights are not needed since only gradients w.r.t. input images are propagated
    with no_weight_grad():
        # create_graph = true because we still need to use this gradient to perform backpropagation
        # real_pred.sum() is needed to obtain a scalar, but does not affect gradients (since each sample independently contributes to output)
        (grad_real,) = autograd.grad(
            outputs=real_pred.sum(), inputs=real_img, create_graph=True
        )
    grad_penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean()

    return grad_penalty

g_loss

g_loss(fake_pred: Tensor) -> Tensor

Calculates the generator loss.

Parameters:

Name Type Description Default
fake_pred Tensor

Predicted scores for fake images

required

Returns:

Name Type Description
Tensor Tensor

Generator loss

Source code in stylegan2_torch/loss.py
62
63
64
65
66
67
68
69
70
71
72
73
74
def g_loss(fake_pred: Tensor) -> Tensor:
    """
    Calculates the generator loss.

    Args:
        fake_pred (Tensor): Predicted scores for fake images

    Returns:
        Tensor: Generator loss
    """
    loss = F.softplus(-fake_pred).mean()

    return loss

g_reg_loss

g_reg_loss(
    fake_img: Tensor,
    latents: Tensor,
    mean_path_length: Union[Tensor, Literal[0]],
    decay: float = 0.01,
) -> Tuple[Tensor, Tensor, Tensor]

Calculates Generator path length regularization loss.

Parameters:

Name Type Description Default
fake_img Tensor

Generated images (N, C, H, W)

required
latents Tensor

W+ latent vectors (N, P, 512), P = number of style vectors

required
mean_path_length Union[Tensor, Literal[0]]

Current accumulated mean path length (dynamic a)

required
decay float

Decay in accumulating a. Defaults to 0.01.

0.01

Returns:

Type Description
Tuple[Tensor, Tensor, Tensor]

Tuple[Tensor, Tensor, Tensor]: Path loss, mean path, path length

Note

This loss function was first introduced in StyleGAN2. The idea is that fixed-sized steps in W results in fixed-magnitude change in image.

Key Intuition: minimizing \(\mathbb{E}_{\mathbf{w},\mathbf{y}~N(0,1)}(||\mathbf{J^T_{\mathbf{w}}\mathbf{y}}||_2 - a)^2\) is equivalent to scaling \(W+\) equally in each dimension.

Reason:

  1. Do SVD on \(\mathbf{J^T_{\mathbf{w}}} = U \bar{\Sigma} V^T\)
  2. \(U\) and \(V\) are orthogonal and hence irrelevant (since orthogonal matrices simply rotates the vector, but \(\mathbf{y}\) is N(0,1), it is still the same distribution after rotation)
  3. \(\bar{\Sigma}\) has \(L\) non-zero singular values representing scaling factor in \(L\) dimensions
  4. Loss is minimized when \(\bar{\Sigma}\) has identical singular values equal \(\frac{a}{\sqrt{L}}\) (because high-dimensional normal distributions have norm centered around \(\sqrt{L}\))
Info

Implementation:

  1. \(a\) is set dynamically using the moving average of the path_lengths (sort of like searching for the appropriate scaling factor in an non-agressive manner).
  2. As explained in paper's Appendix B, ideal weight for path regularization is \(\gamma_{pl} = \frac{\ln 2}{r^2(\ln r - \ln 2)}\). This is achieved by setting pl_weight, then in the code, the loss is first scaled by \(r^2\) (i.e. height * width) in noise then by n_layers in path_lengths by taken mean over the n_layers style vectors. Resulting is equivalent as saying that idea pl_weight is 2. See here.
  3. path_batch_shrink controls the fraction of batch size to use to reduce memory footprint of regularization. Since it is done without freeing the memory of the existing batch.
  4. Identity \(\mathbf{J^T_{\mathbf{w}}} \mathbf{y} = \nabla (g(\mathbf{w}) \mathbf{y})\)
Source code in stylegan2_torch/loss.py
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
def g_reg_loss(
    fake_img: Tensor,
    latents: Tensor,
    mean_path_length: Union[Tensor, Literal[0]],
    decay: float = 0.01,
) -> Tuple[Tensor, Tensor, Tensor]:
    """
    Calculates Generator path length regularization loss.

    Args:
        fake_img (Tensor): Generated images (N, C, H, W)
        latents (Tensor): W+ latent vectors (N, P, 512), P = number of style vectors
        mean_path_length (Union[Tensor, Literal[0]]): Current accumulated mean path length (dynamic `a`)
        decay (float, optional): Decay in accumulating `a`. Defaults to 0.01.

    Returns:
        Tuple[Tensor, Tensor, Tensor]: Path loss, mean path, path length

    Note:
        This loss function was first introduced in StyleGAN2.
        The idea is that fixed-sized steps in W results in fixed-magnitude change in image.

        **Key Intuition**: minimizing $\mathbb{E}_{\mathbf{w},\mathbf{y}~N(0,1)}(||\mathbf{J^T_{\mathbf{w}}\mathbf{y}}||_2 - a)^2$ is equivalent to scaling $W+$ equally in each dimension.

        Reason:

        1. Do SVD on $\mathbf{J^T_{\mathbf{w}}} = U \\bar{\Sigma} V^T$
        2. $U$ and $V$ are orthogonal and hence irrelevant (since orthogonal matrices simply rotates the vector, but $\mathbf{y}$ is N(0,1), it is still the same distribution after rotation)
        3. $\\bar{\Sigma}$ has $L$ non-zero singular values representing scaling factor in $L$ dimensions
        4. Loss is minimized when $\\bar{\Sigma}$ has identical singular values equal $\\frac{a}{\sqrt{L}}$ (because high-dimensional normal distributions have norm centered around $\sqrt{L}$)

    Info:
        Implementation:

        1. $a$ is set dynamically using the moving average of the path_lengths (sort of like searching for the appropriate scaling factor in an non-agressive manner).
        2. As explained in paper's Appendix B, ideal weight for path regularization is $\gamma_{pl} = \\frac{\ln 2}{r^2(\ln r - \ln 2)}$. This is achieved by setting `pl_weight`, then in the code, the loss is first scaled by $r^2$ (i.e. height * width) in `noise` then by `n_layers` in `path_lengths` by taken mean over the `n_layers` style vectors. Resulting is equivalent as saying that idea `pl_weight` is 2. See [here](https://github.com/NVlabs/stylegan2/blob/master/training/loss.py).
        3. `path_batch_shrink` controls the fraction of batch size to use to reduce memory footprint of regularization. Since it is done without freeing the memory of the existing batch.
        4. Identity $\mathbf{J^T_{\mathbf{w}}} \mathbf{y} = \\nabla (g(\mathbf{w}) \mathbf{y})$

    """

    noise = torch.randn_like(fake_img) / math.sqrt(
        fake_img.shape[2] * fake_img.shape[3]
    )

    (grad,) = autograd.grad(
        outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True
    )
    path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1))

    path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length)

    path_penalty = (path_lengths - path_mean).pow(2).mean()

    return path_penalty, path_mean.detach(), path_lengths