Source code for tabula_rasa.models.qa_model
"""Production Table QA model with T5 backbone."""
import torch
import torch.nn as nn
from transformers import T5EncoderModel, T5Tokenizer
from .encoders import StatisticalEncoder
[docs]
class ProductionTableQA(nn.Module):
"""
Production Table QA model with T5 backbone.
Combines:
- Pretrained language understanding (T5)
- Statistical table knowledge (StatEncoder)
- Execution grounding (trained to match executor)
"""
[docs]
def __init__(self, model_name: str = "t5-small", stat_dim: int = 768):
"""
Initialize the Table QA model.
Args:
model_name: Pretrained T5 model name
stat_dim: Dimension for statistical encoder output
"""
super().__init__()
# T5 encoder for question understanding
self.tokenizer = T5Tokenizer.from_pretrained(model_name)
self.text_encoder = T5EncoderModel.from_pretrained(model_name)
# Statistical sketch encoder
self.stat_encoder = StatisticalEncoder(output_dim=stat_dim)
# Fusion layer (combine text + stats)
text_dim = self.text_encoder.config.d_model
self.fusion = nn.Sequential(
nn.Linear(text_dim + stat_dim, stat_dim),
nn.LayerNorm(stat_dim),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(stat_dim, stat_dim),
nn.LayerNorm(stat_dim),
nn.ReLU(),
)
# Output heads
self.numeric_head = nn.Sequential(
nn.Linear(stat_dim, 256), nn.ReLU(), nn.Dropout(0.1), nn.Linear(256, 1)
)
self.confidence_head = nn.Sequential(
nn.Linear(stat_dim, 128), nn.ReLU(), nn.Linear(128, 1), nn.Sigmoid()
)
# Query type classifier (helps model route questions)
self.query_type_head = nn.Sequential(
nn.Linear(stat_dim, 128),
nn.ReLU(),
nn.Linear(128, 4), # aggregate, conditional, filter, group_by
)
[docs]
def forward(
self, question: str, sketch: dict, return_features: bool = False
) -> dict[str, torch.Tensor]:
"""
Forward pass.
Args:
question: Natural language question
sketch: Statistical sketch dictionary
return_features: Whether to return intermediate features
Returns:
Dictionary with keys:
- answer: Predicted numerical answer
- confidence: Confidence score [0, 1]
- query_type_logits: Query type classification logits
- features (optional): Fused representation
"""
# Encode question with T5
inputs = self.tokenizer(
question, return_tensors="pt", padding=True, truncation=True, max_length=128
)
text_outputs = self.text_encoder(**inputs)
text_repr = text_outputs.last_hidden_state.mean(dim=1) # Mean pooling
# Encode table statistics
stat_repr = self.stat_encoder(sketch).unsqueeze(0) # Add batch dim
# Fuse
combined = torch.cat([text_repr, stat_repr], dim=-1)
fused = self.fusion(combined)
# Predictions
numeric_answer = self.numeric_head(fused).squeeze()
confidence = self.confidence_head(fused).squeeze()
query_type_logits = self.query_type_head(fused).squeeze()
output = {
"answer": numeric_answer,
"confidence": confidence,
"query_type_logits": query_type_logits,
}
if return_features:
output["features"] = fused
return output