17  Self-Supervised Learning Pipelines

NoteChapter Overview

While contrastive learning (Chapter 5) and Siamese networks (Chapter 6) require labeled pairs or triplets, self-supervised learning unlocks the ability to learn from unlabeled data at unprecedented scale. This chapter explores self-supervised techniques that leverage the inherent structure of data to create powerful embeddings without manual annotation. We cover masked language modeling for domain-specific text, vision transformers for industrial imagery, time-series forecasting approaches, and multi-modal self-supervision strategies. These techniques enable enterprises to train embeddings on trillions of unlabeled documents, images, and sensor readings—data that already exists but was previously unusable for training.

17.1 Self-Supervised Learning for Unlabeled Enterprise Data

The fundamental challenge facing enterprise AI: you have petabytes of data but almost no labels. Traditional supervised learning requires expensive manual annotation. Self-supervised learning solves this by turning the data itself into both input and supervision.

17.1.1 The Self-Supervised Paradigm

Self-supervised learning creates “pretext tasks” where the model must predict part of the input from other parts. The key insight: by learning to solve these pretext tasks, the model develops representations that capture the underlying structure of the data.

Common pretext tasks:

  • Masked prediction: Predict hidden parts (BERT, MAE)
  • Next token prediction: Predict future content (GPT, autoregressive models)
  • Contrastive prediction: Distinguish augmented views (SimCLR, MoCo)
  • Reconstruction: Rebuild input from transformed version (autoencoders)
Show Self-Supervised Embedding Framework
import torch
import torch.nn as nn
import torch.nn.functional as F


class SelfSupervisedEmbeddingFramework:
    """Framework for self-supervised learning on enterprise data.

    Supports masked prediction, contrastive learning, and reconstruction tasks.
    """

    def __init__(self, encoder_model, pretext_task="masked", embedding_dim=768, mask_probability=0.15):
        self.encoder = encoder_model
        self.pretext_task = pretext_task
        self.embedding_dim = embedding_dim
        self.mask_probability = mask_probability

        if pretext_task == "masked":
            self.prediction_head = nn.Linear(embedding_dim, embedding_dim)
        elif pretext_task == "contrastive":
            self.projection_head = nn.Sequential(
                nn.Linear(embedding_dim, embedding_dim), nn.ReLU(), nn.Linear(embedding_dim, 128)
            )
        elif pretext_task == "reconstruction":
            self.decoder = self._build_decoder(embedding_dim)

    def _build_decoder(self, embedding_dim):
        return nn.Sequential(
            nn.Linear(embedding_dim, embedding_dim * 2), nn.ReLU(),
            nn.Linear(embedding_dim * 2, embedding_dim * 4), nn.ReLU(),
            nn.Linear(embedding_dim * 4, embedding_dim)
        )

    def create_pretext_task(self, batch):
        """Create pretext task from unlabeled batch."""
        if self.pretext_task == "masked":
            return self._create_masked_task(batch)
        elif self.pretext_task == "contrastive":
            return self._create_contrastive_task(batch)
        elif self.pretext_task == "reconstruction":
            return self._create_reconstruction_task(batch)

    def _create_masked_task(self, batch):
        batch_size, seq_len, features = batch.shape
        mask = torch.rand(batch_size, seq_len) < self.mask_probability
        inputs = batch.clone()
        inputs[mask] = 0
        return inputs, batch.clone(), mask

    def _create_contrastive_task(self, batch):
        view1 = self._augment(batch)
        view2 = self._augment(batch)
        return (view1, view2), None, None

    def _create_reconstruction_task(self, batch):
        noise = torch.randn_like(batch) * 0.1
        return batch + noise, batch, None

    def _augment(self, batch):
        noise = torch.randn_like(batch) * 0.05
        return batch + noise

    def forward(self, inputs):
        return self.encoder(inputs)

    def compute_loss(self, inputs, targets, mask=None):
        """Compute loss for pretext task."""
        if self.pretext_task == "masked":
            embeddings = self.encoder(inputs)
            predictions = self.prediction_head(embeddings)
            loss = F.mse_loss(predictions[mask], targets[mask])
            with torch.no_grad():
                accuracy = ((predictions[mask] - targets[mask]).abs() < 0.1).float().mean()
            return loss, {"loss": loss.item(), "accuracy": accuracy.item()}
        # Similar for other tasks...


# Usage example
encoder = nn.Sequential(nn.Linear(768, 768), nn.ReLU())
framework = SelfSupervisedEmbeddingFramework(encoder, pretext_task="masked")
batch = torch.randn(32, 512, 768)
inputs, targets, mask = framework.create_pretext_task(batch)
loss, metrics = framework.compute_loss(inputs, targets, mask)
print(f"Loss: {metrics['loss']:.4f}, Accuracy: {metrics['accuracy']:.4f}")
Loss: 1.0015, Accuracy: 0.0800
TipChoosing the Right Pretext Task

Masked prediction: Best for structured data with natural ordering (text, sequences, time-series). Captures bidirectional context.

Contrastive learning: Best when you can define meaningful augmentations. Works well for images, audio, multimodal data.

Reconstruction: Best for high-dimensional data where reconstruction is meaningful. Good for images, sensor data.

Rule of thumb: If your data has natural ordering, use masked prediction. If augmentations preserve semantics, use contrastive. If neither, try reconstruction.

17.1.2 Enterprise Self-Supervised Pipeline

Production self-supervised learning requires careful data management and training infrastructure:

Show Enterprise Self-Supervised Pipeline
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DistributedSampler


class EnterpriseSelfsupervisedPipeline:
    """Production self-supervised learning pipeline with distributed training,
    checkpointing, and monitoring.
    """

    def __init__(self, model, data_source, batch_size=256, num_workers=8,
                 checkpoint_dir="./checkpoints", log_dir="./logs"):
        self.model = model
        self.data_source = data_source
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.checkpoint_dir = checkpoint_dir
        self.log_dir = log_dir
        self.world_size = torch.cuda.device_count()
        self.is_distributed = self.world_size > 1

    def setup_distributed(self):
        """Initialize distributed training."""
        if self.is_distributed:
            dist.init_process_group(backend="nccl")
            local_rank = dist.get_rank()
            torch.cuda.set_device(local_rank)
            self.model = DistributedDataParallel(self.model, device_ids=[local_rank])

    def train(self, num_epochs=100, learning_rate=1e-4):
        """Train self-supervised model."""
        self.setup_distributed()
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=learning_rate, weight_decay=0.01)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

        for epoch in range(num_epochs):
            self.model.train()
            epoch_loss = 0
            # Training loop implementation...
            scheduler.step()
            if epoch % 10 == 0:
                self.save_checkpoint(epoch, epoch_loss / 1000)

    def save_checkpoint(self, epoch, loss):
        """Save model checkpoint."""
        checkpoint = {"epoch": epoch, "model_state_dict": self.model.state_dict(), "loss": loss}
        path = f"{self.checkpoint_dir}/checkpoint_epoch_{epoch}.pt"
        torch.save(checkpoint, path)
        print(f"Checkpoint saved: {path}")


# Usage example
model = torch.nn.Sequential(torch.nn.Linear(512, 768), torch.nn.ReLU())
pipeline = EnterpriseSelfsupervisedPipeline(model, data_source="s3://bucket/data")
# pipeline.train(num_epochs=10, learning_rate=1e-4)
print("Pipeline configured for distributed SSL training")
Pipeline configured for distributed SSL training
WarningProduction Considerations

Data Quality: Self-supervised learning amplifies data quality issues. Bad data → bad embeddings. Filter corrupted samples before training.

Compute Budget: Training on billions of samples requires significant compute. For 100M parameters × 1B tokens, expect 100-1000 GPU-hours.

Checkpoint Frequency: Save checkpoints every 1-2 hours of training (not epochs). Spot instance interruptions are common.

Monitoring: Track loss trends, gradient norms, and embedding quality metrics. Diverging loss indicates instability.

17.2 Masked Language Modeling for Domain-Specific Text

Masked Language Modeling (MLM), popularized by BERT, is the foundation of modern NLP. For enterprises, the key is adapting MLM to domain-specific vocabulary and writing styles.

17.2.1 MLM Fundamentals

The MLM objective: predict randomly masked tokens from surrounding context. This forces the model to learn bidirectional representations that capture semantic and syntactic patterns.

Show Domain-Specific MLM
import torch
from transformers import BertConfig, BertForMaskedLM, BertTokenizer, Trainer, TrainingArguments


class DomainSpecificMLM:
    """Masked Language Modeling for domain-specific text (legal, medical, financial, etc.)."""

    def __init__(self, domain="general", vocab_size=30000, hidden_size=768, num_layers=12, num_heads=12):
        self.domain = domain
        self.config = BertConfig(
            vocab_size=vocab_size, hidden_size=hidden_size,
            num_hidden_layers=num_layers, num_attention_heads=num_heads,
            intermediate_size=hidden_size * 4, max_position_embeddings=512
        )
        self.model = BertForMaskedLM(self.config)
        self.tokenizer = None

    def train_tokenizer(self, text_corpus, save_path="./tokenizer"):
        """Train domain-specific tokenizer - critical for specialized domains."""
        from tokenizers import Tokenizer
        from tokenizers.models import BPE
        from tokenizers.trainers import BpeTrainer

        tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
        trainer = BpeTrainer(
            vocab_size=self.config.vocab_size,
            special_tokens=["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"]
        )
        tokenizer.train_from_iterator(text_corpus, trainer=trainer)
        tokenizer.save(f"{save_path}/tokenizer.json")
        self.tokenizer = BertTokenizer.from_pretrained(save_path)
        print(f"Tokenizer trained and saved to {save_path}")

    def get_embeddings(self, texts, layer=-1):
        """Extract embeddings from trained model."""
        self.model.eval()
        inputs = self.tokenizer(texts, padding=True, truncation=True, max_length=512, return_tensors="pt")
        with torch.no_grad():
            outputs = self.model.bert(**inputs, output_hidden_states=True)
        embeddings = outputs.hidden_states[layer].mean(dim=1)
        return embeddings.numpy()


# Usage example
legal_mlm = DomainSpecificMLM(domain="legal", vocab_size=32000)
legal_corpus = ["The plaintiff filed a motion...", "Under tort law, negligence requires..."]
# legal_mlm.train_tokenizer(legal_corpus)
print(f"Domain: {legal_mlm.domain}, Vocab size: {legal_mlm.config.vocab_size}")
Domain: legal, Vocab size: 32000

17.2.2 Advanced MLM Techniques

For production deployments, basic MLM can be enhanced with several techniques:

Show Advanced MLM Techniques
import numpy as np
import torch
from transformers import BertForMaskedLM, BertTokenizer


class AdvancedMLM:
    """Advanced MLM with whole word masking, span masking, and entity-aware masking."""

    def __init__(self, base_model, tokenizer):
        self.model = base_model
        self.tokenizer = tokenizer

    def whole_word_masking(self, input_ids, mlm_probability=0.15):
        """Mask entire words instead of subword tokens for better semantics."""
        words = []
        current_word = []

        for idx, token_id in enumerate(input_ids):
            token = self.tokenizer.decode([token_id])
            if token.startswith("##"):
                current_word.append(idx)
            else:
                if current_word:
                    words.append(current_word)
                current_word = [idx]
        if current_word:
            words.append(current_word)

        num_words_to_mask = max(1, int(len(words) * mlm_probability))
        words_to_mask = np.random.choice(len(words), size=num_words_to_mask, replace=False)

        mask = torch.zeros_like(input_ids, dtype=torch.bool)
        for word_idx in words_to_mask:
            for token_idx in words[word_idx]:
                mask[token_idx] = True
        return mask

    def span_masking(self, input_ids, span_length=3, mlm_probability=0.15):
        """Mask contiguous spans of tokens for longer-range dependencies (SpanBERT)."""
        seq_len = len(input_ids)
        num_masks = int(seq_len * mlm_probability / span_length)
        mask = torch.zeros_like(input_ids, dtype=torch.bool)

        for _ in range(num_masks):
            start = np.random.randint(0, max(1, seq_len - span_length))
            mask[start:start + span_length] = True
        return mask


# Usage example
model = BertForMaskedLM.from_pretrained("bert-base-uncased")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
advanced_mlm = AdvancedMLM(model, tokenizer)
text = "The quick brown fox jumps over the lazy dog"
input_ids = tokenizer.encode(text, return_tensors="pt")[0]
mask = advanced_mlm.whole_word_masking(input_ids)
print(f"Masked {mask.sum().item()} tokens out of {len(input_ids)}")
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Masked 1 tokens out of 11
TipMLM Training Best Practices

Tokenizer First: Always train a domain-specific tokenizer before MLM. Generic tokenizers fragment domain terms.

Masking Strategy: Use whole-word masking for semantic learning, span masking for longer dependencies.

Adaptation vs. Scratch: If you have < 100M tokens, adapt pre-trained model. If > 1B tokens and very specialized domain, train from scratch.

Hyperparameters: Standard BERT hyperparameters (lr=5e-5, batch=32, warmup=10%) work well. For adaptation, use lr=2e-5.

Compute Budget: 100M parameters × 1B tokens ≈ 500 GPU-hours. Use mixed precision (fp16) to reduce by 2x.

17.3 Vision Transformers for Industrial Imagery

Vision Transformers (ViTs) combined with self-supervised learning enable training on unlabeled industrial imagery—manufacturing defects, medical scans, satellite images, security footage.

17.3.1 Self-Supervised Vision Transformers

Show Masked Autoencoder for Vision
import torch
import torch.nn as nn
from einops import rearrange


class MaskedAutoencoderViT(nn.Module):
    """Masked Autoencoder (MAE) for vision transformers - self-supervised image embeddings."""

    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768,
                 depth=12, num_heads=12, decoder_embed_dim=512, decoder_num_heads=8, decoder_depth=8, mask_ratio=0.75):
        super().__init__()
        self.patch_size = patch_size
        self.mask_ratio = mask_ratio
        num_patches = (img_size // patch_size) ** 2

        self.patch_embed = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.pos_embed = nn.Parameter(torch.randn(1, num_patches, embed_dim) * 0.02)
        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(embed_dim, num_heads, dim_feedforward=embed_dim * 4, batch_first=True),
            num_layers=depth
        )
        self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim)
        self.decoder = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(decoder_embed_dim, decoder_num_heads, dim_feedforward=decoder_embed_dim * 4, batch_first=True),
            num_layers=decoder_depth
        )
        self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size ** 2 * in_channels)

    def forward(self, x):
        # Patchify and embed
        x = self.patch_embed(x)
        x = rearrange(x, 'b c h w -> b (h w) c')
        x = x + self.pos_embed

        # Random masking
        B, N, D = x.shape
        len_keep = int(N * (1 - self.mask_ratio))
        noise = torch.rand(B, N, device=x.device)
        ids_shuffle = torch.argsort(noise, dim=1)
        ids_restore = torch.argsort(ids_shuffle, dim=1)
        ids_keep = ids_shuffle[:, :len_keep]
        x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))

        # Encode visible patches
        latent = self.encoder(x_masked)

        # Decode and reconstruct
        latent_full = torch.zeros(B, N, D, device=x.device)
        latent_full = torch.scatter(latent_full, 1, ids_keep.unsqueeze(-1).repeat(1, 1, D), latent)
        decoded = self.decoder_embed(latent_full)
        reconstructed = self.decoder_pred(self.decoder(decoded, decoded))

        return reconstructed, ids_restore


# Usage example
mae = MaskedAutoencoderViT(img_size=224, patch_size=16, mask_ratio=0.75)
images = torch.randn(4, 3, 224, 224)
reconstructed, ids = mae(images)
print(f"Input: {images.shape}, Reconstructed: {reconstructed.shape}")
print(f"Masked {mae.mask_ratio * 100}% of patches during training")
Input: torch.Size([4, 3, 224, 224]), Reconstructed: torch.Size([4, 196, 768])
Masked 75.0% of patches during training
TipViT Self-Supervision Best Practices

Mask Ratio: MAE uses 75% masking (aggressive!). This works because images have high redundancy. For specialized imagery (e.g., X-rays), try 50-60%.

Patch Size: Standard is 16x16 for 224x224 images. For higher resolution (512x512+), use 32x32 patches.

Augmentation: Strong augmentations (color jitter, blur) improve robustness. But avoid augmentations that change semantics (e.g., don’t flip medical images if orientation matters).

Compute: ViT-Base with MAE requires ~100 GPU-hours for 1M images. Use ViT-Small (5.7M params) for faster prototyping.

17.3.2 Industrial Vision Applications

Show Industrial Defect Detection
import torch
import torch.nn as nn


class IndustrialDefectDetection:
    """Self-supervised defect detection for manufacturing using image reconstruction."""

    def __init__(self, encoder_model, image_size=256, embedding_dim=512):
        self.encoder = encoder_model
        self.image_size = image_size
        self.embedding_dim = embedding_dim
        self.decoder = self._build_decoder()
        self.threshold = None

    def _build_decoder(self):
        return nn.Sequential(
            nn.Linear(self.embedding_dim, 1024), nn.ReLU(),
            nn.Linear(1024, 2048), nn.ReLU(),
            nn.Linear(2048, self.image_size * self.image_size * 3), nn.Sigmoid()
        )

    def train_on_normal_samples(self, normal_images, epochs=50, batch_size=32):
        """Train on defect-free samples to learn normal patterns."""
        optimizer = torch.optim.Adam(list(self.encoder.parameters()) + list(self.decoder.parameters()), lr=1e-4)
        criterion = nn.MSELoss()

        for epoch in range(epochs):
            total_loss = 0
            for i in range(0, len(normal_images), batch_size):
                batch = normal_images[i:i + batch_size]
                embeddings = self.encoder(batch)
                reconstructed = self.decoder(embeddings).view(-1, 3, self.image_size, self.image_size)
                loss = criterion(reconstructed, batch)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                total_loss += loss.item()
            if epoch % 10 == 0:
                print(f"Epoch {epoch}: Loss = {total_loss:.4f}")

        self._calibrate_threshold(normal_images)

    def _calibrate_threshold(self, normal_images):
        """Set anomaly threshold based on reconstruction errors of normal samples."""
        self.encoder.eval()
        self.decoder.eval()
        errors = []
        with torch.no_grad():
            for img in normal_images:
                emb = self.encoder(img.unsqueeze(0))
                recon = self.decoder(emb).view(-1, 3, self.image_size, self.image_size)
                error = ((recon - img.unsqueeze(0)) ** 2).mean().item()
                errors.append(error)
        self.threshold = torch.tensor(errors).mean() + 3 * torch.tensor(errors).std()

    def detect_defects(self, test_image):
        """Detect defects by comparing reconstruction error to threshold."""
        self.encoder.eval()
        self.decoder.eval()
        with torch.no_grad():
            emb = self.encoder(test_image.unsqueeze(0))
            recon = self.decoder(emb).view(-1, 3, self.image_size, self.image_size)
            error = ((recon - test_image.unsqueeze(0)) ** 2).mean().item()
        is_defect = error > self.threshold
        return is_defect, error


# Usage example
encoder = nn.Sequential(nn.Flatten(), nn.Linear(256 * 256 * 3, 512), nn.ReLU())
detector = IndustrialDefectDetection(encoder, image_size=256)
normal_samples = torch.randn(100, 3, 256, 256)
# detector.train_on_normal_samples(normal_samples, epochs=10)
print("Defect detector trained on normal samples only")
Defect detector trained on normal samples only

17.4 Time-Series Self-Supervision

Time-series data (sensor readings, financial data, user activity logs) presents unique self-supervision opportunities due to temporal structure.

17.4.1 Time-Series Pretext Tasks

Show Time Series Self-Supervised Learning
import torch
import torch.nn as nn


class TimeSeriesSelfSupervised:
    """Self-supervised learning for time series: masking, forecasting, contrastive learning."""

    def __init__(self, input_dim, hidden_dim=256, num_layers=4, task="forecasting"):
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.task = task
        self.encoder = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)

        if task == "forecasting":
            self.predictor = nn.Linear(hidden_dim, input_dim)
        elif task == "masking":
            self.predictor = nn.Linear(hidden_dim, input_dim)
        elif task == "contrastive":
            self.projector = nn.Sequential(nn.Linear(hidden_dim, 128), nn.ReLU(), nn.Linear(128, 64))

    def create_forecasting_task(self, timeseries, forecast_horizon=10):
        """Predict future values from past context."""
        context = timeseries[:, :-forecast_horizon, :]
        target = timeseries[:, -forecast_horizon:, :]
        return context, target

    def create_masking_task(self, timeseries, mask_ratio=0.15):
        """Mask random timesteps and predict them."""
        B, T, D = timeseries.shape
        mask = torch.rand(B, T, 1) < mask_ratio
        masked_series = timeseries.clone()
        masked_series[mask.expand_as(timeseries)] = 0
        return masked_series, timeseries, mask

    def forward(self, x):
        """Encode time series to embeddings."""
        _, (h_n, _) = self.encoder(x)
        return h_n[-1]

    def train_step(self, batch, optimizer):
        """Single training step for chosen task."""
        if self.task == "forecasting":
            context, target = self.create_forecasting_task(batch)
            embedding = self.forward(context)
            predictions = self.predictor(embedding).unsqueeze(1).repeat(1, target.size(1), 1)
            loss = nn.MSELoss()(predictions, target)
        elif self.task == "masking":
            masked, original, mask = self.create_masking_task(batch)
            output, _ = self.encoder(masked)
            predictions = self.predictor(output)
            loss = nn.MSELoss()(predictions[mask.expand_as(predictions)], original[mask.expand_as(original)])

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        return loss.item()


# Usage example
ts_model = TimeSeriesSelfSupervised(input_dim=10, hidden_dim=128, task="forecasting")
timeseries_data = torch.randn(32, 100, 10)
optimizer = torch.optim.Adam(ts_model.encoder.parameters(), lr=1e-3)
loss = ts_model.train_step(timeseries_data, optimizer)
print(f"Time series SSL loss: {loss:.4f}")
Time series SSL loss: 0.9780
TipTime-Series SSL Best Practices

Forecasting Horizon: For high-frequency data (milliseconds), predict 5-10 steps ahead. For slow-varying data (daily), predict 1-2 steps.

Masking Strategy: For bursty data (event logs), use random masking. For smooth data (temperature), use contiguous span masking.

Augmentations: Test augmentations carefully. Ensure they preserve semantic meaning (e.g., don’t shift phase of financial data).

Architecture: Transformers work well for long sequences (> 100 steps). For shorter sequences or limited compute, use LSTM/GRU.

17.5 Multi-Modal Self-Supervised Approaches

Multi-modal self-supervision learns from multiple data types simultaneously—text + images, audio + video, sensor + text logs.

17.5.1 CLIP-Style Multi-Modal Learning

Show Multimodal Self-Supervised Learning
import torch
import torch.nn as nn
import torch.nn.functional as F


class MultimodalSelfSupervised:
    """CLIP-style multimodal self-supervised learning for text-image alignment."""

    def __init__(self, text_encoder, image_encoder, embedding_dim=512, temperature=0.07):
        self.text_encoder = text_encoder
        self.image_encoder = image_encoder
        self.temperature = temperature
        self.text_projection = nn.Linear(embedding_dim, embedding_dim)
        self.image_projection = nn.Linear(embedding_dim, embedding_dim)

    def forward(self, text_inputs, image_inputs):
        """Compute embeddings for both modalities."""
        text_features = self.text_encoder(text_inputs)
        if text_features.dim() == 3:  # (batch, seq, dim) -> mean pool to (batch, dim)
            text_features = text_features.mean(dim=1)
        text_embeds = self.text_projection(text_features)
        text_embeds = F.normalize(text_embeds, dim=-1)

        image_features = self.image_encoder(image_inputs)
        image_embeds = self.image_projection(image_features)
        image_embeds = F.normalize(image_embeds, dim=-1)

        return text_embeds, image_embeds

    def contrastive_loss(self, text_embeds, image_embeds):
        """Symmetric contrastive loss (text-to-image and image-to-text)."""
        logits = torch.matmul(text_embeds, image_embeds.T) / self.temperature
        labels = torch.arange(len(text_embeds), device=logits.device)

        loss_t2i = F.cross_entropy(logits, labels)
        loss_i2t = F.cross_entropy(logits.T, labels)
        return (loss_t2i + loss_i2t) / 2

    def train_step(self, text_batch, image_batch, optimizer):
        """Train on paired text-image data."""
        text_embeds, image_embeds = self.forward(text_batch, image_batch)
        loss = self.contrastive_loss(text_embeds, image_embeds)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        with torch.no_grad():
            logits = torch.matmul(text_embeds, image_embeds.T)
            predictions = logits.argmax(dim=-1)  # For each text, which image is most similar?
            labels = torch.arange(len(text_embeds), device=logits.device)
            accuracy = (predictions == labels).float().mean()

        return loss.item(), accuracy.item()


# Usage example
text_enc = nn.Sequential(nn.Embedding(10000, 512), nn.Linear(512, 512), nn.ReLU())
image_enc = nn.Sequential(nn.Flatten(), nn.Linear(224 * 224 * 3, 512), nn.ReLU())
multimodal = MultimodalSelfSupervised(text_enc, image_enc, embedding_dim=512)

text = torch.randint(0, 10000, (32, 50))
images = torch.randn(32, 3, 224, 224)
optimizer = torch.optim.Adam(list(multimodal.text_encoder.parameters()) + list(multimodal.image_encoder.parameters()), lr=1e-4)
loss, acc = multimodal.train_step(text, images, optimizer)
print(f"Multimodal loss: {loss:.4f}, Alignment accuracy: {acc:.2%}")
Multimodal loss: 3.4942, Alignment accuracy: 0.00%
TipMulti-Modal SSL Best Practices

Pairing Quality: The quality of modality pairs matters more than quantity. 10M high-quality pairs > 100M noisy pairs.

Batch Size: Larger batches provide more negative samples. Use at least 256, ideally 1024+ with gradient accumulation.

Temperature: Start with 0.07. Lower (0.01) for fine-grained matching, higher (0.2) for coarse similarity.

Modality Balance: If one modality is much noisier, consider weighted loss or filtering poor pairs.

Compute: CLIP-scale training (400M pairs) requires thousands of GPU-hours. For enterprise, 1M-10M pairs often sufficient.

17.6 Key Takeaways

  • Self-supervised learning unlocks unlabeled data at unprecedented scale. No manual annotation needed—data structure provides supervision through pretext tasks.

  • Masked Language Modeling is the foundation for domain-specific text embeddings. Always train a domain-specific tokenizer first, then adapt or train MLM on your corpus.

  • Vision Transformers with Masked Autoencoding (MAE) enable learning from unlabeled images with 75% masking. Ideal for manufacturing defects, medical imaging, and satellite imagery where labels are scarce.

  • Time-series self-supervision uses forecasting, masked reconstruction, or contrastive tasks. Choose based on data characteristics: forecasting for ordered data, contrastive for augmentable data.

  • Multi-modal self-supervision creates shared embedding spaces across text, images, audio, and sensors without paired labels. Contrastive learning between modalities is highly effective.

  • Production deployment requires distributed training, checkpointing, and careful data management. For 100M parameters × 1B samples, expect 500-1000 GPU-hours.

17.7 Looking Ahead

In Chapter 18, we explore advanced embedding techniques that push beyond standard architectures—hierarchical embeddings for taxonomies, dynamic embeddings that evolve over time, compositional embeddings for combinatorial spaces, and quantum-inspired embeddings for ultra-high-dimensional data. These techniques unlock capabilities impossible with standard approaches.

17.8 Further Reading

  • Devlin, J., et al. (2018). “BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding.” NAACL.
  • He, K., et al. (2021). “Masked Autoencoders Are Scalable Vision Learners.” CVPR.
  • Radford, A., et al. (2021). “Learning Transferable Visual Models From Natural Language Supervision.” ICML.
  • Chen, T., et al. (2020). “A Simple Framework for Contrastive Learning of Visual Representations.” ICML.
  • Oord, A., Li, Y., & Vinyals, O. (2018). “Representation Learning with Contrastive Predictive Coding.” arXiv:1807.03748.
  • Liu, Y., et al. (2019). “RoBERTa: A Robustly Optimized BERT Pretraining Approach.” arXiv:1907.11692.
  • Dosovitskiy, A., et al. (2020). “An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale.” ICLR.
  • Franceschi, J.Y., et al. (2019). “Unsupervised Scalable Representation Learning for Multivariate Time Series.” NeurIPS.