Source code for tabula_rasa.models.encoders
"""Neural encoders for statistical sketches and text."""
import torch
import torch.nn as nn
[docs]
class StatisticalEncoder(nn.Module):
"""
Encode statistical sketch into neural representation.
Handles variable-length column lists via attention pooling.
"""
[docs]
def __init__(self, hidden_dim: int = 256, output_dim: int = 768):
"""
Initialize the statistical encoder.
Args:
hidden_dim: Hidden dimension for column encoders
output_dim: Output dimension of the encoded sketch
"""
super().__init__()
self.hidden_dim = hidden_dim
self.output_dim = output_dim
# Column-level encoders
self.numeric_encoder = nn.Sequential(
nn.Linear(15, hidden_dim), # 15 numeric features per column
nn.LayerNorm(hidden_dim),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_dim, output_dim),
)
# Attention pooling over columns
self.attention = nn.MultiheadAttention(output_dim, num_heads=8, batch_first=True)
# Global table encoder
self.table_encoder = nn.Sequential(
nn.Linear(output_dim, output_dim),
nn.LayerNorm(output_dim),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(output_dim, output_dim),
)
[docs]
def forward(self, sketch: dict) -> torch.Tensor:
"""
Encode sketch to fixed-size vector.
Args:
sketch: Statistical sketch dictionary
Returns:
Tensor of shape (output_dim,) representing the encoded table
"""
column_embeddings = []
# Encode each numeric column
for _col_name, col_stats in sketch["columns"].items():
if col_stats["type"] == "numeric" and "error" not in col_stats:
# Pack statistics into feature vector
features = torch.tensor(
[
col_stats["mean"],
col_stats["std"],
col_stats["min"],
col_stats["max"],
col_stats["quantiles"][0.25],
col_stats["quantiles"][0.5],
col_stats["quantiles"][0.75],
col_stats["skewness"],
col_stats["kurtosis"],
col_stats["missing_rate"],
col_stats["outlier_rate"],
col_stats["n_unique"] / max(sketch["n_rows"], 1), # Normalized
1.0 if col_stats["distribution_hint"] == "normal" else 0.0,
1.0 if col_stats["distribution_hint"] == "right_skewed" else 0.0,
1.0 if col_stats["distribution_hint"] == "heavy_tailed" else 0.0,
],
dtype=torch.float32,
)
embedding = self.numeric_encoder(features)
column_embeddings.append(embedding)
if not column_embeddings:
# No numeric columns - return zero vector
return torch.zeros(self.output_dim)
# Stack column embeddings
column_stack = torch.stack(column_embeddings).unsqueeze(0) # (1, n_cols, dim)
# Attention pooling
attended, _ = self.attention(column_stack, column_stack, column_stack)
pooled = attended.mean(dim=1).squeeze(0) # (dim,)
# Final encoding
return self.table_encoder(pooled)