20  Scaling Embedding Training

NoteChapter Overview

Training embedding models on trillion-row datasets requires computational infrastructure that goes far beyond single-GPU training. This chapter explores the architectures and techniques that enable embedding training at unprecedented scale: distributed training across hundreds of GPUs, gradient accumulation and mixed precision for memory efficiency, advanced memory optimization techniques, multi-GPU and multi-node coordination strategies, and cost optimization approaches that make large-scale training economically viable. These techniques transform embedding training from a multi-day single-machine task to a multi-hour distributed operation, enabling rapid iteration and larger, more powerful models.

Embedding model training faces unique scaling challenges. Unlike image classification models that process fixed-size inputs, embedding models often work with variable-length sequences, sparse features, and massive vocabularies. Contrastive learning requires large batch sizes (4K-32K samples) for effective negative sampling. Self-supervised pre-training demands processing billions of documents. These requirements push standard training infrastructure to its limits, requiring specialized techniques for efficient distributed training.

20.1 Distributed Training Architectures

Distributed training parallelizes model training across multiple devices, reducing training time from weeks to hours. However, embedding training has unique requirements that distinguish it from standard distributed training: large batch sizes for contrastive learning, sparse feature handling, vocabulary parallelism for large embedding tables, and efficient negative sampling across devices. This section explores architectures that address these challenges.

20.1.1 Parallelism Strategies for Embedding Training

Modern distributed training employs multiple parallelism strategies simultaneously:

Show Distributed Embedding Table
import torch
import torch.distributed as dist
import torch.nn as nn


class DistributedEmbeddingTable(nn.Module):
    """Model-parallel embedding table for large vocabularies split across GPUs."""

    def __init__(self, total_vocab_size, embedding_dim, world_size, rank):
        super().__init__()
        self.total_vocab_size = total_vocab_size
        self.embedding_dim = embedding_dim
        self.world_size = world_size
        self.rank = rank

        # Each GPU holds a slice of vocabulary
        self.vocab_per_gpu = total_vocab_size // world_size
        self.vocab_start = rank * self.vocab_per_gpu
        self.vocab_end = (rank + 1) * self.vocab_per_gpu

        # Local embedding table (subset of vocabulary)
        self.embeddings = nn.Embedding(self.vocab_per_gpu, embedding_dim)
        print(f"Rank {rank}: Vocabulary [{self.vocab_start}, {self.vocab_end})")

    def forward(self, input_ids):
        """Lookup embeddings across distributed vocabulary."""
        batch_size, seq_len = input_ids.shape
        output = torch.zeros(batch_size, seq_len, self.embedding_dim, device=input_ids.device)

        # Mask for tokens this GPU is responsible for
        local_mask = (input_ids >= self.vocab_start) & (input_ids < self.vocab_end)

        if local_mask.any():
            local_ids = input_ids[local_mask] - self.vocab_start
            local_embeddings = self.embeddings(local_ids)
            output[local_mask] = local_embeddings

        # All-reduce: Sum embeddings from all GPUs
        dist.all_reduce(output, op=dist.ReduceOp.SUM)
        return output


# Usage example (conceptual - requires distributed setup)
# model = DistributedEmbeddingTable(total_vocab_size=100000, embedding_dim=512, world_size=8, rank=0)
print("Distributed embedding table for model-parallel training")
Distributed embedding table for model-parallel training

For multi-GPU training with PyTorch’s distributed module, you typically launch with torchrun:

# Single node, 8 GPUs
torchrun --nproc_per_node=8 train.py

# Multi-node (4 nodes, 8 GPUs each)
torchrun --nproc_per_node=8 --nnodes=4 --node_rank=0 \
         --master_addr=node0 --master_port=1234 train.py
Show Distributed Training Example
import torch
import torch.nn as nn


class DistributedContrastiveEmbedding(nn.Module):
    """Embedding model for distributed contrastive training."""
    def __init__(self, vocab_size, embedding_dim):
        super().__init__()
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.projection = nn.Linear(embedding_dim, embedding_dim)

    def forward(self, ids):
        return self.projection(self.embeddings(ids))


class DistributedTrainer:
    """Trainer for distributed embedding model."""
    def __init__(self, model, local_rank, world_size):
        self.model = model
        self.device = f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu"
        self.model.to(self.device)
        self.world_size = world_size
        self.local_rank = local_rank

    def train_step(self, batch, optimizer):
        anchor = self.model(batch['anchor_ids'])
        positive = self.model(batch['positive_ids'])
        loss = nn.functional.mse_loss(anchor, positive)  # Simplified
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        return loss.item()

    def save_checkpoint(self, path, epoch, optimizer):
        torch.save({'epoch': epoch, 'model': self.model.state_dict()}, path)

    def cleanup(self):
        pass  # In practice: dist.destroy_process_group()


# Initialize distributed trainer
model = DistributedContrastiveEmbedding(vocab_size=100000, embedding_dim=512)
trainer = DistributedTrainer(model=model, local_rank=0, world_size=1)

# Optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)

# Training step demo
batch = {
    'anchor_ids': torch.randint(0, 100000, (256,), device=trainer.device),
    'positive_ids': torch.randint(0, 100000, (256,), device=trainer.device)
}
loss = trainer.train_step(batch, optimizer)
print(f"Training step loss: {loss:.4f}")
Training step loss: 0.6637
TipChoosing the Right Parallelism Strategy

Use Data Parallelism when:

  • Model fits on single GPU
  • Batch size is primary bottleneck
  • Most layers are data-parallel friendly (convolutions, transformers)

Add Model Parallelism when:

  • Embedding tables > GPU memory (100M+ vocabulary)
  • Single layer > GPU memory (very wide transformer layers)

Add Pipeline Parallelism when:

  • Model depth > memory capacity (100+ transformer layers)
  • High arithmetic intensity (can hide communication latency)

For embedding training:

  • Start with Data Parallelism for encoder
  • Add Model Parallelism for large embedding tables
  • Consider Pipeline Parallelism for deep architectures (BERT-Large, GPT-3 scale)
WarningCommunication Bottlenecks

Distributed training speedup is limited by communication:

  • All-reduce (gradient sync): O(parameters × world_size)
  • All-gather (activations): O(batch_size × hidden_dim × world_size)
  • Point-to-point (pipeline): O(hidden_dim × micro_batch_size)

Optimizations:

  • Gradient compression: Reduce precision (FP32 → FP16 gradients)
  • Overlap communication and computation: Backward pass while communicating gradients
  • Hierarchical reduction: Node-local reduction, then cross-node
  • Faster interconnect: InfiniBand (200 Gbps) vs Ethernet (10-100 Gbps)

20.2 Gradient Accumulation and Mixed Precision

Memory is the primary constraint in deep learning training. A single NVIDIA A100 GPU has 80GB memory, yet training large embedding models with contrastive learning (32K batch size × 512 dims × 4 bytes ≈ 64GB just for embeddings) quickly exceeds capacity. Gradient accumulation enables large effective batch sizes by splitting batches into smaller micro-batches, while mixed precision reduces memory footprint and accelerates computation by using FP16 for most operations while maintaining FP32 for numerical stability.

20.2.1 Gradient Accumulation for Large Batch Training

Contrastive learning benefits from large batch sizes—more negatives improve representation quality. But memory limits batch size. Gradient accumulation solves this:

Show Gradient Accumulation Trainer
import torch
import torch.nn as nn


class GradientAccumulationTrainer:
    """Enable large effective batch sizes through gradient accumulation."""

    def __init__(self, model, accumulation_steps=4):
        self.model = model
        self.accumulation_steps = accumulation_steps

    def train_step(self, dataloader, optimizer, device="cuda"):
        """Training step with gradient accumulation."""
        self.model.train()
        optimizer.zero_grad()
        total_loss = 0.0

        for i, batch in enumerate(dataloader):
            if i >= self.accumulation_steps:
                break

            anchor_ids = batch["anchor_ids"].to(device)
            positive_ids = batch["positive_ids"].to(device)

            loss = self.model(anchor_ids, positive_ids)
            loss = loss / self.accumulation_steps  # Scale loss
            loss.backward()  # Accumulate gradients
            total_loss += loss.item()

        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
        optimizer.step()
        optimizer.zero_grad()

        return total_loss


# Usage example
# model = EmbeddingModel()
# trainer = GradientAccumulationTrainer(model, accumulation_steps=32)
print("Gradient accumulation enables 32K+ effective batch sizes")
Gradient accumulation enables 32K+ effective batch sizes

20.2.2 Mixed Precision Training

Modern GPUs (Volta, Turing, Ampere architectures) have specialized Tensor Cores that accelerate FP16 matrix multiplications by 2-8×. Mixed precision uses FP16 for computation while maintaining FP32 for numerical stability:

Show Mixed Precision Trainer
import torch
import torch.nn as nn
from torch.cuda.amp import GradScaler, autocast


class MixedPrecisionTrainer:
    """Automatic mixed precision (AMP) training for 1.5-2x speedup (workload-dependent)."""

    def __init__(self, model, device="cuda"):
        self.model = model.to(device)
        self.device = device
        self.scaler = GradScaler()

    def train_step(self, batch, optimizer):
        """Training step with automatic mixed precision."""
        self.model.train()
        anchor_ids = batch["anchor_ids"].to(self.device)
        positive_ids = batch["positive_ids"].to(self.device)

        optimizer.zero_grad()

        # Forward pass in FP16
        with autocast():
            loss = self.model(anchor_ids, positive_ids)

        # Backward with gradient scaling
        self.scaler.scale(loss).backward()
        self.scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
        self.scaler.step(optimizer)
        self.scaler.update()

        return loss.item()


# Usage example
# trainer = MixedPrecisionTrainer(model)
print("Mixed precision training: 1.5-2x speedup typical on modern GPUs")
Mixed precision training: 1.5-2x speedup typical on modern GPUs
TipWhen to Use Gradient Accumulation vs Larger Hardware

Use gradient accumulation when:

  • Memory-constrained (batch won’t fit on GPU)
  • Want to experiment with very large batches (64K+)
  • Training on cloud instances with limited GPU memory

Upgrade hardware when:

  • Wall-clock time is critical (accumulation is slower)
  • Training very frequently (hardware cost amortizes)
  • Need to scale beyond single node (distributed > accumulation)

Use mixed precision almost always:

  • Modern GPUs (V100, A100) have Tensor Cores
  • 1.5-2× speedup with minimal code changes
  • Rarely causes numerical issues (except very deep networks)
WarningMixed Precision Gotchas

Gradient underflow: Very small gradients (< 1e-7) round to zero in FP16. Gradient scaling addresses this, but extreme cases may need:

  • Larger learning rates
  • Loss scaling adjustments
  • FP32 for sensitive layers (layer norm, softmax)

Batch normalization: BatchNorm statistics in FP16 can be unstable. Use FP32 for BatchNorm layers:

model = model.half()  # Convert to FP16
# Keep BatchNorm in FP32
for module in model.modules():
    if isinstance(module, nn.BatchNorm1d):
        module.float()

20.3 Memory Optimization Techniques

Beyond mixed precision and gradient accumulation, several techniques reduce memory footprint, enabling larger models and batch sizes:

20.3.1 Gradient Checkpointing

Trade computation for memory by recomputing activations during backward pass instead of storing them:

Show Gradient Checkpointing
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint


class CheckpointedTransformerLayer(nn.Module):
    """Transformer layer with gradient checkpointing for memory efficiency."""

    def __init__(self, hidden_dim, num_heads):
        super().__init__()
        self.attention = nn.MultiheadAttention(hidden_dim, num_heads)
        self.ffn = nn.Sequential(nn.Linear(hidden_dim, hidden_dim * 4), nn.ReLU(), nn.Linear(hidden_dim * 4, hidden_dim))
        self.norm1 = nn.LayerNorm(hidden_dim)
        self.norm2 = nn.LayerNorm(hidden_dim)

    def forward(self, x):
        """Forward with gradient checkpointing to save memory."""
        def attention_forward(x):
            attn_out, _ = self.attention(x, x, x)
            return self.norm1(x + attn_out)

        x = checkpoint(attention_forward, x)

        def ffn_forward(x):
            return self.norm2(x + self.ffn(x))

        x = checkpoint(ffn_forward, x)
        return x


# Usage example
# layer = CheckpointedTransformerLayer(hidden_dim=512, num_heads=8)
print("Gradient checkpointing: 10-50x memory reduction")
Gradient checkpointing: 10-50x memory reduction

20.3.2 Optimizer State Optimization

Adam optimizer stores momentum and variance for each parameter, tripling memory usage. Optimizations:

Show Memory-Efficient Optimizer
import torch
import torch.optim as optim


class MemoryEfficientOptimizer:
    """Optimize memory usage for large models using efficient optimizers."""

    @staticmethod
    def get_optimizer(parameters, optimizer_type="adamw", lr=0.001):
        """Get memory-efficient optimizer."""
        if optimizer_type == "adamw":
            return optim.AdamW(parameters, lr=lr, fused=True)
        elif optimizer_type == "sgd":
            return optim.SGD(parameters, lr=lr, momentum=0.9, nesterov=True)
        elif optimizer_type == "8bit_adam":
            try:
                import bitsandbytes as bnb
                return bnb.optim.Adam8bit(parameters, lr=lr)
            except ImportError:
                print("bitsandbytes not installed, using AdamW")
                return optim.AdamW(parameters, lr=lr)


# Usage example
# params = model.parameters()
# optimizer = MemoryEfficientOptimizer.get_optimizer(params, "8bit_adam")
print("8-bit optimizers: 4x memory reduction vs standard Adam")
8-bit optimizers: 4x memory reduction vs standard Adam
TipMemory Optimization Checklist

When hitting memory limits, apply optimizations in this order:

  1. Mixed precision (FP16): 2× memory reduction, 1.5-2× speedup
  2. Gradient accumulation: Enables larger effective batch sizes
  3. Gradient checkpointing: 10-50× activation memory reduction
  4. Optimizer state optimization: 8-bit Adam or SGD
  5. Model parallelism: Split model across GPUs
  6. Batch size reduction: Last resort (hurts contrastive learning)

Typical savings:

  • FP16: 40GB → 20GB
    • Checkpointing: 20GB → 8GB
    • 8-bit optimizer: 8GB → 5GB
  • Result: Fit on single A100 (80GB) with large batch

20.4 Multi-GPU and Multi-Node Strategies

Scaling beyond single GPU requires coordination across devices. This section covers practical strategies for multi-GPU (single node) and multi-node (multiple machines) training.

20.4.1 Multi-GPU Training on Single Node

Single-node multi-GPU is the most common setup (8× A100 or V100 GPUs on one machine):

Show Distributed Data Loading
import torch
from torch.utils.data import Dataset, DataLoader, DistributedSampler


class EmbeddingDataset(Dataset):
    """Dataset for efficient distributed embedding training."""

    def __init__(self, data_path, sequence_length=512):
        self.data_path = data_path
        self.sequence_length = sequence_length
        self.data = []  # Load from data_path in practice

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return {"anchor_ids": torch.randint(0, 50000, (self.sequence_length,)),
                "positive_ids": torch.randint(0, 50000, (self.sequence_length,))}


def setup_distributed_dataloaders(dataset, batch_size, world_size, rank):
    """Create distributed dataloaders with proper sharding."""
    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True)

    dataloader = DataLoader(dataset, batch_size=batch_size, sampler=sampler,
                            num_workers=4, pin_memory=True, prefetch_factor=2)
    return dataloader


# Usage example
# dataset = EmbeddingDataset("data.parquet")
# loader = setup_distributed_dataloaders(dataset, batch_size=256, world_size=8, rank=0)
print("Distributed dataloaders ensure each GPU sees unique data")
Distributed dataloaders ensure each GPU sees unique data

20.4.2 Multi-Node Training

Multi-node training scales to hundreds of GPUs across dozens of machines:

Show Multi-Node Training Setup
import os
import torch
import torch.distributed as dist


def setup_multi_node():
    """Initialize multi-node distributed training environment."""
    rank = int(os.environ.get('RANK', 0))
    world_size = int(os.environ.get('WORLD_SIZE', 1))
    local_rank = int(os.environ.get('LOCAL_RANK', 0))
    master_addr = os.environ.get('MASTER_ADDR', 'localhost')
    master_port = os.environ.get('MASTER_PORT', '12355')

    os.environ['MASTER_ADDR'] = master_addr
    os.environ['MASTER_PORT'] = master_port

    print(f"Initializing process group: rank={rank}, world_size={world_size}, local_rank={local_rank}")

    dist.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank)

    torch.cuda.set_device(local_rank)
    return rank, world_size, local_rank


# Usage example - run with torchrun:
# torchrun --nproc_per_node=8 --nnodes=4 --node_rank=$NODE_RANK train.py
# rank, world_size, local_rank = setup_multi_node()
print("Multi-node setup: coordinate training across multiple machines")
Multi-node setup: coordinate training across multiple machines

For multi-node training, use SLURM or torchrun to launch across machines:

# SLURM submission
sbatch --nodes=4 --gres=gpu:8 train_multi_node.sh

# Or with torchrun on each node:
torchrun --nproc_per_node=8 \
         --nnodes=4 \
         --node_rank=$NODE_RANK \
         --master_addr=$MASTER_ADDR \
         --master_port=1234 \
         train_script.py
TipMulti-GPU Best Practices

Data loading:

  • Use DistributedSampler to partition data across GPUs
  • Set num_workers=4 per GPU for async data loading
  • Use pin_memory=True for faster CPU→GPU transfer

Learning rate scaling:

  • Scale learning rate linearly with batch size
  • 1 GPU (batch 512, lr 0.001) → 8 GPUs (batch 4096, lr 0.008)
  • May need warmup for large learning rates

Synchronization:

  • Minimize dist.barrier() calls (blocks all GPUs)
  • Overlap communication with computation
  • Use find_unused_parameters=False in DDP when possible

Checkpointing:

  • Only save from rank 0 to avoid duplicate writes
  • Use dist.barrier() after saving to synchronize
  • Consider sharded checkpointing for very large models
WarningMulti-Node Challenges

Network bottlenecks:

  • Cross-node communication 10-100× slower than NVLink
  • Use gradient compression or ZeRO optimizer
  • Consider hierarchical all-reduce (node-local first)

Fault tolerance:

  • Single node failure kills entire job
  • Implement checkpointing every N steps
  • Use elastic training frameworks (TorchElastic)

Load imbalance:

  • Stragglers slow down entire cluster
  • Monitor per-GPU utilization
  • Use dynamic batch sizing if variability high

20.5 Training Cost Optimization

Large-scale training is expensive. A 100-GPU training run can cost $10K-$100K. This section covers strategies to minimize cost while maintaining quality.

20.5.1 Cloud Cost Optimization

Show Cloud Cost Optimization
import time
from datetime import datetime


class CloudCostOptimizer:
    """Optimize training costs through instance selection and resource management."""

    def __init__(self, budget_per_hour=100.0):
        self.budget_per_hour = budget_per_hour
        self.instance_costs = {
            "p3.2xlarge": 3.06,    # V100
            "p4d.24xlarge": 32.77,  # A100
            "g5.xlarge": 1.006      # T4
        }

    def select_instance_config(self, target_gpus, prefer_cost=True):
        """Select optimal instance configuration based on budget and requirements."""
        configs = []

        for instance_type, hourly_cost in self.instance_costs.items():
            gpus_per_instance = {"p3.2xlarge": 1, "p4d.24xlarge": 8, "g5.xlarge": 1}[instance_type]

            num_instances = (target_gpus + gpus_per_instance - 1) // gpus_per_instance
            total_cost = num_instances * hourly_cost

            if total_cost <= self.budget_per_hour:
                configs.append({"instance_type": instance_type, "num_instances": num_instances,
                                "total_gpus": num_instances * gpus_per_instance, "hourly_cost": total_cost})

        if prefer_cost:
            configs.sort(key=lambda x: x["hourly_cost"])
        else:
            configs.sort(key=lambda x: -x["total_gpus"])

        return configs[0] if configs else None


# Usage example
optimizer = CloudCostOptimizer(budget_per_hour=50.0)
config = optimizer.select_instance_config(target_gpus=8, prefer_cost=True)
print(f"Optimal config: {config}")
Optimal config: {'instance_type': 'g5.xlarge', 'num_instances': 8, 'total_gpus': 8, 'hourly_cost': 8.048}

20.5.2 Spot Instance Training

Spot instances offer 50-90% discounts but can be preempted. Strategies for resilient training:

Show Spot Instance Training
import time
import torch


class SpotInstanceTrainer:
    """Training with spot instance resilience via frequent checkpointing."""

    def __init__(self, model, checkpoint_interval=300):
        self.model = model
        self.checkpoint_interval = checkpoint_interval
        self.last_checkpoint = time.time()

    def train(self, dataloader, optimizer, epochs=10):
        """Train with automatic checkpointing for spot instance resilience."""
        for epoch in range(epochs):
            for batch_idx, batch in enumerate(dataloader):
                try:
                    loss = self.model(batch)
                    loss.backward()
                    optimizer.step()
                    optimizer.zero_grad()

                    # Checkpoint every N seconds
                    if time.time() - self.last_checkpoint > self.checkpoint_interval:
                        self.save_checkpoint(epoch, batch_idx)
                        self.last_checkpoint = time.time()

                except RuntimeError as e:
                    if "preempted" in str(e).lower():
                        print("Spot instance preempted! Checkpoint saved.")
                        self.save_checkpoint(epoch, batch_idx)
                        raise
                    else:
                        raise

    def save_checkpoint(self, epoch, batch_idx):
        """Save checkpoint for recovery."""
        checkpoint = {"epoch": epoch, "batch_idx": batch_idx, "model_state": self.model.state_dict()}
        torch.save(checkpoint, f"checkpoint_epoch{epoch}_batch{batch_idx}.pt")
        print(f"Checkpoint saved: epoch {epoch}, batch {batch_idx}")

    def load_checkpoint(self, checkpoint_path):
        """Resume from checkpoint."""
        checkpoint = torch.load(checkpoint_path)
        self.model.load_state_dict(checkpoint["model_state"])
        return checkpoint["epoch"], checkpoint["batch_idx"]


# Usage example
# trainer = SpotInstanceTrainer(model, checkpoint_interval=300)
print("Spot instance training: 50-90% cost savings with checkpointing")
Spot instance training: 50-90% cost savings with checkpointing
TipCost Optimization Strategies

Immediate savings (no quality impact): 1. Spot instances: 50-90% discount (with checkpointing) 2. Mixed precision: 1.5-2× speedup → 40-60% cost reduction 3. Reserved instances: 30-50% discount for long-term projects 4. Multi-cloud: Compare prices across AWS/GCP/Azure

Advanced optimizations: 1. Early stopping: Halt when validation loss plateaus 2. Hyperparameter search efficiency: Use Bayesian optimization, not grid search 3. Model distillation: Train large model, deploy small model 4. Sparse training: Train only subset of parameters

Typical cost breakdown (100-GPU training):

  • Hardware: 70% (can optimize with spot instances)
  • Storage: 10% (use cheaper object storage)
  • Network: 10% (minimize cross-region transfer)
  • Other: 10% (monitoring, logging, etc.)

20.6 Key Takeaways

  • Distributed training is essential at scale: Data parallelism for throughput, model parallelism for large embedding tables, and pipeline parallelism for deep architectures combine to enable trillion-row training in reasonable time

  • Gradient accumulation enables large effective batch sizes: Split large batches into micro-batches to fit memory constraints while maintaining the benefits of large-batch contrastive learning (16K-32K samples)

  • Mixed precision training provides 1.5-2× speedup: FP16 computation on Tensor Cores with FP32 master weights maintains numerical stability while reducing memory usage and accelerating training (actual speedup is workload-dependent)

  • Memory optimization unlocks larger models: Gradient checkpointing, optimizer state quantization (8-bit Adam), and efficient activation management reduce memory footprint by 10-50×, enabling BERT-scale models on single GPUs

  • Multi-node training scales to hundreds of GPUs: Proper configuration of distributed samplers, learning rate scaling, and network topology awareness enable near-linear scaling to 64+ GPUs with 40-50× speedup

  • Cost optimization is critical for sustainable training: Spot instances (50-90% savings), mixed precision speedup, and efficient checkpointing reduce training costs from $100K to $10K-$30K for large models

  • Communication is the bottleneck at scale: Gradient synchronization, activation gathering, and cross-node communication limit speedup; overlap computation with communication and use gradient compression to mitigate

20.7 Looking Ahead

This chapter covered the computational techniques for training embedding models at scale. Chapter 21 addresses a critical question: how do you know if your embeddings are good? We explore intrinsic quality metrics (isotropy, uniformity, alignment), comprehensive retrieval metrics (Recall@K, MAP, NDCG, MRR), human evaluation frameworks, domain-specific metrics, and statistical rigor—providing the measurement foundation for continuous improvement.

20.8 Further Reading

20.8.1 Distributed Training

  • Li, Shen, et al. (2020). “PyTorch Distributed: Experiences on Accelerating Data Parallel Training.” VLDB.
  • Shoeybi, Mohammad, et al. (2019). “Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism.” arXiv:1909.08053.
  • Rajbhandari, Samyam, et al. (2020). “ZeRO: Memory Optimizations Toward Training Trillion Parameter Models.” SC20.

20.8.2 Mixed Precision Training

  • Micikevicius, Paulius, et al. (2018). “Mixed Precision Training.” ICLR.
  • Narang, Sharan, et al. (2018). “Mixed Precision Training With 8-bit Floating Point.” arXiv:1905.12334.

20.8.3 Memory Optimization

  • Chen, Tianqi, et al. (2016). “Training Deep Nets with Sublinear Memory Cost.” arXiv:1604.06174.
  • Sohoni, Nimit, et al. (2019). “Low-Memory Neural Network Training.” arXiv:1904.10631.
  • Dettmers, Tim, et al. (2022). “8-bit Optimizers via Block-wise Quantization.” arXiv:2110.02861.

20.8.4 Large-Scale Training Systems

  • Jia, Xianyan, et al. (2018). “Highly Scalable Deep Learning Training System with Mixed-Precision.” arXiv:1807.11205.
  • Sergeev, Alexander, and Mike Del Balso (2018). “Horovod: Fast and Easy Distributed Deep Learning in TensorFlow.” arXiv:1802.05799.
  • Paszke, Adam, et al. (2019). “PyTorch: An Imperative Style, High-Performance Deep Learning Library.” NeurIPS.

20.8.5 Cost Optimization

  • Chaudhary, Vinay, et al. (2020). “Balancing Efficiency and Flexibility for DNN Acceleration via Temporal GPU-Systolic Array Integration.” DAC.
  • Yang, Tianyi, et al. (2021). “Toward Efficient Deep Learning in the Cloud: Resource Provisioning and Workload Scheduling.” IEEE Cloud Computing.