@dataclass
class EcologyConfig:
image_size: int = 224
n_mels: int = 128
sequence_length: int = 256 # DNA barcode length
embedding_dim: int = 256
n_species: int = 10000
class SpeciesImageEncoder(nn.Module):
"""Encode species images for identification (camera traps, citizen science)."""
def __init__(self, config: EcologyConfig):
super().__init__()
self.backbone = nn.Sequential(
nn.Conv2d(3, 64, 7, stride=2, padding=3), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(3, 2, 1),
nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(), nn.MaxPool2d(2),
nn.Conv2d(128, 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(), nn.AdaptiveAvgPool2d(1))
self.proj = nn.Linear(256, config.embedding_dim)
self.species_head = nn.Linear(config.embedding_dim, config.n_species)
def forward(self, images: torch.Tensor) -> tuple:
features = self.backbone(images).squeeze(-1).squeeze(-1)
embeddings = F.normalize(self.proj(features), dim=-1)
return embeddings, self.species_head(embeddings)
class BioacousticEncoder(nn.Module):
"""Encode audio spectrograms for species identification (bird songs, whale calls)."""
def __init__(self, config: EcologyConfig):
super().__init__()
self.encoder = nn.Sequential(
nn.Conv2d(1, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2),
nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(), nn.AdaptiveAvgPool2d(4))
self.proj = nn.Sequential(nn.Linear(128 * 16, 512), nn.ReLU(), nn.Linear(512, config.embedding_dim))
def forward(self, spectrograms: torch.Tensor) -> torch.Tensor:
features = self.encoder(spectrograms.unsqueeze(1)).flatten(1)
return F.normalize(self.proj(features), dim=-1)
class DNABarcodeEncoder(nn.Module):
"""Encode DNA barcode sequences for species identification (eDNA, metabarcoding)."""
def __init__(self, config: EcologyConfig):
super().__init__()
self.nucleotide_embed = nn.Embedding(5, 64) # A, C, G, T, N
self.conv = nn.Sequential(
nn.Conv1d(64, 128, 7, padding=3), nn.BatchNorm1d(128), nn.ReLU(), nn.MaxPool1d(2),
nn.Conv1d(128, 256, 5, padding=2), nn.BatchNorm1d(256), nn.ReLU(), nn.AdaptiveAvgPool1d(16))
self.proj = nn.Linear(256 * 16, config.embedding_dim)
def forward(self, sequences: torch.Tensor) -> torch.Tensor:
x = self.nucleotide_embed(sequences).transpose(1, 2)
x = self.conv(x).flatten(1)
return F.normalize(self.proj(x), dim=-1)
# Usage example
eco_config = EcologyConfig()
species_encoder = SpeciesImageEncoder(eco_config)
audio_encoder = BioacousticEncoder(eco_config)
dna_encoder = DNABarcodeEncoder(eco_config)
# Encode camera trap images
wildlife_images = torch.randn(4, 3, 224, 224)
species_emb, species_logits = species_encoder(wildlife_images)
print(f"Species embeddings: {species_emb.shape}") # [4, 256]
# Encode bird song spectrograms
spectrograms = torch.randn(4, 128, 200) # 128 mel bins, 200 time frames
audio_emb = audio_encoder(spectrograms)
print(f"Audio embeddings: {audio_emb.shape}") # [4, 256]
# Encode DNA barcodes
dna_seqs = torch.randint(0, 5, (4, 256)) # COI barcode sequences
dna_emb = dna_encoder(dna_seqs)
print(f"DNA embeddings: {dna_emb.shape}") # [4, 256]