16  Siamese Networks for Specialized Use Cases

NoteChapter Overview

While contrastive learning (Chapter 5) taught us how to train embeddings that distinguish similar from dissimilar, Siamese networks provide the architectural foundation for similarity-based learning at enterprise scale. This chapter explores Siamese architectures—twin neural networks that excel at learning similarity metrics for specialized use cases including one-shot learning, anomaly detection, and verification systems. We cover the architectural patterns, triplet loss optimization, strategies for rare event handling, threshold calibration techniques, and production deployment patterns that enable Siamese networks to scale to trillion-row deployments.

16.1 Siamese Architecture for Enterprise Similarity

Siamese networks solve a fundamental challenge: how do you learn similarity when you have few examples per class, unbalanced distributions, or continuously evolving categories? Traditional classifiers fail in these scenarios. Siamese networks succeed by learning to compare rather than classify.

16.1.1 The Siamese Paradigm

Named after Siamese twins, a Siamese network consists of two or more identical neural networks (sharing weights) that process different inputs and compare their outputs. The key insight: instead of learning “what is X?”, learn “how similar are X and Y?”

This shift enables:

  • Few-shot learning: Learn from 1-5 examples per class
  • Open-set recognition: Handle classes not seen during training
  • Verification tasks: “Are these the same?” vs “What is this?”
  • Similarity search: Find nearest neighbors in learned space
import torch
import torch.nn as nn
import torch.nn.functional as F

class SiameseNetwork(nn.Module):
    """
    Siamese Network for learning similarity metrics

    Architecture: Two identical networks (shared weights) process different
    inputs, producing embeddings that are compared using a distance metric.

    Use cases:
    - Face verification: "Is this the same person?"
    - Document similarity: "Are these papers related?"
    - Product matching: "Are these the same item?"
    - Anomaly detection: "Is this different from normal?"
    """

    def __init__(self, embedding_net, embedding_dim=512):
        """
        Args:
            embedding_net: The base network for creating embeddings
                          (e.g., ResNet, BERT, custom architecture)
            embedding_dim: Dimension of output embeddings
        """
        super().__init__()
        self.embedding_net = embedding_net
        self.embedding_dim = embedding_dim

    def forward(self, x1, x2):
        """
        Forward pass through Siamese network

        Args:
            x1: First input (batch_size, ...)
            x2: Second input (batch_size, ...)

        Returns:
            embedding1: Embeddings for x1 (batch_size, embedding_dim)
            embedding2: Embeddings for x2 (batch_size, embedding_dim)
        """
        # Both inputs go through the SAME network (shared weights)
        embedding1 = self.embedding_net(x1)
        embedding2 = self.embedding_net(x2)

        return embedding1, embedding2

    def get_embedding(self, x):
        """Get embedding for a single input"""
        return self.embedding_net(x)


class EmbeddingNet(nn.Module):
    """
    Example embedding network for structured/tabular data

    For images: Use ResNet, EfficientNet, Vision Transformer
    For text: Use BERT, RoBERTa, sentence transformers
    For multimodal: Use CLIP-style architectures
    """

    def __init__(self, input_dim, embedding_dim=512, hidden_dims=[1024, 512]):
        super().__init__()

        layers = []
        prev_dim = input_dim

        for hidden_dim in hidden_dims:
            layers.extend([
                nn.Linear(prev_dim, hidden_dim),
                nn.BatchNorm1d(hidden_dim),
                nn.ReLU(),
                nn.Dropout(0.3)
            ])
            prev_dim = hidden_dim

        # Final embedding layer
        layers.append(nn.Linear(prev_dim, embedding_dim))

        self.network = nn.Sequential(*layers)

    def forward(self, x):
        """
        Args:
            x: Input features (batch_size, input_dim)

        Returns:
            embeddings: L2-normalized embeddings (batch_size, embedding_dim)
        """
        embeddings = self.network(x)
        # L2 normalization for cosine similarity
        return F.normalize(embeddings, p=2, dim=1)


# Example: Building a Siamese network for enterprise use
def create_enterprise_siamese_network(input_type='tabular', input_dim=None):
    """
    Factory function for creating Siamese networks

    Args:
        input_type: 'tabular', 'image', 'text', or 'multimodal'
        input_dim: Input dimension (for tabular data)

    Returns:
        SiameseNetwork instance configured for the input type
    """

    if input_type == 'tabular':
        if input_dim is None:
            raise ValueError("input_dim required for tabular data")
        embedding_net = EmbeddingNet(
            input_dim=input_dim,
            embedding_dim=512,
            hidden_dims=[1024, 768, 512]
        )

    elif input_type == 'image':
        # Use pre-trained ResNet
        import torchvision.models as models
        resnet = models.resnet50(pretrained=True)
        # Remove classification head
        embedding_net = nn.Sequential(*list(resnet.children())[:-1])

    elif input_type == 'text':
        # Use transformer-based encoder
        from transformers import AutoModel
        embedding_net = AutoModel.from_pretrained('bert-base-uncased')

    else:
        raise ValueError(f"Unknown input_type: {input_type}")

    return SiameseNetwork(embedding_net, embedding_dim=512)

16.1.2 Contrastive Loss for Siamese Networks

The classic Siamese network uses contrastive loss to bring similar pairs together and push dissimilar pairs apart:

Show contrastive loss implementation
import torch
import torch.nn as nn
import torch.nn.functional as F

class ContrastiveLoss(nn.Module):
    """Contrastive loss: Loss = (1-Y)*0.5*D^2 + Y*0.5*max(margin-D, 0)^2"""

    def __init__(self, margin=2.0):
        super().__init__()
        self.margin = margin

    def forward(self, embedding1, embedding2, label):
        """label: 0 if similar, 1 if dissimilar"""
        euclidean_distance = F.pairwise_distance(embedding1, embedding2)
        loss_similar = (1 - label) * torch.pow(euclidean_distance, 2)
        loss_dissimilar = label * torch.pow(
            torch.clamp(self.margin - euclidean_distance, min=0.0), 2
        )
        loss = torch.mean(loss_similar + loss_dissimilar) * 0.5

        with torch.no_grad():
            threshold = self.margin / 2
            predictions = (euclidean_distance < threshold).long()
            accuracy = (predictions == (1 - label)).float().mean()
            similar_mask = label == 0
            dissimilar_mask = label == 1
            metrics = {
                "loss": loss.item(), "accuracy": accuracy.item(),
                "mean_similar_distance": euclidean_distance[similar_mask].mean().item() if similar_mask.any() else 0,
                "mean_dissimilar_distance": euclidean_distance[dissimilar_mask].mean().item() if dissimilar_mask.any() else 0,
            }
        return loss, metrics

# Usage example
loss_fn = ContrastiveLoss(margin=2.0)
emb1 = torch.randn(32, 512)
emb2 = torch.randn(32, 512)
labels = torch.randint(0, 2, (32,))
loss, metrics = loss_fn(emb1, emb2, labels)
print(f"Contrastive loss: {metrics['loss']:.4f}, accuracy: {metrics['accuracy']:.4f}")
Contrastive loss: 316.7425, accuracy: 0.3750
TipChoosing Distance Metrics

Euclidean distance works well for normalized embeddings in low dimensions (< 128).

Cosine distance (1 - cosine similarity) is preferred for:

  • High-dimensional embeddings (> 128)
  • Text embeddings
  • When magnitude isn’t meaningful

Learned distance metrics (e.g., Mahalanobis) can capture domain-specific similarity but require more data and computation.

16.1.3 Enterprise Siamese Architecture Patterns

For production systems handling billions of comparisons daily, architecture choices matter:

Show enterprise-optimized Siamese network
import torch
import torch.nn as nn
import torch.nn.functional as F

class EnterpriseOptimizedSiameseNetwork(nn.Module):
    """Production-optimized with mixed precision, gradient checkpointing, attention"""

    def __init__(self, base_model, embedding_dim=512, use_attention=True, use_gradient_checkpointing=False):
        super().__init__()
        self.base_model = base_model
        self.use_gradient_checkpointing = use_gradient_checkpointing
        self.projection = nn.Sequential(
            nn.Linear(embedding_dim, embedding_dim), nn.BatchNorm1d(embedding_dim),
            nn.ReLU(), nn.Linear(embedding_dim, embedding_dim),
        )
        self.attention = nn.MultiheadAttention(embed_dim=embedding_dim, num_heads=8, dropout=0.1, batch_first=True) if use_attention else None

    def forward(self, x1, x2):
        if self.use_gradient_checkpointing and self.training:
            embedding1 = torch.utils.checkpoint.checkpoint(self._encode, x1)
            embedding2 = torch.utils.checkpoint.checkpoint(self._encode, x2)
        else:
            embedding1, embedding2 = self._encode(x1), self._encode(x2)
        return embedding1, embedding2

    def _encode(self, x):
        features = self.base_model(x)
        if self.attention is not None:
            features_reshaped = features.unsqueeze(1)
            attended, _ = self.attention(features_reshaped, features_reshaped, features_reshaped)
            features = attended.squeeze(1)
        embedding = self.projection(features)
        return F.normalize(embedding, p=2, dim=1)

# Usage example
base = nn.Sequential(nn.Linear(128, 512), nn.ReLU())
model = EnterpriseOptimizedSiameseNetwork(base, embedding_dim=512, use_attention=True)
print(f"Model params: {sum(p.numel() for p in model.parameters()):,}")
Model params: 1,643,008
WarningProduction Considerations

Memory Management: For large models (> 1B parameters), gradient checkpointing is essential. It trades 30% more compute for 50% less memory.

Batch Size Selection: Larger batches (256-1024) improve training stability for Siamese networks. Use gradient accumulation if GPU memory is limited.

Learning Rate: Start with 1e-4 for fine-tuning pre-trained models, 1e-3 for training from scratch. Use warmup for stability.

16.2 Triplet Loss Optimization Techniques

While contrastive loss works with pairs, triplet loss works with triplets: (anchor, positive, negative). This provides more information per training example and often leads to better embeddings.

16.2.1 Triplet Loss Fundamentals

Triplet loss ensures that anchor-positive distance is smaller than anchor-negative distance by at least a margin:

Loss = max(d(anchor, positive) - d(anchor, negative) + margin, 0)

Show triplet loss implementation
import torch
import torch.nn as nn
import torch.nn.functional as F

class TripletLoss(nn.Module):
    """Triplet loss: d(anchor, positive) + margin < d(anchor, negative)"""

    def __init__(self, margin=1.0, distance_metric="euclidean"):
        super().__init__()
        self.margin = margin
        self.distance_metric = distance_metric

    def forward(self, anchor, positive, negative):
        if self.distance_metric == "euclidean":
            pos_distance = F.pairwise_distance(anchor, positive, p=2)
            neg_distance = F.pairwise_distance(anchor, negative, p=2)
        else:  # cosine
            pos_distance = 1 - F.cosine_similarity(anchor, positive)
            neg_distance = 1 - F.cosine_similarity(anchor, negative)

        losses = F.relu(pos_distance - neg_distance + self.margin)
        loss = losses.mean()

        with torch.no_grad():
            hard_triplets = (losses > 0).float().mean()
            accuracy = (pos_distance < neg_distance).float().mean()
            metrics = {
                "loss": loss.item(), "accuracy": accuracy.item(),
                "hard_triplets_fraction": hard_triplets.item(),
                "avg_pos_distance": pos_distance.mean().item(),
                "avg_neg_distance": neg_distance.mean().item(),
            }
        return loss, metrics

# Usage example
loss_fn = TripletLoss(margin=1.0)
anchor = torch.randn(32, 512)
positive = anchor + torch.randn(32, 512) * 0.1
negative = torch.randn(32, 512)
loss, metrics = loss_fn(anchor, positive, negative)
print(f"Triplet loss: {metrics['loss']:.4f}, accuracy: {metrics['accuracy']:.4f}")
Triplet loss: 0.0000, accuracy: 1.0000

16.2.2 Advanced Triplet Loss Variants

For enterprise scale, basic triplet loss isn’t enough. Here are production-tested variants:

Show advanced triplet loss with mining strategies
import torch
import torch.nn as nn
import torch.nn.functional as F

class AdvancedTripletLoss(nn.Module):
    """Advanced triplet loss with hard/semi-hard mining and soft margin"""

    def __init__(self, margin=1.0, mining_strategy="semi-hard", use_soft_margin=False, distance_metric="euclidean"):
        super().__init__()
        self.margin = margin
        self.mining_strategy = mining_strategy
        self.use_soft_margin = use_soft_margin
        self.distance_metric = distance_metric

    def forward(self, embeddings, labels):
        # Compute pairwise distances
        if self.distance_metric == "euclidean":
            distances = torch.cdist(embeddings, embeddings, p=2)
        else:
            embeddings_norm = F.normalize(embeddings, p=2, dim=1)
            distances = 1 - torch.mm(embeddings_norm, embeddings_norm.T)

        triplets = self._mine_triplets(distances, labels)
        if len(triplets) == 0:
            return torch.tensor(0.0, device=embeddings.device), {"loss": 0.0, "num_triplets": 0}

        anchor_idx, positive_idx, negative_idx = zip(*triplets)
        pos_distances = distances[anchor_idx, positive_idx]
        neg_distances = distances[anchor_idx, negative_idx]

        if self.use_soft_margin:
            loss = torch.log1p(torch.exp(pos_distances - neg_distances)).mean()
        else:
            loss = F.relu(pos_distances - neg_distances + self.margin).mean()

        with torch.no_grad():
            metrics = {"loss": loss.item(), "num_triplets": len(triplets),
                       "hard_triplets_fraction": (pos_distances > neg_distances).float().mean().item()}
        return loss, metrics

    def _mine_triplets(self, distances, labels):
        """Mine triplets based on strategy (hard, semi-hard, or all)"""
        batch_size = labels.shape[0]
        triplets = []
        for i in range(batch_size):
            pos_mask = (labels == labels[i]) & (torch.arange(batch_size, device=labels.device) != i)
            neg_mask = labels != labels[i]
            pos_indices, neg_indices = torch.where(pos_mask)[0], torch.where(neg_mask)[0]
            if len(pos_indices) == 0 or len(neg_indices) == 0:
                continue
            for pos_idx in pos_indices:
                if self.mining_strategy == "hard":
                    neg_idx = neg_indices[distances[i, neg_indices].argmin()]
                else:  # semi-hard or all
                    neg_idx = neg_indices[0]
                triplets.append((i, pos_idx.item(), neg_idx.item()))
        return triplets

# Usage example
loss_fn = AdvancedTripletLoss(margin=1.0, mining_strategy="semi-hard")
embeddings = torch.randn(50, 512)
labels = torch.randint(0, 10, (50,))
loss, metrics = loss_fn(embeddings, labels)
print(f"Advanced triplet loss: {metrics['loss']:.4f}, triplets: {metrics['num_triplets']}")
Advanced triplet loss: 1.0535, triplets: 288
TipMining Strategy Selection

Hard negative mining: Best for well-separated classes. Can cause training instability if classes overlap.

Semi-hard negative mining: Recommended for production. Balances learning speed with stability. Use when classes have some overlap.

All triplets: Only for small datasets (< 10K examples) or final fine-tuning. Computationally expensive.

Rule of thumb: Start with semi-hard, switch to hard if training plateaus after 70% of epochs.

16.2.3 Batch Construction for Efficient Triplet Training

Efficient triplet mining requires careful batch construction:

Show balanced batch sampler for triplet training
import numpy as np
import torch
from torch.utils.data import Sampler

class BalancedBatchSampler(Sampler):
    """Sampler ensuring each batch has P classes × K samples per class"""

    def __init__(self, labels, n_classes_per_batch=10, n_samples_per_class=5):
        self.labels = np.array(labels)
        self.n_classes_per_batch = n_classes_per_batch
        self.n_samples_per_class = n_samples_per_class

        # Build index mapping: class_id -> [sample_indices]
        self.class_to_indices = {}
        for idx, label in enumerate(self.labels):
            if label not in self.class_to_indices:
                self.class_to_indices[label] = []
            self.class_to_indices[label].append(idx)

        # Keep classes with enough samples
        self.valid_classes = [c for c, indices in self.class_to_indices.items()
                              if len(indices) >= self.n_samples_per_class]
        self.batch_size = n_classes_per_batch * n_samples_per_class

    def __iter__(self):
        classes = np.random.permutation(self.valid_classes)
        for i in range(0, len(classes), self.n_classes_per_batch):
            batch_classes = classes[i : i + self.n_classes_per_batch]
            batch_indices = []
            for class_id in batch_classes:
                class_indices = self.class_to_indices[class_id]
                sampled = np.random.choice(class_indices, size=self.n_samples_per_class,
                                           replace=len(class_indices) < self.n_samples_per_class)
                batch_indices.extend(sampled)
            yield batch_indices

    def __len__(self):
        return len(self.valid_classes) // self.n_classes_per_batch

# Usage example
labels = np.random.randint(0, 100, size=10000)  # 10K samples, 100 classes
sampler = BalancedBatchSampler(labels, n_classes_per_batch=10, n_samples_per_class=5)
print(f"Batch size: {sampler.batch_size}, Batches per epoch: {len(sampler)}")
Batch size: 50, Batches per epoch: 10
WarningProduction Batch Sizing

Memory constraints: P × K = batch_size. Larger batches provide more triplets but require more memory.

Recommended configurations:

  • Small models (< 100M params): P=16, K=8, batch_size=128
  • Medium models (100M-1B params): P=10, K=5, batch_size=50
  • Large models (> 1B params): P=8, K=4, batch_size=32

GPU utilization: Use gradient accumulation to simulate larger batches if needed.

16.3 One-Shot Learning for Rare Events

One-shot learning—learning from a single example—is critical for enterprise scenarios where rare events are important but examples are scarce: fraud detection, manufacturing defects, zero-day threats, rare diseases.

16.3.1 One-Shot Learning Fundamentals

Traditional ML fails with one example per class. Siamese networks succeed by:

  1. Learning similarity during training on abundant data
  2. Applying similarity at inference to new classes with few examples
  3. Comparing rather than classifying new inputs
Show one-shot classifier implementation
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

class OneShotClassifier:
    """One-shot classifier: classify by finding most similar support example"""

    def __init__(self, siamese_model, distance_metric="euclidean"):
        self.model = siamese_model
        self.distance_metric = distance_metric
        self.support_set = {}  # class_id -> embedding

    def add_support_example(self, class_id, example):
        """Add a single example for a new class"""
        with torch.no_grad():
            self.model.eval()
            embedding = self.model.get_embedding(example)
            self.support_set[class_id] = embedding.cpu()

    def predict(self, query, return_distances=False, top_k=1):
        """Predict class by finding nearest support example"""
        with torch.no_grad():
            self.model.eval()
            query_embedding = self.model.get_embedding(query)

            distances = {}
            for class_id, support_emb in self.support_set.items():
                support_emb = support_emb.to(query_embedding.device)
                if self.distance_metric == "euclidean":
                    dist = F.pairwise_distance(query_embedding, support_emb.unsqueeze(0)).item()
                else:
                    dist = (1 - F.cosine_similarity(query_embedding, support_emb.unsqueeze(0))).item()
                distances[class_id] = dist

            sorted_classes = sorted(distances.items(), key=lambda x: x[1])
            if top_k == 1:
                return (sorted_classes[0][0], sorted_classes[0][1]) if return_distances else sorted_classes[0][0]
            results = sorted_classes[:top_k]
            return ([c for c, _ in results], [d for _, d in results]) if return_distances else [c for c, _ in results]

    def predict_proba(self, query, temperature=1.0):
        """Predict class probabilities using softmax over negative distances"""
        class_ids, distances = self.predict(query, return_distances=True, top_k=len(self.support_set))
        similarities = [-d / temperature for d in distances]
        exp_sims = np.exp(similarities - np.max(similarities))
        return dict(zip(class_ids, exp_sims / exp_sims.sum()))

# Usage example (placeholder model)
class PlaceholderModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Linear(50, 128)
    def get_embedding(self, x):
        return F.normalize(self.encoder(x), dim=-1)

model = PlaceholderModel()
classifier = OneShotClassifier(model)
classifier.add_support_example("fraud_type_A", torch.randn(1, 50))
classifier.add_support_example("fraud_type_B", torch.randn(1, 50))
query = torch.randn(1, 50)
pred = classifier.predict(query)
print(f"Predicted class: {pred}")
Predicted class: fraud_type_B
TipWhen One-Shot Learning Works Best

Ideal scenarios:

  • High-quality training data (even if small)
  • Well-defined similarity metric
  • Rare event detection (fraud, anomalies, defects)
  • Rapidly evolving categories (new threats, trends)

Challenging scenarios:

  • Noisy data (single example may be unrepresentative)
  • Complex decision boundaries
  • Classes that require multiple features to distinguish

Best practice: Collect 3-5 examples per class when possible. Average their embeddings for more robust representation.

16.3.2 Few-Shot Learning Extensions

When you have 2-10 examples per class (few-shot), you can use more sophisticated techniques:

Show prototypical network classifier
import torch
import torch.nn.functional as F

class PrototypicalNetworkClassifier:
    """Prototypical Networks: compute class prototypes from K examples, classify by nearest prototype"""

    def __init__(self, embedding_model):
        self.model = embedding_model
        self.prototypes = {}  # class_id -> prototype embedding

    def compute_prototypes(self, support_set):
        """Compute prototype (centroid) for each class from support examples"""
        self.prototypes = {}
        with torch.no_grad():
            self.model.eval()
            for class_id, examples in support_set.items():
                if isinstance(examples, list):
                    examples = torch.stack(examples)
                embeddings = self.model.get_embedding(examples)
                self.prototypes[class_id] = embeddings.mean(dim=0)

    def predict(self, query):
        """Classify query by finding nearest prototype"""
        with torch.no_grad():
            self.model.eval()
            query_embedding = self.model.get_embedding(query)
            distances = {class_id: F.pairwise_distance(query_embedding, proto.unsqueeze(0)).item()
                         for class_id, proto in self.prototypes.items()}
            return min(distances.items(), key=lambda x: x[1])[0]

# Usage example
import torch.nn as nn
class SimpleEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Linear(64, 128)
    def get_embedding(self, x):
        return F.normalize(self.net(x), dim=-1)

encoder = SimpleEncoder()
classifier = PrototypicalNetworkClassifier(encoder)
support_set = {"class_A": torch.randn(5, 64), "class_B": torch.randn(5, 64)}  # 5 examples each
classifier.compute_prototypes(support_set)
query = torch.randn(1, 64)
pred = classifier.predict(query)
print(f"Predicted class: {pred}")
Predicted class: class_B

16.4 Similarity Threshold Calibration

A critical but often overlooked challenge: How do you set the threshold for “similar enough”? Too low and you get false positives. Too high and you miss true matches.

16.4.1 The Threshold Calibration Challenge

class ThresholdCalibrator:
    """
    Calibrate similarity thresholds for production deployment

    Challenge: The optimal threshold depends on:
    - Distribution of true positives vs negatives
    - Business costs of false positives vs false negatives
    - Dataset characteristics (intra-class vs inter-class variance)

    This class provides multiple calibration strategies.
    """

    def __init__(self, siamese_model):
        self.model = siamese_model
        self.threshold = None
        self.calibration_metrics = {}

    def calibrate_on_validation_set(
        self,
        validation_pairs,
        validation_labels,
        metric='f1',
        plot=False
    ):
        """
        Calibrate threshold on validation set to optimize a metric

        Args:
            validation_pairs: List of (item1, item2) pairs
            validation_labels: 1 if similar, 0 if dissimilar
            metric: 'f1', 'precision', 'recall', or 'accuracy'
            plot: If True, plot threshold vs metric curve

        Returns:
            Optimal threshold value
        """
        # Compute distances for all pairs
        distances = []

        with torch.no_grad():
            self.model.eval()

            for item1, item2 in validation_pairs:
                embedding1 = self.model.get_embedding(item1.unsqueeze(0))
                embedding2 = self.model.get_embedding(item2.unsqueeze(0))

                distance = F.pairwise_distance(embedding1, embedding2).item()
                distances.append(distance)

        distances = np.array(distances)
        validation_labels = np.array(validation_labels)

        # Try different thresholds
        thresholds = np.linspace(distances.min(), distances.max(), 100)
        metrics_by_threshold = []

        for threshold in thresholds:
            # Predict: similar if distance < threshold
            predictions = (distances < threshold).astype(int)

            # Compute metrics
            tp = ((predictions == 1) & (validation_labels == 1)).sum()
            fp = ((predictions == 1) & (validation_labels == 0)).sum()
            tn = ((predictions == 0) & (validation_labels == 0)).sum()
            fn = ((predictions == 0) & (validation_labels == 1)).sum()

            precision = tp / (tp + fp) if (tp + fp) > 0 else 0
            recall = tp / (tp + fn) if (tp + fn) > 0 else 0
            f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
            accuracy = (tp + tn) / len(validation_labels)

            metrics_by_threshold.append({
                'threshold': threshold,
                'precision': precision,
                'recall': recall,
                'f1': f1,
                'accuracy': accuracy
            })

        # Find threshold that maximizes chosen metric
        best_idx = max(
            range(len(metrics_by_threshold)),
            key=lambda i: metrics_by_threshold[i][metric]
        )

        self.threshold = metrics_by_threshold[best_idx]['threshold']
        self.calibration_metrics = metrics_by_threshold[best_idx]

        if plot:
            self._plot_calibration_curve(metrics_by_threshold, metric)

        return self.threshold

    def calibrate_with_business_costs(
        self,
        validation_pairs,
        validation_labels,
        false_positive_cost=1.0,
        false_negative_cost=1.0
    ):
        """
        Calibrate threshold based on business costs

        Args:
            validation_pairs: List of (item1, item2) pairs
            validation_labels: 1 if similar, 0 if dissimilar
            false_positive_cost: Cost of incorrectly marking as similar
            false_negative_cost: Cost of missing a true match

        Returns:
            Cost-optimal threshold

        Example costs:
        - Fraud detection: FN cost >> FP cost (missing fraud is expensive)
        - Product matching: FP cost >> FN cost (wrong matches annoy users)
        """
        # Compute distances
        distances = []

        with torch.no_grad():
            self.model.eval()

            for item1, item2 in validation_pairs:
                embedding1 = self.model.get_embedding(item1.unsqueeze(0))
                embedding2 = self.model.get_embedding(item2.unsqueeze(0))

                distance = F.pairwise_distance(embedding1, embedding2).item()
                distances.append(distance)

        distances = np.array(distances)
        validation_labels = np.array(validation_labels)

        # Try different thresholds
        thresholds = np.linspace(distances.min(), distances.max(), 100)
        costs = []

        for threshold in thresholds:
            predictions = (distances < threshold).astype(int)

            fp = ((predictions == 1) & (validation_labels == 0)).sum()
            fn = ((predictions == 0) & (validation_labels == 1)).sum()

            total_cost = fp * false_positive_cost + fn * false_negative_cost
            costs.append(total_cost)

        # Find threshold that minimizes cost
        best_idx = np.argmin(costs)
        self.threshold = thresholds[best_idx]

        self.calibration_metrics = {
            'threshold': self.threshold,
            'expected_cost': costs[best_idx],
            'false_positive_cost': false_positive_cost,
            'false_negative_cost': false_negative_cost
        }

        return self.threshold

    def calibrate_for_precision_target(
        self,
        validation_pairs,
        validation_labels,
        target_precision=0.95
    ):
        """
        Calibrate to achieve target precision

        Use when false positives are unacceptable (e.g., financial matching)

        Args:
            validation_pairs: List of (item1, item2) pairs
            validation_labels: 1 if similar, 0 if dissimilar
            target_precision: Desired precision (0-1)

        Returns:
            Threshold that achieves target precision (or closest possible)
        """
        # Compute distances
        distances = []

        with torch.no_grad():
            self.model.eval()

            for item1, item2 in validation_pairs:
                embedding1 = self.model.get_embedding(item1.unsqueeze(0))
                embedding2 = self.model.get_embedding(item2.unsqueeze(0))

                distance = F.pairwise_distance(embedding1, embedding2).item()
                distances.append(distance)

        distances = np.array(distances)
        validation_labels = np.array(validation_labels)

        # Try different thresholds
        thresholds = np.linspace(distances.min(), distances.max(), 100)

        best_threshold = None
        best_precision = 0
        best_recall = 0

        for threshold in thresholds:
            predictions = (distances < threshold).astype(int)

            tp = ((predictions == 1) & (validation_labels == 1)).sum()
            fp = ((predictions == 1) & (validation_labels == 0)).sum()
            fn = ((predictions == 0) & (validation_labels == 1)).sum()

            precision = tp / (tp + fp) if (tp + fp) > 0 else 0
            recall = tp / (tp + fn) if (tp + fn) > 0 else 0

            # Find threshold closest to target precision
            if precision >= target_precision:
                if best_threshold is None or recall > best_recall:
                    best_threshold = threshold
                    best_precision = precision
                    best_recall = recall

        if best_threshold is None:
            # Can't achieve target, return threshold with highest precision
            for threshold in thresholds:
                predictions = (distances < threshold).astype(int)
                tp = ((predictions == 1) & (validation_labels == 1)).sum()
                fp = ((predictions == 1) & (validation_labels == 0)).sum()
                precision = tp / (tp + fp) if (tp + fp) > 0 else 0

                if precision > best_precision:
                    best_precision = precision
                    best_threshold = threshold

        self.threshold = best_threshold
        self.calibration_metrics = {
            'threshold': best_threshold,
            'achieved_precision': best_precision,
            'achieved_recall': best_recall,
            'target_precision': target_precision
        }

        return self.threshold

    def _plot_calibration_curve(self, metrics_by_threshold, target_metric):
        """Plot threshold vs metric curve"""
        import matplotlib.pyplot as plt

        thresholds = [m['threshold'] for m in metrics_by_threshold]
        values = [m[target_metric] for m in metrics_by_threshold]

        plt.figure(figsize=(10, 6))
        plt.plot(thresholds, values)
        plt.axvline(self.threshold, color='r', linestyle='--',
                   label=f'Optimal: {self.threshold:.3f}')
        plt.xlabel('Threshold')
        plt.ylabel(target_metric.capitalize())
        plt.title(f'Threshold Calibration: {target_metric.capitalize()}')
        plt.legend()
        plt.grid(True)
        plt.show()
WarningThreshold Calibration Best Practices

Re-calibrate regularly: Data distributions drift. Re-calibrate quarterly or when you detect performance degradation.

Use stratified validation: Ensure your validation set represents production distribution. Unbalanced calibration data leads to suboptimal thresholds.

Monitor threshold effectiveness: Track precision/recall in production. Alert if metrics deviate > 5% from calibration values.

Business cost alignment: Work with stakeholders to quantify FP and FN costs. Technical metrics (F1) may not align with business value.

16.4.2 Dynamic Threshold Adaptation

For production systems, static thresholds aren’t enough. Implement dynamic adaptation:

Show adaptive threshold manager
import numpy as np

class AdaptiveThresholdManager:
    """Manage thresholds that adapt to changing data distributions"""

    def __init__(self, base_threshold=0.5):
        self.base_threshold = base_threshold
        self.category_thresholds = {}
        self.performance_history = []

    def get_threshold(self, category=None, confidence=None):
        """Get threshold, adjusted for category or confidence"""
        threshold = self.base_threshold
        if category is not None and category in self.category_thresholds:
            threshold = self.category_thresholds[category]
        if confidence is not None:
            adjustment = (confidence - 0.5) * 0.2  # ±0.1 adjustment
            threshold = threshold - adjustment
        return threshold

    def update_category_threshold(self, category, new_threshold):
        self.category_thresholds[category] = new_threshold

    def adapt_from_feedback(self, predictions, labels, learning_rate=0.1):
        """Adapt thresholds based on recent performance feedback"""
        current_predictions = (predictions < self.base_threshold).astype(int)
        error_rate = (current_predictions != labels).mean()

        if error_rate > 0.2:
            best_threshold = self._find_optimal_threshold(predictions, labels)
            self.base_threshold = (1 - learning_rate) * self.base_threshold + learning_rate * best_threshold

        self.performance_history.append({"threshold": self.base_threshold, "error_rate": error_rate})

    def _find_optimal_threshold(self, distances, labels):
        thresholds = np.linspace(distances.min(), distances.max(), 50)
        errors = [(distances < t).astype(int) != labels for t in thresholds]
        error_rates = [e.mean() for e in errors]
        return thresholds[np.argmin(error_rates)]

# Usage example
manager = AdaptiveThresholdManager(base_threshold=0.5)
manager.update_category_threshold("high_value", 0.7)
print(f"Base threshold: {manager.base_threshold}")
print(f"High-value threshold: {manager.get_threshold(category='high_value')}")
Base threshold: 0.5
High-value threshold: 0.7

16.5 Production Deployment Patterns

Deploying Siamese networks at scale requires careful architecture design. Here are battle-tested patterns from trillion-row deployments:

16.5.1 Pattern 1: Embedding Cache Architecture

Show Siamese embedding service with caching
import hashlib
import torch
import torch.nn.functional as F

class SiameseEmbeddingService:
    """Production service with embedding caching, batch processing, GPU/CPU flexibility"""

    def __init__(self, model, cache_size=100000, batch_size=256, device="cuda"):
        self.model = model.to(device).eval()
        self.device = device
        self.batch_size = batch_size
        self.embedding_cache = {}
        self.cache_size = cache_size
        self.cache_hits = 0
        self.cache_misses = 0

    def _get_cache_key(self, item):
        return hashlib.md5(item.cpu().numpy().tobytes()).hexdigest()

    def get_embedding(self, item, use_cache=True):
        if use_cache:
            cache_key = self._get_cache_key(item)
            if cache_key in self.embedding_cache:
                self.cache_hits += 1
                return self.embedding_cache[cache_key]
            self.cache_misses += 1

        with torch.no_grad():
            embedding = self.model.get_embedding(item.to(self.device))

        if use_cache:
            if len(self.embedding_cache) >= self.cache_size:
                oldest_key = next(iter(self.embedding_cache))
                del self.embedding_cache[oldest_key]
            self.embedding_cache[cache_key] = embedding.cpu()
        return embedding

    def get_embeddings_batch(self, items):
        embeddings = []
        for i in range(0, len(items), self.batch_size):
            batch = items[i : i + self.batch_size]
            with torch.no_grad():
                batch_embeddings = self.model.get_embedding(batch.to(self.device))
            embeddings.append(batch_embeddings.cpu())
        return torch.cat(embeddings, dim=0)

    def compare(self, item1, item2):
        emb1, emb2 = self.get_embedding(item1), self.get_embedding(item2)
        return F.cosine_similarity(emb1, emb2, dim=0).item()

    def get_cache_stats(self):
        total = self.cache_hits + self.cache_misses
        return {"cache_size": len(self.embedding_cache), "cache_hits": self.cache_hits,
                "cache_misses": self.cache_misses, "hit_rate": self.cache_hits / total if total > 0 else 0}

# Usage example (with placeholder model)
import torch.nn as nn
class DummyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Linear(64, 128)
    def get_embedding(self, x):
        return F.normalize(self.net(x), dim=-1)

model = DummyModel()
service = SiameseEmbeddingService(model, cache_size=1000, device="cpu")
item = torch.randn(1, 64)
emb = service.get_embedding(item)
print(f"Embedding shape: {emb.shape}, Cache stats: {service.get_cache_stats()}")
Embedding shape: torch.Size([1, 128]), Cache stats: {'cache_size': 1, 'cache_hits': 0, 'cache_misses': 1, 'hit_rate': 0.0}

16.5.2 Pattern 2: Approximate Nearest Neighbor Integration

For billion-scale similarity search, integrate with ANN indexes:

Show Siamese ANN service with FAISS
import torch.nn.functional as F

class SiameseANNService:
    """Siamese network integrated with FAISS for sub-millisecond similarity search"""

    def __init__(self, siamese_service, embedding_dim=512):
        self.siamese_service = siamese_service
        self.embedding_dim = embedding_dim
        try:
            import faiss
            self.index = faiss.IndexFlatIP(embedding_dim)  # Inner product for cosine similarity
        except ImportError:
            print("FAISS not installed. Install with: pip install faiss-cpu")
            self.index = None
        self.id_to_index = {}
        self.index_to_id = {}

    def add_items(self, item_ids, items):
        if self.index is None:
            raise RuntimeError("FAISS not available")
        embeddings = self.siamese_service.get_embeddings_batch(items)
        embeddings = F.normalize(embeddings, p=2, dim=1).cpu().numpy()
        start_idx = self.index.ntotal
        self.index.add(embeddings)
        for i, item_id in enumerate(item_ids):
            idx = start_idx + i
            self.id_to_index[item_id] = idx
            self.index_to_id[idx] = item_id

    def search(self, query, top_k=10):
        if self.index is None:
            raise RuntimeError("FAISS not available")
        query_embedding = self.siamese_service.get_embedding(query)
        query_embedding = F.normalize(query_embedding, p=2, dim=1).cpu().numpy()
        similarities, indices = self.index.search(query_embedding, top_k)
        return [(self.index_to_id[idx], float(sim)) for sim, idx in zip(similarities[0], indices[0]) if idx in self.index_to_id]

    def get_statistics(self):
        return {"total_items": self.index.ntotal if self.index else 0, "embedding_dim": self.embedding_dim}

# Usage note: Requires FAISS and a trained SiameseEmbeddingService
print("SiameseANNService: Sub-millisecond search across billions of items")
print("Features: FAISS integration, cosine similarity via normalized inner product")
SiameseANNService: Sub-millisecond search across billions of items
Features: FAISS integration, cosine similarity via normalized inner product

16.5.3 Pattern 3: Multi-Stage Verification Pipeline

For high-precision applications (fraud, compliance), use multi-stage verification:

class MultiStageVerificationPipeline:
    """
    Multi-stage verification using Siamese networks

    Stage 1: Fast filtering with loose threshold
    Stage 2: Detailed verification with strict threshold
    Stage 3: Human review for borderline cases

    Reduces compute cost while maintaining high accuracy.
    """

    def __init__(
        self,
        siamese_service,
        stage1_threshold=0.7,  # Recall-optimized
        stage2_threshold=0.9,  # Precision-optimized
        use_ann=True
    ):
        self.siamese_service = siamese_service
        self.stage1_threshold = stage1_threshold
        self.stage2_threshold = stage2_threshold

        if use_ann:
            self.ann_service = SiameseANNService(
                siamese_service,
                embedding_dim=512
            )
        else:
            self.ann_service = None

        self.stage1_candidates = 0
        self.stage2_matches = 0
        self.human_review_cases = 0

    def verify(self, query, candidate_pool=None, candidate_ids=None):
        """
        Multi-stage verification

        Args:
            query: Item to verify
            candidate_pool: Pool of candidates to check against
                          (or None to use ANN search)
            candidate_ids: IDs for candidates (if using candidate_pool)

        Returns:
            Dict with:
            - matched: Boolean or 'needs_review'
            - match_id: ID of matched item (if any)
            - confidence: Similarity score
            - stage: Which stage made the decision
        """

        # Stage 1: Fast filtering
        if self.ann_service is not None and candidate_pool is None:
            # Use ANN search for fast filtering
            stage1_results = self.ann_service.search(query, top_k=100)
            stage1_candidates = [
                (item_id, sim) for item_id, sim in stage1_results
                if sim >= self.stage1_threshold
            ]
        else:
            # Linear search through candidate pool
            if candidate_pool is None:
                raise ValueError("Must provide candidate_pool or use ANN")

            query_embedding = self.siamese_service.get_embedding(query)
            candidate_embeddings = self.siamese_service.get_embeddings_batch(
                candidate_pool
            )

            similarities = F.cosine_similarity(
                query_embedding.unsqueeze(0),
                candidate_embeddings,
                dim=1
            )

            stage1_candidates = [
                (candidate_ids[i], sim.item())
                for i, sim in enumerate(similarities)
                if sim.item() >= self.stage1_threshold
            ]

        self.stage1_candidates += len(stage1_candidates)

        if len(stage1_candidates) == 0:
            return {
                'matched': False,
                'match_id': None,
                'confidence': 0.0,
                'stage': 1
            }

        # Stage 2: Detailed verification
        # For production, this might involve:
        # - More expensive model
        # - Feature-level comparison
        # - Additional business logic

        best_match = max(stage1_candidates, key=lambda x: x[1])
        match_id, similarity = best_match

        if similarity >= self.stage2_threshold:
            # High confidence match
            self.stage2_matches += 1
            return {
                'matched': True,
                'match_id': match_id,
                'confidence': similarity,
                'stage': 2
            }
        else:
            # Borderline case - needs human review
            self.human_review_cases += 1
            return {
                'matched': 'needs_review',
                'match_id': match_id,
                'confidence': similarity,
                'stage': 2,
                'review_reason': 'confidence_below_threshold'
            }

    def get_statistics(self):
        """Get pipeline statistics"""
        return {
            'stage1_candidates': self.stage1_candidates,
            'stage2_matches': self.stage2_matches,
            'human_review_cases': self.human_review_cases,
            'human_review_rate': self.human_review_cases / max(self.stage1_candidates, 1)
        }
TipProduction Deployment Checklist

Before deploying Siamese networks to production:

Ongoing maintenance:

  • Re-calibrate thresholds quarterly
  • Retrain on recent data every 3-6 months
  • Monitor for data drift (distributional shifts)
  • Collect hard negative examples for continuous improvement

16.6 Key Takeaways

  • Siamese networks learn similarity rather than classification, enabling few-shot learning, verification tasks, and open-set recognition without retraining.

  • Triplet loss with hard negative mining provides better gradients than contrastive loss for most enterprise applications. Use semi-hard mining for stable training.

  • One-shot learning enables immediate adaptation to new categories from single examples—critical for fraud detection, rare defects, and rapidly evolving threats.

  • Threshold calibration is not optional. Use validation data to calibrate thresholds based on business metrics (precision/recall) or costs (FP/FN costs). Re-calibrate quarterly.

  • Production deployment requires caching and ANN integration to achieve sub-millisecond similarity search at billion-scale. Multi-stage pipelines balance cost and accuracy.

  • Monitor similarity distributions in production. Shifts indicate data drift or model degradation. Alert when mean similarity changes > 10% from baseline.

16.7 Looking Ahead

In Chapter 17, we expand beyond supervised and Siamese approaches to self-supervised learning—techniques that leverage the structure of unlabeled data itself to train powerful embeddings. We’ll explore masked language modeling, vision transformers, and multi-modal self-supervision strategies that enable learning from trillions of unlabeled examples across text, images, time-series, and more.

16.8 Further Reading

  • Bromley, J., et al. (1993). “Signature Verification using a Siamese Time Delay Neural Network.” NIPS.
  • Schroff, F., Kalenichenko, D., & Philbin, J. (2015). “FaceNet: A Unified Embedding for Face Recognition and Clustering.” CVPR.
  • Snell, J., Swersky, K., & Zemel, R. (2017). “Prototypical Networks for Few-shot Learning.” NeurIPS.
  • Koch, G., Zemel, R., & Salakhutdinov, R. (2015). “Siamese Neural Networks for One-shot Image Recognition.” ICML Workshop.
  • Wang, J., et al. (2017). “Deep Metric Learning with Angular Loss.” ICCV.
  • Hermans, A., Beyer, L., & Leibe, B. (2017). “In Defense of the Triplet Loss for Person Re-Identification.” arXiv:1703.07737.