loss
d_loss
d_loss(real_pred: Tensor, fake_pred: Tensor) -> Tensor
Calculates the discriminator loss. (equivalent to adversarial loss in original GAN paper).
loss = softplus(-f(x)) + softplus(f(x))
Parameters:
Name | Type | Description | Default |
---|---|---|---|
real_pred |
Tensor
|
Predicted scores for real images |
required |
fake_pred |
Tensor
|
Predicted scores for fake images |
required |
Returns:
Name | Type | Description |
---|---|---|
Tensor |
Tensor
|
Discriminator loss |
Source code in stylegan2_torch/loss.py
11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
|
d_reg_loss
d_reg_loss(real_pred: Tensor, real_img: Tensor) -> Tensor
Note
The loss function was first proposed in https://arxiv.org/pdf/1801.04406.pdf. This regularization term penalizes the discriminator from producing a gradient orthogonal to the true data manifold (i.e. Expected gradient w.r.t. real image distribution should be zero). This means that:
- Discriminator score cannot improve once generator reaches true data distribution (because discriminator gives same expected score if inputs are from sample distribution, based on this regularization term)
- Near Nash equilibrium, discriminator is encouraged to minimize the gradient magnitude (because adversarial loss cannot improve, see 1)
Points 1 and 2 are sort of chicken-and-egg in nature but the idea is to help converge to the Nash equilibrium.
Calculates the discriminator R_1 loss.
Source code in stylegan2_torch/loss.py
35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
|
g_loss
g_loss(fake_pred: Tensor) -> Tensor
Calculates the generator loss.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
fake_pred |
Tensor
|
Predicted scores for fake images |
required |
Returns:
Name | Type | Description |
---|---|---|
Tensor |
Tensor
|
Generator loss |
Source code in stylegan2_torch/loss.py
62 63 64 65 66 67 68 69 70 71 72 73 74 |
|
g_reg_loss
g_reg_loss(
fake_img: Tensor,
latents: Tensor,
mean_path_length: Union[Tensor, Literal[0]],
decay: float = 0.01,
) -> Tuple[Tensor, Tensor, Tensor]
Calculates Generator path length regularization loss.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
fake_img |
Tensor
|
Generated images (N, C, H, W) |
required |
latents |
Tensor
|
W+ latent vectors (N, P, 512), P = number of style vectors |
required |
mean_path_length |
Union[Tensor, Literal[0]]
|
Current accumulated mean path length (dynamic |
required |
decay |
float
|
Decay in accumulating |
0.01
|
Returns:
Type | Description |
---|---|
Tuple[Tensor, Tensor, Tensor]
|
Tuple[Tensor, Tensor, Tensor]: Path loss, mean path, path length |
Note
This loss function was first introduced in StyleGAN2. The idea is that fixed-sized steps in W results in fixed-magnitude change in image.
Key Intuition: minimizing \(\mathbb{E}_{\mathbf{w},\mathbf{y}~N(0,1)}(||\mathbf{J^T_{\mathbf{w}}\mathbf{y}}||_2 - a)^2\) is equivalent to scaling \(W+\) equally in each dimension.
Reason:
- Do SVD on \(\mathbf{J^T_{\mathbf{w}}} = U \bar{\Sigma} V^T\)
- \(U\) and \(V\) are orthogonal and hence irrelevant (since orthogonal matrices simply rotates the vector, but \(\mathbf{y}\) is N(0,1), it is still the same distribution after rotation)
- \(\bar{\Sigma}\) has \(L\) non-zero singular values representing scaling factor in \(L\) dimensions
- Loss is minimized when \(\bar{\Sigma}\) has identical singular values equal \(\frac{a}{\sqrt{L}}\) (because high-dimensional normal distributions have norm centered around \(\sqrt{L}\))
Info
Implementation:
- \(a\) is set dynamically using the moving average of the path_lengths (sort of like searching for the appropriate scaling factor in an non-agressive manner).
- As explained in paper's Appendix B, ideal weight for path regularization is \(\gamma_{pl} = \frac{\ln 2}{r^2(\ln r - \ln 2)}\). This is achieved by setting
pl_weight
, then in the code, the loss is first scaled by \(r^2\) (i.e. height * width) innoise
then byn_layers
inpath_lengths
by taken mean over then_layers
style vectors. Resulting is equivalent as saying that ideapl_weight
is 2. See here. path_batch_shrink
controls the fraction of batch size to use to reduce memory footprint of regularization. Since it is done without freeing the memory of the existing batch.- Identity \(\mathbf{J^T_{\mathbf{w}}} \mathbf{y} = \nabla (g(\mathbf{w}) \mathbf{y})\)
Source code in stylegan2_torch/loss.py
77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
|