import torch
import torch.nn as nn
import torch.nn.functional as F
class MoCoTextEmbedding(nn.Module):
"""MoCo for text embeddings - works with small batches!"""
def __init__(self, vocab_size=10000, embed_dim=256, projection_dim=128,
queue_size=4096, momentum=0.999):
super().__init__()
self.queue_size = queue_size
self.momentum = momentum
self.temperature = 0.07
# Query encoder
self.encoder_q = nn.Sequential(
nn.Embedding(vocab_size, embed_dim),
nn.Flatten(1), nn.Linear(embed_dim * 20, projection_dim)
)
# Key encoder (momentum updated)
self.encoder_k = nn.Sequential(
nn.Embedding(vocab_size, embed_dim),
nn.Flatten(1), nn.Linear(embed_dim * 20, projection_dim)
)
for p_q, p_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
p_k.data.copy_(p_q.data)
p_k.requires_grad = False
self.register_buffer("queue", F.normalize(torch.randn(projection_dim, queue_size), dim=0))
self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
@torch.no_grad()
def _momentum_update(self):
for p_q, p_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
p_k.data = p_k.data * self.momentum + p_q.data * (1 - self.momentum)
@torch.no_grad()
def _update_queue(self, keys):
batch_size = keys.shape[0]
ptr = int(self.queue_ptr)
self.queue[:, ptr:ptr + batch_size] = keys.T
self.queue_ptr[0] = (ptr + batch_size) % self.queue_size
def forward(self, query_ids, key_ids):
q = F.normalize(self.encoder_q(query_ids), dim=1)
with torch.no_grad():
self._momentum_update()
k = F.normalize(self.encoder_k(key_ids), dim=1)
l_pos = torch.einsum("nc,nc->n", q, k).unsqueeze(-1)
l_neg = torch.einsum("nc,ck->nk", q, self.queue.clone().detach())
logits = torch.cat([l_pos, l_neg], dim=1) / self.temperature
labels = torch.zeros(logits.shape[0], dtype=torch.long, device=q.device)
loss = F.cross_entropy(logits, labels)
self._update_queue(k)
with torch.no_grad():
accuracy = (logits.argmax(dim=1) == labels).float().mean()
return loss, {"accuracy": accuracy.item(), "queue_ptr": int(self.queue_ptr)}
# Example: MoCo works with small batches!
torch.manual_seed(42)
model = MoCoTextEmbedding(vocab_size=1000, embed_dim=64, projection_dim=32, queue_size=256)
for i in range(5):
query = torch.randint(0, 1000, (16, 20))
key = torch.randint(0, 1000, (16, 20))
loss, metrics = model(query, key)
print(f"MoCo Loss: {loss.item():.4f}")
print(f"Accuracy: {metrics['accuracy']:.2%}")
print(f"Queue filled: {metrics['queue_ptr']}/256")