Hi everyone 👋
I'm back with another interesting topic in Generative Adversarial Networks (GANs). GANs are a type of machine learning model that can be used to generate realistic images, text, and other data. It's important to understand GANs because they are used in a variety of applications, such as image editing, product design, and natural language processing.
In this article, I will share some of the most useful loss functions for GANs, from basic to advanced projects. First, I will introduce each loss function and then provide code examples in PyTorch. I will also include mathematical formulas and research papers to help you learn more deeply about each loss function.
If you have any questions or problems, please feel free to ask in the comments. Learning new things is always a challenge, but it's also one of the most rewarding experiences.
Just keep reading…
What is Loss Function?
Loss functions are a key way to measure how well a model is performing. They mathematically quantify the difference between the real data and the predicted data. Loss functions also measure the model's performance and guide the optimizer on how much the model should learn from the real data. It is important to learn about loss functions because they are used in most research papers on machine learning. So, don't skip over this topic!
Min-Max Loss Using For GAN Training
The min-max loss function is used to train two models that compete against each other. The model that is trying to minimize the loss is called the "min player" and the model that is trying to maximize the loss is called the "max player." This is also known as a min-max game, and it is a common way to train GANs.
Adversarial loss is a specific type of min-max loss that is used in GANs. It is a two-player game where the generator tries to generate realistic data that the discriminator cannot distinguish from real data. The discriminator tries to distinguish between real and generated data.
$$\begin{align*} \min_G \max_D V(D, G) &= \mathbb{E}{x \sim p\text{data}} [\log D(x)] \\ &\quad + \mathbb{E}_{z \sim p_z(_z)} [\log(1 - D(G(z)))]. \end{align*}$$
Adversarial Loss
Adversarial loss is a way to create competition between the generator and discriminator. The generator tries to generate data that looks like real data, and the discriminator tries to predict whether the data is real or fake. Adversarial loss is the main loss function used in GANs, and it is typically implemented using the binary cross-entropy (BCE) loss function. However, other loss functions can also be used, such as Wasserstein loss.
Here is a real-life analogy for adversarial loss.
Imagine you have two friends, Chi and Zen. Chi is a good artist, and Zen is good at detecting fake drawings. Zen helps Chi to become a better artist by challenging him to create drawings that Zen cannot distinguish from real drawings.
In the same way, the discriminator in a GAN helps the generator to become better by challenging it to generate data that the discriminator cannot distinguish from real data.
Please focus on this formula for calculating the discriminator loss and updating the model parameters.
$$\nabla_{\theta_d} \frac{1}{m} \sum_{i=1}^m [ \left( \log D(x_i) - \log (1 - D(G(z_i))) \right)]$$
Once the discriminator part is complete, you can proceed to the generator loss formula to calculate the loss and update the model parameters.
$$\nabla_{\theta_g} \frac{1}{m} \sum_{i=1}^m \log (1 - D(G(z_i)))$$
To learn more, please read the research paper "Generative Adversarial Loss."
gan = Generator(...)
disc = Discriminator(...)
g_optim = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
d_optim = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
# Adversarial loss function
def adversarial_loss(outputs, target_labels):
loss = nn.BCELoss()
return loss(outputs, target_labels)
# Training Loop for using binary cross entropy
for epoch in range(num_epochs):
for batch_idx, real_data in enumerate(dataloader):
# Train the Discriminator
d_optim.zero_grad()
real_labels = torch.ones(batch_size, 1)
fake_labels = torch.zeros(batch_size, 1)
real_output = disc(real_data)
real_loss = adversarial_loss(real_output, real_labels)
fake_input = torch.randn(batch_size, latent_dim)
fake_data = gan(fake_input)
fake_output = disc(fake_data.detach()) # Detach gradients from the generator
fake_loss = adversarial_loss(fake_output, fake_labels)
dis_loss = real_loss + fake_loss
dis_loss.backward()
dis_optimizer.step()
# Train the Generator
g_optim.zero_grad()
fake_output = disc(fake_data)
gen_loss = adversarial_loss(fake_output, real_labels)
gen_loss.backward()
gen_optimizer.step()
Kullback-Leibler divergence (KL-divergence) Loss
Calculating the statistical distance between real and generated data distributions is a time-consuming task in GANs. KL-divergence loss is a non-saturating loss function that measures the similarity between two distributions. A lower KL-divergence loss value indicates that the two distributions are more similar. This loss function helps to prevent GANs from getting stuck in local minima.
$$L(y_{\text{pred}}, y_{\text{true}}) = y_{\text{true}} \cdot \log\left(\frac{y_{\text{true}}}{y_{\text{pred}}}\right) = y_{\text{true}} \cdot (\log y_{\text{true}} - \log y_{\text{pred}})$$
Read more in research paper: Kullback-Leibler divergence (KL-divergence loss)
import torch
from torch.nn import KLDivLoss
# Create the input and target tensors
inputs = torch.randn(10, 10)
targets = torch.randn(10, 10)
# Convert the input and target tensors to probability distributions
generate = torch.softmax(inputs,dim=1)
targets = torch.softmax(targets,dim=1)
# Create the KL-divergence loss function
loss_fn = KLDivLoss()
# Calculate the loss
loss = loss_fn(generate, soft_targets)
print(loss)
## Output>>> tensor(-0.2001)
Wasserstein Loss
The Wasserstein loss calculates the difference between the expected values of the discriminator's outputs for real and generated data. The discriminator is trained to distinguish between real and generated data. The Wasserstein loss is a non-saturating loss function, which means that it does not saturate as the generator becomes better at creating realistic data. This makes the Wasserstein loss more stable than other loss functions, and it is less likely to lead to the generator getting stuck in local minima.
The following formula is used by the discriminator.👇
$$L_D = W_d(p_r, p_z) = E_{x \sim p_r} [D(x)] - E_{x \sim p_z} [D(G(z))]$$
The following formula is used by the generator.👇
$$L_G = -E_{z \sim p_z} [D(G(z))]$$
Read more in research paper: Wasserstein GAN Loss
# Wasserstein loss for the critic
def d_wasserstein_loss(p_real, p_fake):
"""
Compute the Wasserstein loss for the discriminator (critic).
Args:
p_real (torch.Tensor): Predictions for real data.
p_fake (torch.Tensor): Predictions for fake data.
Returns:
torch.Tensor: Wasserstein loss for the discriminator.
"""
r_loss = torch.mean(p_real)
f_loss = torch.mean(p_fake)
return f_loss - r_loss
# Wasserstein loss for the generator
def g_wasserstein_loss(pred_fake):
"""
Compute the Wasserstein loss for the generator.
Args:
pred_fake (torch.Tensor): Predictions for fake data.
Returns:
torch.Tensor: Wasserstein loss for the generator.
"""
return -1 * torch.mean(pred_fake)
Gradient Penalty Loss
The Wasserstein loss is used to help prevent the discriminator from becoming too confident and to make it easier to train the generator.
One analogy to help understand what the gradient penalty loss does is to imagine a teacher training you to distinguish between apples and bananas. The teacher puts the apples and bananas in different baskets and teaches you to distinguish between them with high accuracy, giving you a reward each time you are correct. However, as you become better at telling the difference between apples and bananas, the teacher may start to give you smaller rewards, because it is redundant to reward you for completing a task that you have already mastered.
The gradient penalty loss works in a similar way. It encourages the discriminator to learn to distinguish between real and generated data, but it also penalizes the discriminator if it becomes too confident in its predictions. This helps to prevent the discriminator from overfitting to the training data and makes it easier for the generator to learn to create realistic data.
$$\lambda \space{\mathbb{E}{\displaystyle\substack{\hat{x} \sim}} P_{\hat{x}}} \left[ \left( \left\|\nabla_{\hat{x}} D(\hat{x}) \right\|_2 - 1 \right)^2 \right]$$
It’s above show only gradient panalty loss formula and now see combine of both discriminator WLoss (Wasserstein Loss) with GP (Gradient Panalty Loss).
$$WGPLoss = E_{x \sim p_g} [D(\hat x)] - E_{x \sim p_r} [D(x)] \space + \lambda \space{\mathbb{E}{\displaystyle\substack{\hat{x} \sim}} P_{\hat{x}}} \left[ \left( \left\|\nabla_{\hat{x}} D(\hat{x}) \right\|_2 - 1 \right)^2 \right]$$
# First implement gradient
def get_gradient(crit, real, fake, epsilon):
'''
Return the gradient of the critic's scores with respect to mixes of real and fake images.
Parameters:
crit: the critic model
real: a batch of real images
fake: a batch of fake images
epsilon: a vector of the uniformly random proportions of real/fake per mixed image
Returns:
gradient: the gradient of the critic's scores, with respect to the mixed image
'''
mixed_images = real * epsilon + fake * (1 - epsilon)
mixed_scores = crit(mixed_images)
gradient = torch.autograd.grad(
inputs=mixed_images,
outputs=mixed_scores,
grad_outputs=torch.ones_like(mixed_scores),
create_graph=True,
retain_graph=True,
)[0]
return gradient
def gradient_penalty(gradient):
'''
Return the gradient penalty, given a gradient.
Given a batch of image gradients, you calculate the magnitude of each image's gradient
and penalize the mean quadratic distance of each magnitude to 1.
Parameters:
gradient: the gradient of the critic's scores, with respect to the mixed image
Returns:
penalty: the gradient penalty
'''
gradient = gradient.view(len(gradient), -1)
gradient_norm = gradient.norm(2, dim=1)
penalty = torch.mean((gradient_norm - 1)**2)
return penalty
# Actual Training loop For Using Wesserstein loss with Gradient panalty.
n_epochs = 5
cur_step = 0
c_lambda = 10
crit_repeats = 5
for epoch in range(n_epochs):
for i, data in enumerate(dataloader):
actual_image = data[0].to(device)
b_size = actual_image.size(0)
# Train The Discriminator
mean_iteration_critic_loss = 0
for _ in range(crit_repeats):
disc_opt.zero_grad()
noise = torch.rand(b_size,100,1,1,device=device)
fake_image = gen(noise)
fake_pred = disc(fake_image.detach())
real_pred = disc(actual_image)
epsilon = torch.rand(len(actual_image),1,1,1,device=device,requires_grad=True)
gradient = get_gradient(disc,actual_image,fake_image.detach(),epsilon)
gp = gradient_penalty(gradient)
crit_loss = d_wasserstein_loss(real_pred,fake_pred) + c_lambda * gp
mean_iteration_critic_loss += crit_loss.item()/crit_repeats
crit_loss.backward(retain_graph=True)
disc_opt.step()
# Upgrade Generator Network.
gen_opt.zero_grad()
fake_noise2 = torch.rand(b_size,100,1,1,device=device)
fake_image2 = gen(fake_noise2)
fake_disc = disc(fake_image2)
fake_gen_loss = g_wasserstein_loss(fake_disc)
fake_gen_loss.backward()
gen_opt.step()
Least Squares Loss
Least squares loss, also known as L2 loss or MSE loss, is typically used for regression tasks, but it can also be used as a loss function for GANs, where it is known as LSGAN. LSGAN loss has several advantages over binary cross-entropy loss, which is the traditional loss function used for GANs. LSGAN loss is less prone to vanishing gradients, more stable to train, and can generate higher quality images. It simply measures the mean squared error between the predicted values and the actual target values.
Binary cross-entropy loss is also a good loss function for GANs, but it has some limitations. It can only predict whether an image is real or fake, but it does not provide any information about how similar a fake sample is to a real sample. This can make it difficult to train the generator to produce high-quality images. Additionally, binary cross-entropy loss can become saturated when the generated samples are very different from the real samples, which can make it difficult to train the generator.
$$\text{MSE} = \frac{1}{n} \sum_{i=1}^{n} (y_i - \hat{y}_i)^2$$
Read more in reseach paper: Least square loss in paper
# Pytorch inbuild Least square loss have name is MSELoss
m_loss = nn.MSELoss()
# Or you can using Tensorflow than directly compile time put
# 'mse' string into the loss categorey
Pixel Wise Loss
Pixel wise loss is already say how to calculate this loss. If you don’t understand what I say It’s simply calculate different between predicted image and ground truth image value pixel by pixel not for whole image. This loss method to say how much similar predicted image to actual image. If it’s lower value that indicate good sign for model generate image that look like actual image.
This are some of the loss can work with that: L1Loss and MSELoss.
$$L1\ Loss\ or\ MAE = \frac{1}{n}\sum_{i=1}^{n}|y_i - \hat{y}_i|$$
# Pytorch inbuild Least square loss have name is MSELoss
m_loss = nn.MSELoss()
# Or
l1loss = nn.L1Loss()
# Or you can using Tensorflow than directly compile time put
# 'mse' string into the loss categorey
# OR
# 'mae' string into the loss category for deffine as L1 loss in the tensorflow.
Perceptual Loss
The most widely used loss function in GANs is the feature matching loss. This loss function calculates the difference between the generated image and the ground truth image in terms of their higher-level features. This means it doesn't just calculate pixel by pixel differences but also provides information about the overall structure and appearance of the image.
The most popular way to implement this loss function is to first use a pre-trained model to extract features from both the real and generated images. Then, you calculate the difference between the two sets of higher-level features. However, it's also possible to use other types of loss functions to measure the differences in higher-level features.
One drawback to note is that this approach can be computationally expensive.
$$L_\text{perceptual} = \sum_{i=1}^N \left\| Φ(G){i} - Φ(T){i} \right\|_2^2$$
To learn more read this paper: Perceptual Losses for Real-Time Style Transfer and Super-Resolution
# Perceptual Implementation for using as pytorch
import torch
import torchvision.models as models
import torch.nn.functional as F
import torchvision.transforms as transforms
# Load a pretrained VGG16 model
vgg16 = models.vgg16(pretrained=True).to(device)
vgg16_features = vgg16.features
# Define a transform to preprocess images for the VGG model
preprocess = transforms.Compose([
transforms.Resize((224, 224)), # Resize to VGG input size
])
def perceptual_loss(generate, target):
# Preprocess input and target images
generate = preprocess(generate)
target = preprocess(target)
# Calculate VGG feature maps
generate_feature = vgg16_features(generate)
target_features = vgg16_features(target)
loss = 0.0 # Initialize the loss variable
for generate_feat, target_feat in zip(generate_feature, target_features):
# Calculate mean squared difference for each feature map
loss += torch.mean((generate_feat - target_feat)**2)
return loss
Feature Matching Loss
It works a little bit differently from perceptual loss, becuse this loss function calculate the difference between the generated image and the real image in term of a lower level of feature not a higher level. But it’s the same not only measuring the difference in the image pixel by pixel but also providing the overall structure of the image.
$$\begin{equation}\mathcal{L}{FM} = \frac{1}{N\times 3} \sum{i=0}^{N} ||D_k^{(i)}(x) - D_k^{(i)}(G(x))||1\end{equation}$$
import torch
import torch.nn.functional as F
def feature_matching_loss(real_pred, fake_pred):
"""Implements the feature matching loss in PyTorch.
Args:
real_pred: Tensor, output of the ground truth wave passed through the discriminator.
fake_pred: Tensor, output of the generator prediction passed through the discriminator.
Returns:
Feature Matching Loss.
"""
fm_loss = []
for i in range(len(fake_pred)):
for j in range(len(fake_pred[i]) - 1):
fm_loss.append(F.l1_loss(real_pred[i][j], fake_pred[i][j]))
return torch.mean(torch.stack(fm_loss))
Cycle Consistency Loss
Have you ever wondered how machines learn to convert images from one style to another? In the case of CycleGAN, it uses a clever technique called cycle consistency loss. This loss function works by comparing the original image to the revised image that has been converted back to the original style. But why use two generators instead of just one? The secret lies in the fact that the first generator converts the real image to a different domain, while the second generator transforms the converted domain image back into a real image. By calculating the difference between the second converted real image and the actual real image, we can evaluate how effective the model is at accurately converting images.
Think of it like this: imagine a machine that can turn a photograph of an ocean into a sketch of an ocean. The cycle consistency loss is like measuring how well the machine can turn the sketch back into a photograph of an ocean. It's all about checking how closely the final result matches the original. Just as we might compare a sketch of an ocean to a photograph of an ocean, the cycle consistency loss helps us ensure that our model is producing high-quality conversions.
$$L_\text{cyc}(G, F) = \mathbb{E}{x \sim p\text{data}(x)}[||F(G(x)) - x||1] + \mathbb{E}{y \sim p\text{data}(y)}[||G(F(y)) - y||1]$$
To learn more read research paper: Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks
impor torch
from torch import nn
adv_criterion = nn.MSELoss()
recon_criterion = nn.L1Loss()
n_epochs = 20
dim_A = 3
dim_B = 3
batch_size = 1
lr = 0.0002
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Write Generator and Discriminator Architacture.
gen_AB = generator(...)
gen_BA = generator(...)
disc_A = Discriminator(...)
disc_B = Discriminator(...)
gen_AB = Generator(dim_A, dim_B).to(device)
gen_BA = Generator(dim_B, dim_A).to(device)
# It's new one for you becuse in this case two generator using one optimizer.
gen_opt = torch.optim.Adam(list(gen_AB.parameters()) + list(gen_BA.parameters()), lr=lr, betas=(0.5, 0.999))
disc_A = Discriminator(dim_A).to(device)
disc_A_opt = torch.optim.Adam(disc_A.parameters(), lr=lr, betas=(0.5, 0.999))
disc_B = Discriminator(dim_B).to(device)
disc_B_opt = torch.optim.Adam(disc_B.parameters(), lr=lr, betas=(0.5, 0.999))
def cycle_consistency_loss(real_A,real_B,fake_A,fake_B,gen_BA, gen_AB, cycle_criterion,lambda_cycle = 10):
'''
Return the cycle consistency loss of the generator given inputs
(and the generated images for testing purposes).
Parameters:
real_X: the real images from pile X
real_Y: the real images from pile Y
fake_X: the generated images of class X
fake_Y: the generated images of class Y
gen_YX: the generator for class Y to X; takes images and returns the images
transformed to class X
gen_XY: the generator for class X to Y; take images and returns the images
transformed to class Y
cycle_criterion: The cycle consistency loss function is a way to measure
how well the generator can reverse its own transformations.
It does this by taking real images from pile X, passing them through the X->Y generator,
and then passing the resulting images through the Y->X generator. The cycle consistency
loss is then calculated as the difference between the original real images and the images that are generated by the Y->X generator.
'''
cycle_A = gen_BA(fake_B)
cycle_loss1 = cycle_criterion(cycle_A, real_A)
cycle_B = gen_AB(fake_A)
cycle_loss2 = cycle_criterion(cycle_B,real_B)
# total cycle loss
total_cycle_loss = lambda_cycle * (cycle_loss1 + cycle_loss2)
return total_cycle_loss
fake_B = gen_AB(input_A)
fake_A = gen_BA(input_B)
# Cycle-consistency Loss -- get_cycle_consistency_loss(real_X, real_Y,fake_X,fake_Y, gen_YX,gen_XY cycle_criterion)
gen_cycle_loss = get_cycle_consistency_loss(real_A, real_B,fake_A,fake_B, gen_BA,gen_AB, recon_criterion)
Identity Loss
The identity loss function is most commonly used in Pix2Pix and CycleGAN for image-to-image translation tasks. It is used to ensure that the generator does not simply copy the input image, but instead learns to translate it to the target domain while preserving its core features.
$$L_\text{identity}(G, F) = \mathbb{E}{y \sim p\text{data}(y)} [||G(y) - y||1] + \mathbb{E}{x \sim p_\text{data}(x)} [||F(x) - x||_1]$$
Learn more about read paper: Identity Loss For Generative-Adversarial-Network
import torch
from torch import nn
def get_identity_loss(real_A,real_B gen_AB,gen_BA, identity_criterion):
'''
Return the identity loss of the generator given inputs
(and the generated images for testing purposes).
Parameters:
real_A: the real images from pile A
real_B: the real images from pile B
gen_AB: the generator for class X to Y; takes images and returns the images
transformed to class Y
gen_BA: the generator for class Y to X; takes images and returns the images
transformed to class X
identity_criterion: the identity loss function; takes the real images from X and
those images put through a Y->X generator and returns the identity
loss 1. than again take the real image from Y and those images through
a X->Y generator adn returns the identity loss 2.than combine the
identity loss 1 and identity loss 2. (which you aim to minimize)
'''
identity_A = gen_BA(real_A)
identity_loss_A = identity_criterion(identity_A, real_A)
identity_B = gen_AB(real_B)
identity_loss_B = identity_criterion(identity_B,real_B)
gen_identity_loss = identity_loss_A + identity_loss_B
return gen_identity_loss
identity_criterion = nn.L1Loss()
# Identity Loss -- get_identity_loss(real_X,Real_Y, gen_XY, gen_YX, identity_criterion)
gen_identity_loss = get_identity_loss(real_A,real_B, gen_AB, gen_BA, identity_criterion)
Peak Signal-to-Noise Ratio (PSNR)
Peak signal-to-noise ratio (PSNR) is a loss function that can be used to train GANs, but it is not as common as other loss functions, such as adversarial loss and perceptual loss. PSNR is a measure of the quality of a reconstructed image compared to the original image. It is calculated as the ratio of the maximum possible power of a signal to the power of corrupting noise that affects the fidelity of its representation.
In the context of GANs, PSNR can be used to ensure that the generator produces high-quality images. However, it is important to note that PSNR is not a perfect measure of image quality. For example, PSNR can be high for an image that is blurred or has artefacts.
Here is an analogy to help you understand PSNR:
Imagine you have a perfect copy of a dog image. Now, you add some noise to the dog image. The more noise you add, the less similar the image will be to the original dog image. PSNR measures how much noise is in the dog image. A higher PSNR value indicates that the image is more similar to the original image and has less noise.
When training a GAN, PSNR can be used in conjunction with other loss functions to ensure that the generator produces high-quality images that are also realistic and perceptually similar to real images.
$$PSNR = 10 \cdot \log_{10}\left(\frac{1}{\sqrt{MSE}}\right)$$
To learn more read the paper: A Formal Evaluation of PSNR as Quality Measurement Parameter for Image Segmentation Algorithms
import torch
def psnr_loss(pred, target):
mse = torch.mean((pred - target)**2)
return 10.0 * torch.log10(1.0 / torch.sqrt(mse))
Cosine Similarity Loss
Cosine similarity is a measure of the similarity between two non-zero vectors in an inner product space. It is calculated by taking the dot product of the two vectors and dividing it by the product of their magnitudes. This means that cosine similarity is independent of the magnitudes of the vectors, and only depends on their direction.
In general terms, cosine similarity measures how directionally similar two vectors are. A cosine similarity of 1 indicates that the two vectors are perfectly aligned, while a cosine similarity of -1 indicates that the two vectors are completely opposite. A cosine similarity of 0 indicates that the two vectors are orthogonal, or perpendicular, to each other.
Cosine similarity can be used in a variety of machine learning tasks, including:
Image retrieval: Cosine similarity can be used to measure the similarity between two images by comparing their feature vectors. This can be used to retrieve similar images from a database.
Natural language processing: Cosine similarity can be used to measure the similarity between two sentences or documents by comparing their word vectors. This can be used for tasks such as text classification and machine translation.
Recommendation systems: Cosine similarity can be used to measure the similarity between two users or items by comparing their feature vectors. This can be used to recommend items to users that they are likely to enjoy.
Cosine similarity is a simple and effective way to measure the similarity between two vectors. It is used in a variety of machine-learning tasks, including image retrieval, natural language processing, and recommendation systems.
$$\begin{align*}L_\text{cos-sim}(f(x), y) &= 1 - \frac{y \cdot f(x)}{\|y\| \|f(x)\|} \end{align*}$$
To learn more about read the research paper - A survey and taxonomy of loss functions in machine
import torch
# Make sure this loss function only calculate 1D-Vector
def cosine_similarity_loss(y_true, y_pred):
# Compute the dot product of y_true and y_pred
dot_product = torch.dot(y_true, y_pred)
# Compute the L2 norm of y_true
norm_y_true = torch.norm(y_true)
# Compute the L2 norm of y_pred
norm_y_pred = torch.norm(y_pred)
# Compute the cosine similarity
cosine_sim = dot_product / (norm_y_true * norm_y_pred)
# Compute the cosine loss and ensure it's non-negative
cosine_loss = torch.clamp(1 - cosine_sim, min=0)
return cosine_loss
Structural Similarity Index Measure (SSIM) Loss
It measures the similarity between two images based on luminance, contrast, and structural differences. The first step followed by to break the two images down into three parts: brightness, contrast, and structure. Following this, the similarity between the corresponding components of the two images is measured using a variety of metrics. now final step to calculating the weighted average of the similarity scores of the different components. The main advantage of the SSIM loss function is its more effective in Image-to-Image translation problems.
Please read this analogy to understand easily what is doing SSIM loss.
Suppose you are trying to draw a picture of a man. You can draw the man nose, eyes, and legs accurately, but you don't draw correctly overall shape of the man, then the picture will not look like the original man.
The SSIM loss function is similar to trying to draw the man overall shape correctly. It considers the overall shape of the image, as well as how bright and dark it is. This means that the SSIM loss function is better at generating images that are realistic and pleasing to the eye.
$$\text{SSIM}(x, y) = \frac{(2 \mu_x \mu_y + C_1)(2 \sigma_{xy} + C_2)}{(\mu_x^2 + \mu_y^2 + C_1)(\sigma_x^2 + \sigma_y^2 + C_2)}$$
To learn more read the paper: Image Quality Assessment: From Error Visibility to Structural Similarity
import torch
import torch.nn.functional as F
def ssim_loss(x, y):
"""Computes the Structural Similarity Index (SSIM) loss between two images.
Args:
x (torch.Tensor): input as predicted image.
y (torch.Tensor): input as actual ground truth image.
Returns:
torch.Tensor: The SSIM loss mesure similarity between the two images.
"""
# Calculate the mean and variance of the two images.
mu_x = F.avg_pool2d(x, 3, padding=1)
mu_y = F.avg_pool2d(y, 3, padding=1)
# Calculate the variance of the two images.
sigma_x2 = F.avg_pool2d(x**2, 3, padding=1) - mu_x**2
sigma_y2 = F.avg_pool2d(y**2, 3, padding=1) - mu_y**2
# Calculate the covariance between the two images.
sigma_xy = F.avg_pool2d(x * y, 3, padding=1) - mu_x * mu_y
# Add small constants to avoid division by zero.
C1 = 0.01 ** 2
C2 = 0.03 ** 2
# Calculate the SSIM loss.
ssim_l = ((2 * mu_x * mu_y + C1) * (2 * sigma_xy + C2)) / ((mu_x ** 2 + mu_y ** 2 + C1) * (sigma_x2 + sigma_y2 + C2))
# Subtract 1 from the SSIM index to make the loss function compatible
# with other loss functions that are minimized when the difference between
# the predicted image and the ground truth image is zero.
ssim_l = 1 - ssim_l
# Take the mean of the loss values for each image in the batch
# to get a single loss value that can be used to train the model.
ssim_m = torch.mean(ssim_l)
return ssim_m
Thanks for reading! I hope you found this article helpful. If you have any questions, please leave a comment below. And don't forget to sign up for my newsletter so you never miss a new post!