Source code for tabula_rasa.core.executor
"""Query execution on actual data for training and validation."""
import re
from dataclasses import dataclass
from typing import Any
import pandas as pd
[docs]
@dataclass
class Query:
"""Structured query representation."""
query_type: str # 'aggregate', 'filter', 'conditional', 'join'
target_column: str | None = None
aggregation: str | None = None # 'mean', 'sum', 'count', 'std', 'percentile'
condition: str | None = None
percentile: float | None = None
group_by: str | None = None
[docs]
class AdvancedQueryExecutor:
"""
Execute queries on actual data.
Supports: aggregations, filters, conditionals, group-by
"""
[docs]
def __init__(self, df: pd.DataFrame):
"""
Initialize query executor.
Args:
df: DataFrame to execute queries against
"""
self.df = df
[docs]
def execute(self, query: Query) -> Any:
"""
Execute structured query.
Args:
query: Query object specifying the operation
Returns:
Query result (float for aggregates, int for counts, dict for group-by)
Raises:
ValueError: If query type is unknown
"""
if query.query_type == "aggregate":
return self._execute_aggregate(query)
elif query.query_type == "conditional":
return self._execute_conditional(query)
elif query.query_type == "filter":
return self._execute_filter(query)
elif query.query_type == "group_by":
return self._execute_group_by(query)
else:
raise ValueError(f"Unknown query type: {query.query_type}")
def _execute_aggregate(self, query: Query) -> float:
"""Execute aggregation query."""
col_data = self.df[query.target_column]
if query.aggregation == "mean":
return float(col_data.mean())
elif query.aggregation == "sum":
return float(col_data.sum())
elif query.aggregation == "count":
return float(len(col_data))
elif query.aggregation == "std":
return float(col_data.std())
elif query.aggregation == "min":
return float(col_data.min())
elif query.aggregation == "max":
return float(col_data.max())
elif query.aggregation == "percentile":
return float(col_data.quantile(query.percentile))
else:
raise ValueError(f"Unknown aggregation: {query.aggregation}")
def _execute_conditional(self, query: Query) -> float:
"""Execute conditional aggregation (e.g., mean of X where Y > 10)."""
mask = self._parse_condition(query.condition)
filtered_data = self.df[mask][query.target_column]
if len(filtered_data) == 0:
return float("nan")
if query.aggregation == "mean":
return float(filtered_data.mean())
elif query.aggregation == "count":
return float(len(filtered_data))
elif query.aggregation == "std":
return float(filtered_data.std())
else:
raise ValueError(f"Unknown aggregation: {query.aggregation}")
def _execute_filter(self, query: Query) -> int:
"""Execute filter query (count matching rows)."""
mask = self._parse_condition(query.condition)
return int(mask.sum())
def _execute_group_by(self, query: Query) -> dict:
"""Execute group-by aggregation."""
grouped = self.df.groupby(query.group_by)[query.target_column]
if query.aggregation == "mean":
return grouped.mean().to_dict()
elif query.aggregation == "count":
return grouped.count().to_dict()
elif query.aggregation == "sum":
return grouped.sum().to_dict()
else:
raise ValueError(f"Unknown aggregation: {query.aggregation}")
def _parse_condition(self, condition: str) -> pd.Series:
"""
Parse condition string into boolean mask.
Args:
condition: String like "column > 10" or "price <= 100.5"
Returns:
Boolean Series mask
Raises:
ValueError: If condition cannot be parsed
"""
# Support operators: >, <, >=, <=, ==, !=
pattern = r"(\w+)\s*(>|<|>=|<=|==|!=)\s*([0-9.]+)"
match = re.match(pattern, condition.strip())
if not match:
raise ValueError(f"Cannot parse condition: {condition}")
col, op, val = match.groups()
val = float(val)
if op == ">":
return self.df[col] > val
elif op == "<":
return self.df[col] < val
elif op == ">=":
return self.df[col] >= val
elif op == "<=":
return self.df[col] <= val
elif op == "==":
return self.df[col] == val
elif op == "!=":
return self.df[col] != val
else:
raise ValueError(f"Unknown operator: {op}")