"""
Plotting utilities for statistical visualizations.
Creates publication-quality plots for insights.
"""
from __future__ import annotations
from pathlib import Path
from typing import Any, Literal
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from statqa.metadata.schema import Variable
[docs]
class PlotFactory:
"""
Factory for creating statistical visualizations.
Args:
style: Seaborn style ('whitegrid', 'darkgrid', 'white', 'dark', 'ticks')
context: Seaborn context ('paper', 'notebook', 'talk', 'poster')
figsize: Default figure size (width, height)
dpi: DPI for rasterized output
"""
def __init__(
self,
style: Literal["whitegrid", "darkgrid", "white", "dark", "ticks"] = "whitegrid",
context: Literal["paper", "notebook", "talk", "poster"] = "notebook",
figsize: tuple[int, int] = (8, 6),
dpi: int = 100,
) -> None:
self.figsize = figsize
self.dpi = dpi
sns.set_style(style)
sns.set_context(context)
[docs]
def plot_univariate(
self,
data: pd.Series,
variable: Variable,
output_path: str | Path | None = None,
return_metadata: bool = False,
) -> Figure | tuple[Figure, dict[str, Any]]:
"""
Create univariate plot (histogram or bar chart).
Args:
data: Data series
variable: Variable metadata
output_path: Optional path to save plot
return_metadata: Whether to return plot metadata alongside figure
Returns:
Matplotlib figure, or tuple of (figure, metadata) if return_metadata=True
"""
fig, ax = plt.subplots(figsize=self.figsize, dpi=self.dpi)
# Clean data
clean_data = self._clean_data(data, variable)
if variable.is_numeric():
plot_type = "histogram"
self._plot_numeric_distribution(clean_data, variable, ax)
elif variable.is_categorical():
plot_type = "bar_chart"
self._plot_categorical_distribution(clean_data, variable, ax)
else:
plot_type = "unknown"
ax.set_title(f"Distribution of {variable.label}")
if output_path:
fig.savefig(output_path, bbox_inches="tight", dpi=self.dpi)
if return_metadata:
metadata = self._generate_univariate_metadata(
clean_data, variable, plot_type, output_path
)
return fig, metadata
return fig
[docs]
def plot_bivariate(
self,
data: pd.DataFrame,
var1: Variable,
var2: Variable,
output_path: str | Path | None = None,
return_metadata: bool = False,
) -> Figure | tuple[Figure, dict[str, Any]]:
"""
Create bivariate plot (scatter, box, or heatmap).
Args:
data: DataFrame with both variables
var1: First variable
var2: Second variable
output_path: Optional path to save plot
return_metadata: Whether to return plot metadata alongside figure
Returns:
Matplotlib figure, or tuple of (figure, metadata) if return_metadata=True
"""
fig, ax = plt.subplots(figsize=self.figsize, dpi=self.dpi)
# Clean data
subset = data[[var1.name, var2.name]].copy()
subset = self._clean_dataframe(subset, [var1, var2])
subset = subset.dropna()
if var1.is_numeric() and var2.is_numeric():
plot_type = "scatter"
self._plot_scatter(subset, var1, var2, ax)
elif var1.is_categorical() and var2.is_numeric():
plot_type = "boxplot"
self._plot_boxplot(subset, var1, var2, ax)
elif var1.is_categorical() and var2.is_categorical():
plot_type = "heatmap"
self._plot_heatmap(subset, var1, var2, ax)
else:
plot_type = "unknown"
if output_path:
fig.savefig(output_path, bbox_inches="tight", dpi=self.dpi)
if return_metadata:
metadata = self._generate_bivariate_metadata(subset, var1, var2, plot_type, output_path)
return fig, metadata
return fig
[docs]
def plot_temporal(
self,
data: pd.DataFrame,
time_var: Variable,
value_var: Variable,
group_var: Variable | None = None,
output_path: str | Path | None = None,
) -> Figure:
"""
Create temporal trend plot.
Args:
data: DataFrame with time and value
time_var: Time variable
value_var: Value variable
group_var: Optional grouping variable
output_path: Optional path to save plot
Returns:
Matplotlib figure
"""
fig, ax = plt.subplots(figsize=self.figsize, dpi=self.dpi)
# Clean and sort
cols = [time_var.name, value_var.name]
if group_var:
cols.append(group_var.name)
subset = data[cols].copy()
subset = self._clean_dataframe(
subset, [time_var, value_var] + ([group_var] if group_var else [])
)
subset = subset.dropna().sort_values(time_var.name)
if group_var:
# Grouped line plot
for group_name, group_data in subset.groupby(group_var.name):
label = (
group_var.valid_values.get(str(group_name), str(group_name))
if group_var.valid_values
else str(group_name)
)
ax.plot(
group_data[time_var.name],
group_data[value_var.name],
marker="o",
label=label,
)
ax.legend()
else:
# Simple line plot
ax.plot(subset[time_var.name], subset[value_var.name], marker="o", linewidth=2)
ax.set_xlabel(time_var.label)
ax.set_ylabel(value_var.label)
ax.set_title(f"{value_var.label} over {time_var.label}")
ax.grid(True, alpha=0.3)
if output_path:
fig.savefig(output_path, bbox_inches="tight", dpi=self.dpi)
return fig
def _clean_data(self, data: pd.Series, variable: Variable) -> pd.Series:
"""Clean missing values from series."""
clean = data.copy()
if variable.missing_values:
clean = clean.replace(dict.fromkeys(variable.missing_values, np.nan))
return clean.dropna()
def _clean_dataframe(self, data: pd.DataFrame, variables: list[Variable]) -> pd.DataFrame:
"""Clean missing values from dataframe."""
clean = data.copy()
for var in variables:
if var.missing_values:
clean[var.name] = clean[var.name].replace(dict.fromkeys(var.missing_values, np.nan))
return clean
def _plot_numeric_distribution(self, data: pd.Series, variable: Variable, ax: Axes) -> None:
"""Plot histogram/KDE for numeric variable."""
n_unique = data.nunique()
if n_unique > 50:
# Use KDE for continuous data
sns.histplot(data, kde=True, ax=ax, stat="density")
ax.set_ylabel("Density")
else:
# Use count histogram for discrete data
sns.histplot(data, kde=False, ax=ax, bins=min(n_unique, 30))
ax.set_ylabel("Count")
ax.set_xlabel(variable.label)
# Add mean line
mean = data.mean()
ax.axvline(mean, color="red", linestyle="--", label=f"Mean: {mean:.2f}", alpha=0.7)
ax.legend()
def _plot_categorical_distribution(self, data: pd.Series, variable: Variable, ax: Axes) -> None:
"""Plot bar chart for categorical variable."""
counts = data.value_counts()
# Map to labels if available
if variable.valid_values:
counts.index = counts.index.map(lambda x: variable.valid_values.get(x, str(x)))
sns.barplot(x=counts.index, y=counts.values, ax=ax, palette="viridis")
ax.set_xlabel(variable.label)
ax.set_ylabel("Count")
# Rotate labels if many categories
if len(counts) > 5:
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")
def _plot_scatter(self, data: pd.DataFrame, var1: Variable, var2: Variable, ax: Axes) -> None:
"""Plot scatter plot with regression line."""
sns.regplot(
x=var1.name,
y=var2.name,
data=data,
ax=ax,
scatter_kws={"alpha": 0.5},
line_kws={"color": "red"},
)
ax.set_xlabel(var1.label)
ax.set_ylabel(var2.label)
ax.set_title(f"{var1.label} vs {var2.label}")
def _plot_boxplot(
self, data: pd.DataFrame, var_cat: Variable, var_num: Variable, ax: Axes
) -> None:
"""Plot box plot for categorical vs numeric."""
# Map categories to labels
plot_data = data.copy()
if var_cat.valid_values:
plot_data[var_cat.name] = plot_data[var_cat.name].map(
lambda x: var_cat.valid_values.get(x, str(x))
)
sns.boxplot(x=var_cat.name, y=var_num.name, data=plot_data, ax=ax, palette="Set2")
ax.set_xlabel(var_cat.label)
ax.set_ylabel(var_num.label)
ax.set_title(f"{var_num.label} by {var_cat.label}")
if len(plot_data[var_cat.name].unique()) > 5:
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")
def _plot_heatmap(self, data: pd.DataFrame, var1: Variable, var2: Variable, ax: Axes) -> None:
"""Plot heatmap for categorical vs categorical."""
# Create contingency table
contingency = pd.crosstab(data[var1.name], data[var2.name])
# Map to labels
if var1.valid_values:
contingency.index = contingency.index.map(lambda x: var1.valid_values.get(x, str(x)))
if var2.valid_values:
contingency.columns = contingency.columns.map(
lambda x: var2.valid_values.get(x, str(x))
)
sns.heatmap(contingency, annot=True, fmt="d", cmap="YlOrRd", ax=ax)
ax.set_xlabel(var2.label)
ax.set_ylabel(var1.label)
ax.set_title(f"{var1.label} vs {var2.label}")
def _generate_univariate_metadata(
self,
data: pd.Series,
variable: Variable,
plot_type: str,
output_path: str | Path | None,
) -> dict:
"""Generate metadata for univariate plots."""
metadata = {
"plot_type": plot_type,
"caption": self._generate_univariate_caption(data, variable),
"alt_text": self._generate_univariate_alt_text(data, variable, plot_type),
"visual_elements": self._extract_univariate_visual_elements(data, variable, plot_type),
}
if output_path:
metadata["primary_plot"] = str(output_path)
metadata["generation_code"] = (
f"plot_factory.plot_univariate(data['{variable.name}'], "
f"{variable.name}_var, '{output_path}')"
)
return metadata
def _generate_bivariate_metadata(
self,
data: pd.DataFrame,
var1: Variable,
var2: Variable,
plot_type: str,
output_path: str | Path | None,
) -> dict:
"""Generate metadata for bivariate plots."""
metadata = {
"plot_type": plot_type,
"caption": self._generate_bivariate_caption(data, var1, var2, plot_type),
"alt_text": self._generate_bivariate_alt_text(data, var1, var2, plot_type),
"visual_elements": self._extract_bivariate_visual_elements(data, var1, var2, plot_type),
}
if output_path:
metadata["primary_plot"] = str(output_path)
metadata["generation_code"] = (
f"plot_factory.plot_bivariate(data, {var1.name}_var, "
f"{var2.name}_var, '{output_path}')"
)
return metadata
def _generate_univariate_caption(self, data: pd.Series, variable: Variable) -> str:
"""Generate descriptive caption for univariate plots."""
if variable.is_numeric():
mean_val = data.mean()
std_val = data.std()
n_obs = len(data)
# Detect distribution shape
skewness = data.skew()
if abs(skewness) < 0.5:
shape = "approximately normal distribution"
elif skewness > 0.5:
shape = "right-skewed distribution"
else:
shape = "left-skewed distribution"
return (
f"Histogram showing {variable.label.lower()} distribution "
f"with mean={mean_val:.2f} and std={std_val:.2f} "
f"(N={n_obs}). The data shows a {shape}."
)
else:
counts = data.value_counts()
mode = counts.idxmax()
mode_pct = (counts.max() / len(data)) * 100
n_categories = len(counts)
if variable.valid_values and mode in variable.valid_values:
mode_label = variable.valid_values[mode]
else:
mode_label = str(mode)
return (
f"Bar chart showing {variable.label.lower()} frequencies "
f"across {n_categories} categories (N={len(data)}). "
f"Most common category is '{mode_label}' ({mode_pct:.1f}%)."
)
def _generate_bivariate_caption(
self, data: pd.DataFrame, var1: Variable, var2: Variable, plot_type: str
) -> str:
"""Generate descriptive caption for bivariate plots."""
if plot_type == "scatter":
correlation = data.corr().iloc[0, 1]
if abs(correlation) < 0.3:
strength = "weak"
elif abs(correlation) < 0.7:
strength = "moderate"
else:
strength = "strong"
direction = "positive" if correlation > 0 else "negative"
return (
f"Scatter plot showing the relationship between {var1.label} and "
f"{var2.label} (N={len(data)}). Shows a {strength} {direction} "
f"correlation (r={correlation:.2f}) with regression line."
)
elif plot_type == "boxplot":
n_groups = data[var1.name].nunique()
return (
f"Box plots comparing {var2.label} across {n_groups} "
f"{var1.label.lower()} groups (N={len(data)}). Shows "
f"distribution differences and potential outliers."
)
elif plot_type == "heatmap":
n_var1 = data[var1.name].nunique()
n_var2 = data[var2.name].nunique()
return (
f"Heatmap showing the relationship between {var1.label} "
f"({n_var1} categories) and {var2.label} ({n_var2} categories). "
f"Color intensity represents frequency counts."
)
return f"Bivariate plot showing {var1.label} vs {var2.label}"
def _generate_univariate_alt_text(
self, data: pd.Series, variable: Variable, plot_type: str
) -> str:
"""Generate accessibility alt-text for univariate plots."""
if plot_type == "histogram":
return (
f"Histogram chart with {variable.label.lower()} values on x-axis "
f"and frequency density on y-axis, showing distribution shape "
f"with {len(data)} observations."
)
elif plot_type == "bar_chart":
n_categories = data.nunique()
return (
f"Bar chart with {n_categories} categories of {variable.label.lower()} "
f"on x-axis and count frequencies on y-axis."
)
return f"Chart showing {variable.label.lower()} distribution"
def _generate_bivariate_alt_text(
self, data: pd.DataFrame, var1: Variable, var2: Variable, plot_type: str
) -> str:
"""Generate accessibility alt-text for bivariate plots."""
if plot_type == "scatter":
return (
f"Scatter plot with {var1.label} on x-axis and {var2.label} "
f"on y-axis, showing {len(data)} data points with regression line."
)
elif plot_type == "boxplot":
n_groups = data[var1.name].nunique()
return (
f"Box plot chart with {n_groups} {var1.label.lower()} categories "
f"on x-axis and {var2.label} values on y-axis."
)
elif plot_type == "heatmap":
return (
f"Heatmap with {var1.label} categories on y-axis and "
f"{var2.label} categories on x-axis, using color intensity "
f"to show frequency counts."
)
return f"Chart showing relationship between {var1.label} and {var2.label}"
def _extract_univariate_visual_elements(
self, data: pd.Series, variable: Variable, plot_type: str
) -> dict:
"""Extract visual elements description for univariate plots."""
elements = {
"chart_type": plot_type,
"x_axis": variable.label,
"key_features": [],
"colors": [],
"annotations": [],
}
if plot_type == "histogram":
elements["y_axis"] = "Density"
elements["colors"] = ["blue bars", "red mean line"]
elements["key_features"] = ["distribution shape", "mean line"]
elements["annotations"] = [f"Mean: {data.mean():.2f}"]
# Add distribution characteristics
if abs(data.skew()) > 0.5:
elements["key_features"].append("skewed distribution")
elif plot_type == "bar_chart":
elements["y_axis"] = "Count"
elements["colors"] = ["viridis color palette"]
elements["key_features"] = ["frequency bars"]
return elements
def _extract_bivariate_visual_elements(
self, data: pd.DataFrame, var1: Variable, var2: Variable, plot_type: str
) -> dict:
"""Extract visual elements description for bivariate plots."""
elements = {
"chart_type": plot_type,
"x_axis": var1.label if plot_type != "heatmap" else var2.label,
"y_axis": var2.label if plot_type != "heatmap" else var1.label,
"key_features": [],
"colors": [],
"annotations": [],
}
if plot_type == "scatter":
elements["colors"] = ["blue points", "red regression line"]
elements["key_features"] = ["data points", "regression line", "trend"]
elif plot_type == "boxplot":
elements["colors"] = ["Set2 color palette"]
elements["key_features"] = ["boxes", "whiskers", "outliers", "medians"]
elif plot_type == "heatmap":
elements["colors"] = ["YlOrRd color map"]
elements["key_features"] = ["color intensity", "frequency counts"]
elements["annotations"] = ["count values in cells"]
return elements
[docs]
def close_all(self) -> None:
"""Close all open figures."""
plt.close("all")