A Hands-On Guide to Implementing SimCLR for Self-Supervised Learning
In the rapidly evolving field of computer vision, self-supervised learning has emerged as a powerful technique for training robust feature extractors without the need for labeled data. One of the most notable methods in this domain is SimCLR (Simple Framework for Contrastive Learning of Visual Representations). In this tutorial, we will walk you through the reimplementation of the SimCLR self-supervised learning method, using a small dataset of 100,000 unlabelled images known as STL10. By the end of this guide, you will have a solid understanding of how to apply SimCLR to various vision datasets and downstream tasks.
Understanding SimCLR: The Basics of Contrastive Learning
At the core of SimCLR is the concept of contrastive learning, which aims to learn representations by contrasting positive pairs against negative pairs. In this context, a positive pair consists of two augmented views of the same image, while negative pairs are formed from different images. The similarity between these representations is quantified using cosine similarity.
The loss function for a positive pair of examples ((i,j)) is defined as:
[
\ell{i,j} = -\log \frac{\exp \left(\operatorname{sim}\left(\boldsymbol{z}{i}, \boldsymbol{z}{j}\right) / \tau\right)}{\sum{k=1}^{2N} \mathbb{1}{[k \neq i]} \exp \left(\operatorname{sim}\left(\boldsymbol{z}{i}, \boldsymbol{z}_{k}\right) / \tau\right)}
]
Here, (\tau) is a temperature parameter that controls the sharpness of the distribution, and (\mathbb{1}_{[k \neq i]}) is an indicator function that evaluates to 1 if (k) is not equal to (i).
L2 Normalization and Cosine Similarity
Before calculating the similarity, it is crucial to apply L2 normalization to the feature vectors. This ensures that all vectors lie on the surface of the unit hypersphere, making the cosine similarity calculation effective. The normalization can be performed using:
z_i = F.normalize(proj_1, p=2, dim=1)
z_j = F.normalize(proj_2, p=2, dim=1)
After normalization, we concatenate the two output views and compute the similarity matrix using matrix multiplication.
def calc_similarity_batch(self, a, b):
representations = torch.cat([a, b], dim=0)
return F.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=2)
Indexing the Similarity Matrix
To compute the loss, we need to index the similarity matrix appropriately to extract positive and negative pairs. The positive pairs are found on diagonals shifted by the batch size. For example:
sim_ij = torch.diag(similarity_matrix, batch_size)
sim_ji = torch.diag(similarity_matrix, -batch_size)
positives = torch.cat([sim_ij, sim_ji], dim=0)
The negatives are obtained by masking the similarity matrix to exclude the positive pairs.
Implementing the SimCLR Loss Function
The SimCLR loss function can be implemented as follows:
class ContrastiveLoss(nn.Module):
def __init__(self, batch_size, temperature=0.5):
super().__init__()
self.batch_size = batch_size
self.temperature = temperature
self.mask = (~torch.eye(batch_size * 2, batch_size * 2, dtype=bool)).float()
def forward(self, proj_1, proj_2):
batch_size = proj_1.shape[0]
z_i = F.normalize(proj_1, p=2, dim=1)
z_j = F.normalize(proj_2, p=2, dim=1)
similarity_matrix = self.calc_similarity_batch(z_i, z_j)
sim_ij = torch.diag(similarity_matrix, batch_size)
sim_ji = torch.diag(similarity_matrix, -batch_size)
positives = torch.cat([sim_ij, sim_ji], dim=0)
nominator = torch.exp(positives / self.temperature)
denominator = self.mask * torch.exp(similarity_matrix / self.temperature)
all_losses = -torch.log(nominator / torch.sum(denominator, dim=1))
loss = torch.sum(all_losses) / (2 * self.batch_size)
return loss
Data Augmentation: The Key to Self-Supervised Learning
Data augmentation plays a crucial role in self-supervised representation learning. A typical augmentation pipeline may include:
- Random cropping on a scale from 7% to 100% of the image.
- Resizing images to a standard dimension (e.g., 224×224).
- Horizontal flipping with a 50% probability.
- Heavy color jittering with an 80% probability.
- Gaussian blur with a 50% probability.
- Converting RGB images to grayscale with a 20% probability.
- Normalizing based on ImageNet statistics.
This augmentation pipeline generates two correlated views of the same image, which are then fed into the backbone model.
class Augment:
def __init__(self, img_size, s=1):
self.train_transform = torch.nn.Sequential(
T.RandomResizedCrop(size=img_size),
T.RandomHorizontalFlip(p=0.5),
T.RandomApply([T.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)], p=0.8),
T.RandomApply([T.GaussianBlur((3, 3), (0.1, 2.0))], p=0.5),
T.RandomGrayscale(p=0.2),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
)
def __call__(self, x):
return self.train_transform(x), self.train_transform(x)
Modifying ResNet18 and Defining Parameter Groups
To implement SimCLR, we need to modify the ResNet18 architecture by removing the last fully connected layer and replacing it with an identity function. We then add a projection head (an MLP) that will be used exclusively during the self-supervised pretraining stage.
class AddProjection(nn.Module):
def __init__(self, config, model=None, mlp_dim=512):
super(AddProjection, self).__init__()
self.backbone = models.resnet18(pretrained=False, num_classes=config.embedding_size)
self.backbone.fc = nn.Identity()
self.projection = nn.Sequential(
nn.Linear(in_features=mlp_dim, out_features=mlp_dim),
nn.BatchNorm1d(mlp_dim),
nn.ReLU(),
nn.Linear(in_features=mlp_dim, out_features=config.embedding_size),
nn.BatchNorm1d(config.embedding_size),
)
def forward(self, x):
embedding = self.backbone(x)
return self.projection(embedding)
Defining Parameter Groups
To optimize the model effectively, we need to define parameter groups that exclude weight decay from batch normalization layers. This can be achieved with the following function:
def define_param_groups(model, weight_decay):
param_groups = [
{
'params': [p for name, p in model.named_parameters() if 'bn' not in name],
'weight_decay': weight_decay,
},
{
'params': [p for name, p in model.named_parameters() if 'bn' in name],
'weight_decay': 0.0,
},
]
return param_groups
Training Logic for SimCLR
The training logic for SimCLR involves taking two views of the data, forwarding them through the model to obtain embedding projections, and calculating the SimCLR loss. We can encapsulate this logic in a PyTorch Lightning module.
class SimCLR_pl(pl.LightningModule):
def __init__(self, config, model=None, feat_dim=512):
super().__init__()
self.config = config
self.augment = Augment(config.img_size)
self.model = AddProjection(config, model=model, mlp_dim=feat_dim)
self.loss = ContrastiveLoss(config.batch_size, temperature=self.config.temperature)
def training_step(self, batch, batch_idx):
x, _ = batch
x1, x2 = self.augment(x)
z1 = self.model(x1)
z2 = self.model(x2)
loss = self.loss(z1, z2)
self.log('Contrastive loss', loss)
return loss
Gradient Accumulation and Effective Batch Size
Using a large batch size is crucial for the success of SimCLR. However, if resources are limited, gradient accumulation can be employed to simulate a larger effective batch size by averaging gradients over multiple steps before updating the model.
from pytorch_lightning.callbacks import GradientAccumulationScheduler
accumulator = GradientAccumulationScheduler(scheduling={0: train_config.gradient_accumulation_steps})
Main SimCLR Pretraining Script
The main script ties everything together, initializing the Trainer class of PyTorch Lightning and running the training process.
from pytorch_lightning import Trainer
trainer = Trainer(callbacks=[accumulator], gpus=available_gpus, max_epochs=train_config.epochs)
trainer.fit(model, data_loader)
Fine-Tuning the Model
After pretraining, the next step is fine-tuning the model on a specific task. This involves adding a linear layer on top of the pretrained backbone and training it on labeled data.
class SimCLR_eval(pl.LightningModule):
def __init__(self, lr, model=None):
super().__init__()
self.mlp = nn.Sequential(nn.Linear(512, 10))
self.model = nn.Sequential(model, self.mlp)
self.loss = nn.CrossEntropyLoss()
def training_step(self, batch, batch_idx):
x, y = batch
z = self.forward(x)
loss = self.loss(z, y)
return loss
Conclusion
In this tutorial, we explored the implementation of the SimCLR self-supervised learning method for pretraining robust feature extractors. We covered the key components, including the contrastive loss function, data augmentation strategies, and the training logic using PyTorch Lightning. Despite being a baseline method, SimCLR has shown promising results in various applications, and understanding its mechanics opens the door to experimenting with other self-supervised learning techniques.
Thank you for your interest in AI and self-supervised learning! Stay curious and keep exploring the fascinating world of deep learning.