class VideoAutoencoder(nn.Module):
"""Autoencoder for learning normal video patterns."""
def __init__(self, embedding_dim: int = 256):
super().__init__()
self.encoder = nn.Sequential(
nn.Conv3d(3, 64, (3, 4, 4), stride=(1, 2, 2), padding=(1, 1, 1)), nn.BatchNorm3d(64), nn.ReLU(),
nn.Conv3d(64, 128, (3, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1)), nn.BatchNorm3d(128), nn.ReLU(),
nn.Conv3d(128, 256, (3, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1)), nn.BatchNorm3d(256), nn.ReLU(),
nn.AdaptiveAvgPool3d(1))
self.latent_proj = nn.Linear(256, embedding_dim)
self.decoder_proj = nn.Linear(embedding_dim, 256 * 2 * 4 * 4)
self.decoder = nn.Sequential(
nn.ConvTranspose3d(256, 128, (3, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1)), nn.BatchNorm3d(128), nn.ReLU(),
nn.ConvTranspose3d(128, 64, (3, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1)), nn.BatchNorm3d(64), nn.ReLU(),
nn.ConvTranspose3d(64, 3, (3, 4, 4), stride=(1, 2, 2), padding=(1, 1, 1)), nn.Sigmoid())
def encode(self, video: torch.Tensor) -> torch.Tensor:
features = self.encoder(video).flatten(1)
return self.latent_proj(features)
def decode(self, latent: torch.Tensor, target_shape: tuple) -> torch.Tensor:
x = self.decoder_proj(latent).view(-1, 256, 2, 4, 4)
return F.interpolate(self.decoder(x), size=target_shape[2:], mode='trilinear')
def forward(self, video: torch.Tensor) -> tuple:
latent = self.encode(video)
reconstructed = self.decode(latent, video.shape)
return reconstructed, latent
class FramePredictionModel(nn.Module):
"""Predict future frames - anomalies are unpredictable."""
def __init__(self, embedding_dim: int = 256):
super().__init__()
self.frame_encoder = nn.Sequential(
nn.Conv2d(3, 64, 4, stride=2, padding=1), nn.ReLU(),
nn.Conv2d(64, 128, 4, stride=2, padding=1), nn.ReLU(),
nn.Conv2d(128, 256, 4, stride=2, padding=1), nn.ReLU(),
nn.AdaptiveAvgPool2d(4))
self.temporal = nn.LSTM(256 * 16, embedding_dim, batch_first=True)
self.frame_decoder = nn.Sequential(
nn.ConvTranspose2d(embedding_dim, 128, 4, stride=2, padding=1), nn.ReLU(),
nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1), nn.ReLU(),
nn.ConvTranspose2d(64, 3, 4, stride=2, padding=1), nn.Sigmoid())
def forward(self, frame_sequence: torch.Tensor) -> tuple:
batch, seq_len = frame_sequence.shape[:2]
frames_flat = frame_sequence.flatten(0, 1)
frame_feats = self.frame_encoder(frames_flat).flatten(1).view(batch, seq_len, -1)
lstm_out, _ = self.temporal(frame_feats)
pred_feats = lstm_out[:, -1].view(-1, lstm_out.shape[-1], 1, 1)
pred_feats = F.interpolate(pred_feats, size=(28, 28))
return self.frame_decoder(pred_feats), lstm_out[:, -1]
# Usage example
autoencoder = VideoAutoencoder(embedding_dim=256)
# Detect anomalies by reconstruction error
video_clip = torch.randn(4, 3, 16, 112, 112) # [batch, C, T, H, W]
reconstructed, latent = autoencoder(video_clip)
recon_error = F.mse_loss(reconstructed, video_clip, reduction='none').mean(dim=[1,2,3,4])
print(f"Reconstruction errors: {recon_error}")
print(f"Latent embeddings: {latent.shape}") # [4, 256]