15  Contrastive Learning for Enterprise Embeddings

NoteChapter Overview

Contrastive learning has emerged as the dominant paradigm for training state-of-the-art embeddings without labeled data. This chapter explores how to leverage contrastive learning at enterprise scale—from fundamental principles through production architectures that handle trillion-row training. We cover SimCLR, MoCo, hard negative mining strategies, batch optimization techniques, and distributed training patterns that power modern embedding systems.

15.1 Contrastive Learning Fundamentals

Contrastive learning transforms the embedding problem from “predict labels” to “distinguish similar from dissimilar.” This shift unlocks massive unlabeled datasets and produces embeddings that capture nuanced semantic relationships beyond what supervised learning achieves.

15.1.1 The Core Principle

The fundamental insight: embeddings should place similar items close together and dissimilar items far apart. Simple in concept, revolutionary in practice.

Traditional supervised learning requires:

  • Expensive labeled data (millions of examples)
  • Fixed label space (categories defined upfront)
  • Limited to explicit labels (can’t capture unlabeled nuances)

Contrastive learning requires only:

  • Pairs or triplets indicating similarity
  • Any method to generate positive pairs (augmentation, co-occurrence, etc.)
  • Scales to billions of unlabeled examples

15.1.2 The Contrastive Loss Landscape

InfoNCE Loss: The Foundation

InfoNCE (Noise Contrastive Estimation with Information theory) is the most widely used contrastive loss:

Show InfoNCE Loss Implementation
import torch
import torch.nn.functional as F


class InfoNCELoss:
    """
    InfoNCE loss for contrastive learning.

    Core idea: Given an anchor and one positive example, distinguish the
    positive from N-1 negative examples drawn from the distribution.
    """

    def __init__(self, temperature=0.07):
        self.temperature = temperature

    def compute_loss(self, anchor_embeddings, positive_embeddings, all_embeddings):
        batch_size = anchor_embeddings.shape[0]

        # Normalize embeddings (critical for stable training)
        anchor_norm = F.normalize(anchor_embeddings, p=2, dim=1)
        positive_norm = F.normalize(positive_embeddings, p=2, dim=1)
        all_norm = F.normalize(all_embeddings, p=2, dim=1)

        # Positive similarities
        positive_sim = torch.sum(anchor_norm * positive_norm, dim=1) / self.temperature

        # Similarity matrix: anchor × all
        similarity_matrix = torch.matmul(anchor_norm, all_norm.T) / self.temperature

        # Labels: positive is at index i for anchor i
        labels = torch.arange(batch_size, device=anchor_embeddings.device)

        # Cross-entropy loss
        loss = F.cross_entropy(similarity_matrix, labels)

        # Metrics
        with torch.no_grad():
            predictions = similarity_matrix.argmax(dim=1)
            accuracy = (predictions == labels).float().mean()
            positive_sim_mean = positive_sim.mean()

            mask = torch.ones_like(similarity_matrix, dtype=torch.bool)
            mask[torch.arange(batch_size), labels] = False
            negative_sim_mean = similarity_matrix[mask].mean()

        return loss, {
            "accuracy": accuracy.item(),
            "positive_similarity": positive_sim_mean.item(),
            "negative_similarity": negative_sim_mean.item(),
        }


# Example usage
torch.manual_seed(42)
encoder = torch.nn.Sequential(
    torch.nn.Linear(512, 256), torch.nn.ReLU(), torch.nn.Linear(256, 128)
)

anchors = torch.randn(64, 512)
positives = torch.randn(64, 512)
all_batch = torch.cat([anchors, positives], dim=0)

anchor_emb = encoder(anchors)
positive_emb = encoder(positives)
all_emb = encoder(all_batch)

loss_fn = InfoNCELoss(temperature=0.07)
loss, metrics = loss_fn.compute_loss(anchor_emb, positive_emb, all_emb)

print(f"InfoNCE Loss: {loss.item():.4f}")
print(f"Accuracy: {metrics['accuracy']:.2%}")
print(f"Positive similarity: {metrics['positive_similarity']:.4f}")
print(f"Negative similarity: {metrics['negative_similarity']:.4f}")
InfoNCE Loss: 0.0124
Accuracy: 100.00%
Positive similarity: 4.1863
Negative similarity: 4.3139

The Temperature Parameter: Critical but Often Misunderstood

Temperature τ controls the “softness” of the distribution:

  • Low temperature (0.01-0.05): Sharp distribution, focuses on hardest negatives
    • Pro: Faster convergence, better final performance
    • Con: Numerical instability, requires careful tuning
    • Use when: Large batches (1024+), well-curated negatives
  • Medium temperature (0.07-0.1): Balanced (most common)
    • Pro: Stable training, good performance
    • Con: May not fully utilize hard negatives
    • Use when: Standard training, batch size 256-1024
  • High temperature (0.2-0.5): Soft distribution, considers all negatives
    • Pro: Very stable, handles noisy negatives well
    • Con: Slower convergence, potentially lower final performance
    • Use when: Small batches, noisy data, initial training phase
Code
class TemperatureAnalysis:
    """Analyze impact of temperature on contrastive learning."""

    def recommend_temperature(self, batch_size, data_quality="high"):
        if batch_size >= 4096:
            if data_quality == "high":
                return 0.03, "Large batch + high quality -> very low temperature"
            return 0.05, "Large batch but lower quality -> slightly higher"
        elif batch_size >= 1024:
            if data_quality == "high":
                return 0.05, "Large batch + high quality -> low temperature"
            return 0.07, "Standard setting for large batches"
        elif batch_size >= 256:
            return 0.07, "Standard temperature for medium batches"
        elif batch_size >= 64:
            if data_quality == "low":
                return 0.15, "Small batch + noisy data -> higher temperature"
            return 0.1, "Small batch -> moderately high temperature"
        return 0.2, "Very small batch -> high temperature"


# Example: Get recommendations for different setups
analyzer = TemperatureAnalysis()

print("Temperature Recommendations:")
print("-" * 50)
for batch_size, quality in [(4096, "high"), (512, "medium"), (64, "low")]:
    temp, reasoning = analyzer.recommend_temperature(batch_size, quality)
    print(f"Batch {batch_size:4d}, {quality:6s} quality: τ={temp:.2f}")
    print(f"  {reasoning}")
Temperature Recommendations:
--------------------------------------------------
Batch 4096, high   quality: τ=0.03
  Large batch + high quality -> very low temperature
Batch  512, medium quality: τ=0.07
  Standard temperature for medium batches
Batch   64, low    quality: τ=0.15
  Small batch + noisy data -> higher temperature

15.1.3 Alternative Contrastive Losses

Triplet Loss: The Classic Approach

Code
import torch
import torch.nn.functional as F


class TripletLoss:
    """Triplet loss with margin."""

    def __init__(self, margin=1.0):
        self.margin = margin

    def compute_loss(self, anchor, positive, negative):
        pos_dist = 1 - F.cosine_similarity(anchor, positive, dim=-1)
        neg_dist = 1 - F.cosine_similarity(anchor, negative, dim=-1)
        loss = F.relu(pos_dist - neg_dist + self.margin)
        return loss.mean()


# Example
torch.manual_seed(42)
anchor = torch.randn(32, 128)
positive = anchor + torch.randn(32, 128) * 0.1  # Similar
negative = torch.randn(32, 128)  # Random

triplet_loss = TripletLoss(margin=0.5)
loss = triplet_loss.compute_loss(anchor, positive, negative)
print(f"Triplet Loss: {loss.item():.4f}")
Triplet Loss: 0.0000

NTXentLoss (Normalized Temperature-scaled Cross Entropy)

The loss used in SimCLR, a normalized variant of InfoNCE:

Code
import torch
import torch.nn.functional as F


class NTXentLoss:
    """NT-Xent loss from SimCLR paper."""

    def __init__(self, temperature=0.5):
        self.temperature = temperature

    def compute_loss(self, embeddings):
        batch_size = embeddings.shape[0] // 2
        embeddings = F.normalize(embeddings, p=2, dim=1)

        similarity_matrix = torch.matmul(embeddings, embeddings.T) / self.temperature

        mask = torch.eye(2 * batch_size, dtype=torch.bool, device=embeddings.device)
        similarity_matrix.masked_fill_(mask, -9e15)

        labels = torch.cat([
            torch.arange(batch_size, 2 * batch_size),
            torch.arange(0, batch_size),
        ]).to(embeddings.device)

        return F.cross_entropy(similarity_matrix, labels)


# Example
torch.manual_seed(42)
embeddings = torch.randn(64, 128)  # 32 pairs

nt_xent = NTXentLoss(temperature=0.5)
loss = nt_xent.compute_loss(embeddings)
print(f"NT-Xent Loss: {loss.item():.4f}")
NT-Xent Loss: 4.1953

15.1.4 Why Contrastive Learning Works: The Theoretical Foundation

Mutual Information Maximization

Contrastive learning maximizes the mutual information between different views of the same data:

I(x; x̃) = H(x) - H(x|x̃)

InfoNCE provides a lower bound on mutual information:

I(x; x̃) ≥ log(K) - L_InfoNCE

Where K is the number of negatives. Larger batches (more negatives) provide a tighter bound, explaining why contrastive learning benefits dramatically from large batch sizes.

Alignment and Uniformity

Recent work decomposes contrastive learning success into two properties:

  1. Alignment: Positive pairs should be close
  2. Uniformity: Embeddings should be uniformly distributed on unit hypersphere
Code
import torch
import torch.nn.functional as F


class AlignmentUniformityAnalysis:
    """Analyze embedding quality via alignment and uniformity."""

    def compute_alignment(self, emb1, emb2):
        """Lower is better (closer pairs)."""
        emb1 = F.normalize(emb1, p=2, dim=1)
        emb2 = F.normalize(emb2, p=2, dim=1)
        return torch.norm(emb1 - emb2, p=2, dim=1).pow(2).mean().item()

    def compute_uniformity(self, embeddings, t=2):
        """Lower is better (more uniform distribution)."""
        emb = F.normalize(embeddings, p=2, dim=1)
        sim_matrix = torch.matmul(emb, emb.T)
        mask = ~torch.eye(len(emb), dtype=torch.bool, device=emb.device)
        similarities = sim_matrix[mask]
        squared_distances = 2 * (1 - similarities)
        return torch.log(torch.exp(-t * squared_distances).mean()).item()


# Example
torch.manual_seed(42)
analyzer = AlignmentUniformityAnalysis()

# Good embeddings
good_emb1 = torch.randn(100, 64)
good_emb2 = good_emb1 + torch.randn(100, 64) * 0.1

# Collapsed embeddings (bad)
bad_emb = torch.randn(1, 64).expand(100, -1) + torch.randn(100, 64) * 0.01

print("Good Embeddings:")
print(f"  Alignment: {analyzer.compute_alignment(good_emb1, good_emb2):.4f}")
print(f"  Uniformity: {analyzer.compute_uniformity(good_emb1):.4f}")
print("\nCollapsed Embeddings (BAD):")
print(f"  Uniformity: {analyzer.compute_uniformity(bad_emb):.4f} <- higher = collapsed!")
Good Embeddings:
  Alignment: 0.0100
  Uniformity: -3.8725

Collapsed Embeddings (BAD):
  Uniformity: -0.0005 <- higher = collapsed!

15.2 SimCLR, MoCo, and Enterprise Adaptations

15.2.1 SimCLR: Simple Framework, Powerful Results

SimCLR achieves remarkable results with a straightforward recipe:

  1. Data augmentation pipeline: Generate two views of each example
  2. Encoder network: Extract embeddings
  3. Projection head: Non-linear MLP (critical for performance)
  4. NT-Xent loss: Normalized temperature-scaled cross entropy
  5. Large batch training: 4096+ examples per batch
Show SimCLR Implementation
import torch
import torch.nn as nn
import torch.nn.functional as F


class SimCLRTextEmbedding(nn.Module):
    """SimCLR adapted for text embeddings."""

    def __init__(self, vocab_size=10000, embed_dim=256, projection_dim=128):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.encoder = nn.Sequential(
            nn.Linear(embed_dim, embed_dim), nn.ReLU(), nn.Linear(embed_dim, embed_dim)
        )
        self.projection_head = nn.Sequential(
            nn.Linear(embed_dim, embed_dim), nn.ReLU(), nn.Linear(embed_dim, projection_dim)
        )
        self.temperature = 0.07

    def forward(self, input_ids):
        x = self.embedding(input_ids).mean(dim=1)
        representations = self.encoder(x)
        embeddings = self.projection_head(representations)
        return embeddings, representations

    def compute_loss(self, embeddings):
        batch_size = embeddings.shape[0] // 2
        embeddings = F.normalize(embeddings, p=2, dim=1)

        sim_matrix = torch.matmul(embeddings, embeddings.T) / self.temperature
        mask = torch.eye(2 * batch_size, dtype=torch.bool, device=embeddings.device)
        sim_matrix.masked_fill_(mask, -9e15)

        labels = torch.cat([
            torch.arange(batch_size, 2 * batch_size),
            torch.arange(0, batch_size),
        ]).to(embeddings.device)

        loss = F.cross_entropy(sim_matrix, labels)

        with torch.no_grad():
            accuracy = (sim_matrix.argmax(dim=1) == labels).float().mean()

        return loss, {"accuracy": accuracy.item()}


# Example
torch.manual_seed(42)
model = SimCLRTextEmbedding(vocab_size=1000, embed_dim=128, projection_dim=64)

input_ids = torch.randint(0, 1000, (32, 20))  # 16 pairs
embeddings, _ = model(input_ids)
loss, metrics = model.compute_loss(embeddings)

print(f"SimCLR Loss: {loss.item():.4f}")
print(f"Accuracy: {metrics['accuracy']:.2%}")
SimCLR Loss: 3.4567
Accuracy: 9.38%

15.2.2 MoCo: Memory-Efficient Contrastive Learning

MoCo solves a critical problem: SimCLR requires enormous batch sizes (4096+) for good negatives, which demands massive GPU memory.

MoCo’s solution: maintain a queue of negative examples across batches.

Show MoCo Implementation
import torch
import torch.nn as nn
import torch.nn.functional as F


class MoCoTextEmbedding(nn.Module):
    """MoCo for text embeddings - works with small batches!"""

    def __init__(self, vocab_size=10000, embed_dim=256, projection_dim=128,
                 queue_size=4096, momentum=0.999):
        super().__init__()
        self.queue_size = queue_size
        self.momentum = momentum
        self.temperature = 0.07

        # Query encoder
        self.encoder_q = nn.Sequential(
            nn.Embedding(vocab_size, embed_dim),
            nn.Flatten(1), nn.Linear(embed_dim * 20, projection_dim)
        )
        # Key encoder (momentum updated)
        self.encoder_k = nn.Sequential(
            nn.Embedding(vocab_size, embed_dim),
            nn.Flatten(1), nn.Linear(embed_dim * 20, projection_dim)
        )

        for p_q, p_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            p_k.data.copy_(p_q.data)
            p_k.requires_grad = False

        self.register_buffer("queue", F.normalize(torch.randn(projection_dim, queue_size), dim=0))
        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))

    @torch.no_grad()
    def _momentum_update(self):
        for p_q, p_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            p_k.data = p_k.data * self.momentum + p_q.data * (1 - self.momentum)

    @torch.no_grad()
    def _update_queue(self, keys):
        batch_size = keys.shape[0]
        ptr = int(self.queue_ptr)
        self.queue[:, ptr:ptr + batch_size] = keys.T
        self.queue_ptr[0] = (ptr + batch_size) % self.queue_size

    def forward(self, query_ids, key_ids):
        q = F.normalize(self.encoder_q(query_ids), dim=1)

        with torch.no_grad():
            self._momentum_update()
            k = F.normalize(self.encoder_k(key_ids), dim=1)

        l_pos = torch.einsum("nc,nc->n", q, k).unsqueeze(-1)
        l_neg = torch.einsum("nc,ck->nk", q, self.queue.clone().detach())

        logits = torch.cat([l_pos, l_neg], dim=1) / self.temperature
        labels = torch.zeros(logits.shape[0], dtype=torch.long, device=q.device)

        loss = F.cross_entropy(logits, labels)
        self._update_queue(k)

        with torch.no_grad():
            accuracy = (logits.argmax(dim=1) == labels).float().mean()

        return loss, {"accuracy": accuracy.item(), "queue_ptr": int(self.queue_ptr)}


# Example: MoCo works with small batches!
torch.manual_seed(42)
model = MoCoTextEmbedding(vocab_size=1000, embed_dim=64, projection_dim=32, queue_size=256)

for i in range(5):
    query = torch.randint(0, 1000, (16, 20))
    key = torch.randint(0, 1000, (16, 20))
    loss, metrics = model(query, key)

print(f"MoCo Loss: {loss.item():.4f}")
print(f"Accuracy: {metrics['accuracy']:.2%}")
print(f"Queue filled: {metrics['queue_ptr']}/256")
MoCo Loss: 7.8156
Accuracy: 0.00%
Queue filled: 80/256

15.2.3 Enterprise Adaptations

Multi-Modal Contrastive Learning

Code
import torch
import torch.nn as nn
import torch.nn.functional as F


class MultiModalContrastive(nn.Module):
    """Contrastive learning for text + image."""

    def __init__(self, text_dim=256, image_dim=512, projection_dim=128):
        super().__init__()
        self.text_proj = nn.Linear(text_dim, projection_dim)
        self.image_proj = nn.Linear(image_dim, projection_dim)
        self.temperature = 0.07

    def forward(self, text_features, image_features):
        text_emb = F.normalize(self.text_proj(text_features), dim=1)
        image_emb = F.normalize(self.image_proj(image_features), dim=1)

        logits = torch.matmul(text_emb, image_emb.T) / self.temperature
        labels = torch.arange(text_emb.shape[0], device=text_emb.device)

        loss = (F.cross_entropy(logits, labels) + F.cross_entropy(logits.T, labels)) / 2
        return loss


# Example
torch.manual_seed(42)
model = MultiModalContrastive()
loss = model(torch.randn(32, 256), torch.randn(32, 512))
print(f"Multi-modal Contrastive Loss: {loss.item():.4f}")
Multi-modal Contrastive Loss: 4.5515

15.3 Hard Negative Mining at Scale

The quality of negative examples determines contrastive learning success.

15.3.1 The Hard Negative Spectrum

  • Easy Negatives: Too different; model learns nothing useful
  • Medium Negatives: Provide useful signal
  • Hard Negatives: Force fine-grained learning (best!)
  • False Negatives: Actually positive; hurt training (avoid!)

15.3.2 Hard Negative Mining Strategies

Strategy 1: In-Batch Hard Negative Mining

Code
import torch
import torch.nn.functional as F


class InBatchHardNegativeMining:
    """Mine hard negatives from within batch (zero overhead)."""

    def __init__(self, temperature=0.07, num_hard=5):
        self.temperature = temperature
        self.num_hard = num_hard

    def compute_loss(self, anchor_emb, positive_emb):
        anchor = F.normalize(anchor_emb, dim=1)
        positive = F.normalize(positive_emb, dim=1)

        all_emb = torch.cat([anchor, positive], dim=0)
        sim_matrix = torch.matmul(anchor, all_emb.T)

        losses = []
        for i in range(len(anchor)):
            pos_sim = F.cosine_similarity(anchor[i:i+1], positive[i:i+1])
            neg_sims = torch.cat([sim_matrix[i, :i], sim_matrix[i, i+1:]])
            hard_negs = neg_sims.topk(min(self.num_hard, len(neg_sims)))[0]

            pos_exp = torch.exp(pos_sim / self.temperature)
            neg_exp = torch.exp(hard_negs / self.temperature).sum()
            losses.append(-torch.log(pos_exp / (pos_exp + neg_exp)))

        return torch.stack(losses).mean()


# Example
torch.manual_seed(42)
miner = InBatchHardNegativeMining()
loss = miner.compute_loss(torch.randn(32, 128), torch.randn(32, 128))
print(f"In-batch hard negative loss: {loss.item():.4f}")
In-batch hard negative loss: 4.1985

Strategy 2: Queue-Based Hard Negative Mining

Show Queue-Based Mining
import torch
import torch.nn.functional as F


class QueueBasedMining:
    """Maintain queue for larger negative pool."""

    def __init__(self, dim, queue_size=4096):
        self.queue = F.normalize(torch.randn(queue_size, dim), dim=1)
        self.ptr = 0
        self.queue_size = queue_size
        self.filled = 0

    def update(self, embeddings):
        n = embeddings.shape[0]
        self.queue[self.ptr:self.ptr + n] = F.normalize(embeddings.detach(), dim=1)
        self.ptr = (self.ptr + n) % self.queue_size
        self.filled = min(self.filled + n, self.queue_size)

    def get_hard_negatives(self, anchors, k=10):
        anchors = F.normalize(anchors, dim=1)
        sims = torch.matmul(anchors, self.queue[:self.filled].T)
        return sims.topk(min(k, self.filled), dim=1)[0]


# Example
torch.manual_seed(42)
miner = QueueBasedMining(dim=128, queue_size=512)

for _ in range(5):
    miner.update(torch.randn(32, 128))

hard_neg_sims = miner.get_hard_negatives(torch.randn(16, 128), k=10)
print(f"Queue filled: {miner.filled}/512")
print(f"Hard negative similarities shape: {hard_neg_sims.shape}")
print(f"Average hard negative sim: {hard_neg_sims.mean().item():.4f}")
Queue filled: 160/512
Hard negative similarities shape: torch.Size([16, 10])
Average hard negative sim: 0.1701

Strategy 3: Debiased Hard Negative Mining

Code
import torch
import torch.nn.functional as F


class DebiasedMining:
    """Filter false negatives from hard negative candidates."""

    def filter_by_margin(self, anchor, positive, candidates, margin=0.1):
        """Keep negatives with sufficient margin from positive."""
        anchor = F.normalize(anchor, dim=1)
        positive = F.normalize(positive, dim=1)

        pos_sim = F.cosine_similarity(anchor, positive, dim=1)

        filtered = []
        for i in range(len(anchor)):
            neg_sims = F.cosine_similarity(anchor[i:i+1], candidates[i], dim=1)
            valid = neg_sims < (pos_sim[i] - margin)
            filtered.append(valid.sum().item())

        return filtered


# Example
torch.manual_seed(42)
debiaser = DebiasedMining()

anchor = torch.randn(4, 64)
positive = anchor + torch.randn(4, 64) * 0.1
candidates = torch.randn(4, 10, 64)
candidates[:, :2] = anchor.unsqueeze(1) + torch.randn(4, 2, 64) * 0.05  # False negatives

kept = debiaser.filter_by_margin(anchor, positive, candidates, margin=0.1)
print("Negatives kept after debiasing:")
for i, k in enumerate(kept):
    print(f"  Example {i}: {k}/10 negatives kept")
Negatives kept after debiasing:
  Example 0: 8/10 negatives kept
  Example 1: 8/10 negatives kept
  Example 2: 8/10 negatives kept
  Example 3: 8/10 negatives kept

15.4 Batch Optimization for Trillion-Row Training

15.4.1 Why Large Batches Matter

Batch Size Relative Performance Memory (A100)
256 0.85 12 GB
1024 0.94 45 GB
4096 1.00 OOM

15.4.2 Gradient Accumulation

Code
import torch


class GradientAccumulation:
    """Simulate large batches through accumulation."""

    def __init__(self, micro_batch=256, effective_batch=2048):
        self.steps = effective_batch // micro_batch
        print(f"Accumulating {self.steps} steps: {micro_batch} × {self.steps} = {effective_batch}")


trainer = GradientAccumulation(micro_batch=256, effective_batch=2048)
Accumulating 8 steps: 256 × 8 = 2048

Note: Gradient accumulation has a flaw for contrastive learning—each micro-batch only sees its own negatives. Use distributed training for truly large batches.

15.4.3 Distributed Contrastive Learning

Show Distributed Training
import torch
import torch.nn.functional as F


class DistributedContrastive:
    """Distributed contrastive learning across GPUs."""

    def __init__(self, world_size, rank):
        self.world_size = world_size
        self.rank = rank

    def simulate_gather(self, local_emb):
        """Simulate all-gather across GPUs."""
        return torch.cat([local_emb + torch.randn_like(local_emb) * 0.01
                          for _ in range(self.world_size)], dim=0)

    def compute_loss(self, anchor, positive, temperature=0.07):
        local_batch = anchor.shape[0]

        all_anchor = self.simulate_gather(anchor)
        all_positive = self.simulate_gather(positive)
        global_batch = all_anchor.shape[0]

        all_anchor = F.normalize(all_anchor, dim=1)
        all_positive = F.normalize(all_positive, dim=1)
        local_anchor = F.normalize(anchor, dim=1)

        all_emb = torch.cat([all_anchor, all_positive], dim=0)
        sim_matrix = torch.matmul(local_anchor, all_emb.T) / temperature

        labels = torch.arange(self.rank * local_batch, (self.rank + 1) * local_batch) + global_batch

        return F.cross_entropy(sim_matrix, labels), global_batch


# Example: 4 GPU simulation
torch.manual_seed(42)
trainer = DistributedContrastive(world_size=4, rank=0)

loss, global_batch = trainer.compute_loss(torch.randn(64, 128), torch.randn(64, 128))
print(f"Distributed Loss: {loss.item():.4f}")
print(f"Effective batch: 4 GPUs × 64 = {global_batch}")
Distributed Loss: 15.9274
Effective batch: 4 GPUs × 64 = 256

15.4.4 Mixed Precision for Larger Batches

Code
import torch
import torch.nn.functional as F


class StableInfoNCE:
    """Numerically stable loss for FP16 training."""

    def __init__(self, temperature=0.07):
        self.temperature = temperature

    def compute_loss(self, anchor, all_emb):
        # Normalize in FP32 for stability
        anchor = F.normalize(anchor.float(), dim=1)
        all_emb = F.normalize(all_emb.float(), dim=1)

        sim = torch.matmul(anchor, all_emb.T) / self.temperature
        labels = torch.arange(len(anchor), device=anchor.device)

        # Log-sum-exp trick for stability
        log_denom = torch.logsumexp(sim, dim=1)
        pos_logits = sim[torch.arange(len(anchor)), labels]

        return (log_denom - pos_logits).mean()


# Example with FP16 inputs
torch.manual_seed(42)
loss_fn = StableInfoNCE()
loss = loss_fn.compute_loss(
    torch.randn(32, 128, dtype=torch.float16),
    torch.randn(64, 128, dtype=torch.float16)
)
print(f"Stable InfoNCE (from FP16): {loss.item():.4f}")
Stable InfoNCE (from FP16): 5.3713

15.5 Multi-Node Distributed Architectures

Code
import torch


class MultiNodeTraining:
    """Multi-node distributed contrastive learning."""

    def __init__(self, nodes, gpus_per_node, local_batch):
        self.total_gpus = nodes * gpus_per_node
        self.global_batch = self.total_gpus * local_batch

    def info(self):
        return {
            "total_gpus": self.total_gpus,
            "global_batch": self.global_batch
        }


# Example: 16 nodes × 8 GPUs
trainer = MultiNodeTraining(nodes=16, gpus_per_node=8, local_batch=256)
info = trainer.info()
print(f"Multi-Node Setup:")
print(f"  Total GPUs: {info['total_gpus']}")
print(f"  Global batch: {info['global_batch']:,}")
Multi-Node Setup:
  Total GPUs: 128
  Global batch: 32,768

15.5.1 Memory Optimization with Gradient Checkpointing

Code
import torch.nn as nn
from torch.utils.checkpoint import checkpoint


class MemoryEfficientModel(nn.Module):
    """Trade compute for memory with checkpointing."""

    def __init__(self, dim=768, proj_dim=128):
        super().__init__()
        self.projection = nn.Sequential(
            nn.Linear(dim, 512), nn.ReLU(), nn.Linear(512, proj_dim)
        )

    def forward(self, x, use_checkpoint=True):
        if use_checkpoint and self.training:
            return checkpoint(self.projection, x, use_reentrant=False)
        return self.projection(x)


model = MemoryEfficientModel()
x = torch.randn(32, 768, requires_grad=True)
out = model(x, use_checkpoint=True)
print(f"Output shape: {out.shape}")
print("Memory saved: ~50% with gradient checkpointing")
Output shape: torch.Size([32, 128])
Memory saved: ~50% with gradient checkpointing

15.6 Key Takeaways

  • Contrastive learning transforms embeddings into a similarity learning problem requiring only pairs/triplets instead of expensive labels

  • InfoNCE loss treats contrastive learning as classification: identify the positive from K negatives (larger batches → better embeddings)

  • Temperature critically affects training: low (0.01-0.05) for large batches, medium (0.07-0.1) for standard training, high (0.2-0.5) for noisy data

  • SimCLR vs MoCo trade-offs: SimCLR needs 4096+ batches; MoCo works with 256 using a momentum encoder and queue

  • Hard negative mining dramatically improves quality: in-batch (zero overhead), queue-based (larger pool), offline (global negatives)

  • Debiased mining prevents false negatives from hurting training through margin-based filtering

  • Distributed training enables truly large batches: 8 GPUs × 512 = 4096 effective batch size

  • Memory optimization: gradient checkpointing trades 20-30% compute for 50% memory savings

15.7 Looking Ahead

Chapter 16 explores Siamese Networks, a specialized architecture for one-shot and few-shot learning—critical for applications with limited labeled data.

15.8 Further Reading

  • Chen, T., et al. (2020). “A Simple Framework for Contrastive Learning of Visual Representations.” ICML 2020 (SimCLR)
  • He, K., et al. (2020). “Momentum Contrast for Unsupervised Visual Representation Learning.” CVPR 2020 (MoCo)
  • Oord, A., et al. (2018). “Representation Learning with Contrastive Predictive Coding.” arXiv:1807.03748 (InfoNCE)
  • Wang, T., & Isola, P. (2020). “Understanding Contrastive Representation Learning through Alignment and Uniformity on the Hypersphere.” ICML 2020
  • Gao, T., et al. (2021). “SimCSE: Simple Contrastive Learning of Sentence Embeddings.” EMNLP 2021
  • Robinson, J., et al. (2021). “Contrastive Learning with Hard Negative Samples.” ICLR 2021
  • Chuang, C., et al. (2020). “Debiased Contrastive Learning.” NeurIPS 2020