Source code for tabula_rasa.training.trainer

"""Production training pipeline with best practices."""

import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader

from ..models.qa_model import ProductionTableQA
from .dataset import TableQADataset


[docs] class ProductionTrainer: """Production training with best practices."""
[docs] def __init__( self, model: ProductionTableQA, df: pd.DataFrame, sketch: dict, lr: float = 1e-4, batch_size: int = 16, device: str = "cpu", ): """ Initialize the trainer. Args: model: ProductionTableQA model to train df: Training DataFrame sketch: Statistical sketch of the DataFrame lr: Learning rate batch_size: Batch size for training device: Device to train on ('cpu' or 'cuda') """ self.model = model.to(device) self.df = df self.sketch = sketch self.batch_size = batch_size self.device = device # Optimizer with weight decay self.optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01) # Learning rate scheduler self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( self.optimizer, mode="min", factor=0.5, patience=3 ) # Loss weights self.answer_weight = 1.0 self.confidence_weight = 0.1 self.query_type_weight = 0.2
[docs] def train( self, n_epochs: int = 10, n_train_samples: int = 1000, n_val_samples: int = 200 ) -> tuple[float, dict]: """ Training loop with validation. Args: n_epochs: Number of training epochs n_train_samples: Number of training samples to generate n_val_samples: Number of validation samples to generate Returns: Tuple of (best_val_loss, history_dict) """ # Create datasets train_dataset = TableQADataset(self.df, self.sketch, n_train_samples) val_dataset = TableQADataset(self.df, self.sketch, n_val_samples) if len(train_dataset) == 0 or len(val_dataset) == 0: print("Warning: No valid samples generated") return float("inf"), {} train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=self.batch_size) best_val_loss = float("inf") history = {"train_loss": [], "val_loss": [], "val_mae": [], "val_mape": []} for epoch in range(n_epochs): # Training train_loss = self._train_epoch(train_loader) # Validation val_loss, val_metrics = self._validate(val_loader) # Scheduler step self.scheduler.step(val_loss) # Store history history["train_loss"].append(train_loss) history["val_loss"].append(val_loss) history["val_mae"].append(val_metrics["mae"]) history["val_mape"].append(val_metrics["mape"]) # Logging if epoch % 2 == 0: print(f"Epoch {epoch}/{n_epochs}") print(f" Train Loss: {train_loss:.4f}") print(f" Val Loss: {val_loss:.4f}") print(f" Val MAE: {val_metrics['mae']:.4f}") print(f" Val MAPE: {val_metrics['mape']:.2%}") print(f" Query Type Acc: {val_metrics['query_type_acc']:.2%}") # Save best model if val_loss < best_val_loss: best_val_loss = val_loss return best_val_loss, history
def _train_epoch(self, dataloader: DataLoader) -> float: """Single training epoch.""" self.model.train() total_loss = 0 for batch in dataloader: questions = batch["question"] true_answers = batch["answer"] query_types = batch["query_type"] # Forward pass for each question in batch batch_loss = 0 for q, true_ans, qt in zip(questions, true_answers, query_types): output = self.model(q, self.sketch) # Convert tensor to float if needed true_ans_val = true_ans.item() if torch.is_tensor(true_ans) else float(true_ans) # Skip if answer is invalid if np.isnan(true_ans_val) or np.isinf(true_ans_val): continue # Answer loss (MSE) with normalization true_ans_tensor = torch.tensor(true_ans_val, dtype=torch.float32).to(self.device) answer_loss = F.mse_loss(output["answer"], true_ans_tensor) / ( abs(true_ans_val) + 1.0 ) # Confidence calibration with torch.no_grad(): error = abs(output["answer"].item() - true_ans_val) target_conf = max(0, 1 - error / (abs(true_ans_val) + 1e-6)) conf_loss = F.mse_loss( output["confidence"], torch.tensor(target_conf, dtype=torch.float32).to(self.device), ) # Query type classification qt_loss = F.cross_entropy( output["query_type_logits"].unsqueeze(0), qt.unsqueeze(0).to(self.device) ) # Combined loss loss = ( self.answer_weight * answer_loss + self.confidence_weight * conf_loss + self.query_type_weight * qt_loss ) batch_loss += loss if batch_loss > 0: # Backward self.optimizer.zero_grad() batch_loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) self.optimizer.step() total_loss += batch_loss.item() return total_loss / max(len(dataloader), 1) def _validate(self, dataloader: DataLoader) -> tuple[float, dict]: """Validation.""" self.model.eval() total_loss = 0 errors = [] true_values = [] query_type_correct = 0 query_type_total = 0 with torch.no_grad(): for batch in dataloader: questions = batch["question"] true_answers = batch["answer"] query_types = batch["query_type"] for q, true_ans, qt in zip(questions, true_answers, query_types): # Convert tensor to float if needed true_ans_val = true_ans.item() if torch.is_tensor(true_ans) else float(true_ans) # Skip invalid answers if np.isnan(true_ans_val) or np.isinf(true_ans_val): continue output = self.model(q, self.sketch) # Loss true_ans_tensor = torch.tensor(true_ans_val, dtype=torch.float32).to( self.device ) loss = F.mse_loss(output["answer"], true_ans_tensor) total_loss += loss.item() # Metrics pred_val = output["answer"].item() if not (np.isnan(pred_val) or np.isinf(pred_val)): error = abs(pred_val - true_ans_val) errors.append(error) true_values.append(abs(true_ans_val)) # Query type accuracy pred_qt = output["query_type_logits"].argmax().item() if pred_qt == qt.item(): query_type_correct += 1 query_type_total += 1 mae = np.mean(errors) if errors else 0.0 # Calculate MAPE carefully if errors and true_values: mape = np.mean([e / max(tv, 1e-6) for e, tv in zip(errors, true_values)]) else: mape = 0.0 metrics = { "mae": mae, "mape": mape, "query_type_acc": query_type_correct / max(query_type_total, 1), } return total_loss / max(len(dataloader), 1), metrics