Source code for search_and_replace.correctors

"""Text correction classes."""

from __future__ import annotations

import re
from collections.abc import Sequence
from dataclasses import dataclass
from importlib.resources import files
from pathlib import Path
from typing import TYPE_CHECKING

import hyperscan
from symspellpy import SymSpell, Verbosity

if TYPE_CHECKING:
    from typing import Any

# OCR confusion pairs
OCR_CONFUSIONS: dict[str, str] = {
    "0": "O",
    "1": "l",
    "|": "l",
    "rn": "m",
    "cl": "d",
    "vv": "w",
    "ii": "u",
    "lj": "y",
}


[docs] class SpellCorrector: """Levenshtein-based spelling correction using SymSpell. Args: words: Custom word list. If None and dictionary is None, uses bundled dictionary. dictionary: Path to dictionary file (one word per line). max_distance: Maximum edit distance (default: 2). """ def __init__( self, words: list[str] | None = None, dictionary: Path | str | None = None, max_distance: int = 2, ) -> None: self.max_distance = max_distance self._sym = SymSpell(max_dictionary_edit_distance=max_distance) if words is not None: for word in words: self._sym.create_dictionary_entry(word.lower(), 1) elif dictionary is not None: self._load_file(Path(dictionary)) else: self._load_bundled() def _load_file(self, path: Path) -> None: with path.open(encoding="utf-8") as f: for line in f: word = line.strip() if word: self._sym.create_dictionary_entry(word.lower(), 1) def _load_bundled(self) -> None: data = files("search_and_replace") / "data" / "english_50k.txt" with data.open("r", encoding="utf-8") as f: for line in f: word = line.strip() if word: self._sym.create_dictionary_entry(word.lower(), 1)
[docs] def correct(self, word: str) -> str: """Correct a single word.""" suggestions = self._sym.lookup( word.lower(), Verbosity.CLOSEST, max_edit_distance=self.max_distance ) if suggestions: corrected: str = suggestions[0].term if word and word[0].isupper(): corrected = corrected.capitalize() return corrected return word
[docs] def correct_text(self, text: str) -> str: """Correct all words in text.""" tokens = re.findall(r"\b\w+\b|\W+", text) return "".join(self.correct(t) if t.isalpha() else t for t in tokens)
[docs] class OCRCorrector: """Corrects common OCR character confusions (0/O, 1/l, rn/m, etc.)."""
[docs] def correct(self, text: str) -> str: """Apply OCR confusion corrections.""" for wrong, right in OCR_CONFUSIONS.items(): text = text.replace(wrong, right) return text
@dataclass class _Match: pattern_id: int start: int end: int replacement: str
[docs] class PatternCorrector: """Hyperscan-based pattern matching for fuzzy single-character errors. Args: patterns: List of (word, max_errors) tuples. """ def __init__(self, patterns: Sequence[tuple[str, int]]) -> None: self._db: Any = None self._replacements: dict[int, str] = {} self._compile(patterns) def _compile(self, patterns: Sequence[tuple[str, int]]) -> None: expressions: list[bytes] = [] ids: list[int] = [] flags: list[int] = [] pid = 0 for word, max_errors in patterns: word = word.strip("\ufeff") if len(word) < 3: continue for i in range(1, len(word) - 1): prefix = re.escape(word[:i]) suffix = re.escape(word[i + 1 :]) regex = prefix + rf".{{0,{max_errors}}}\??[\r\n]*" + suffix expressions.append(regex.encode("utf-8")) ids.append(pid) flags.append( hyperscan.HS_FLAG_CASELESS | hyperscan.HS_FLAG_DOTALL | hyperscan.HS_FLAG_SOM_LEFTMOST ) self._replacements[pid] = word pid += 1 if expressions: self._db = hyperscan.Database() self._db.compile(expressions=expressions, ids=ids, flags=flags) def _scan(self, text_bytes: bytes) -> list[_Match]: matches: list[_Match] = [] def on_match(pid: int, start: int, end: int, _flags: int, ctx: list[_Match]) -> None: ctx.append(_Match(pid, start, end, self._replacements[pid])) if self._db: self._db.scan(text_bytes, match_event_handler=on_match, context=matches) return matches def _byte_to_char(self, text: str, matches: list[_Match]) -> list[_Match]: mapping: list[int] = [] for i, char in enumerate(text): mapping.extend([i] * len(char.encode("utf-8"))) mapping.append(len(text)) return [ _Match( m.pattern_id, mapping[m.start], mapping[min(m.end, len(mapping) - 1)], m.replacement ) for m in matches ] def _resolve_overlaps(self, matches: list[_Match]) -> list[_Match]: if not matches: return [] sorted_m = sorted(matches, key=lambda m: (m.start, -(m.end - m.start))) result: list[_Match] = [] last_end = -1 for m in sorted_m: if m.start >= last_end: result.append(m) last_end = m.end return result
[docs] def correct(self, text: str) -> str: """Apply pattern corrections.""" if not self._db or not text: return text matches = self._scan(text.encode("utf-8")) if not matches: return text matches = self._byte_to_char(text, matches) matches = self._resolve_overlaps(matches) for m in sorted(matches, key=lambda x: -x.start): text = text[: m.start] + m.replacement + text[m.end :] return text
[docs] class Replacer: """Direct string replacement.""" def __init__(self, replacements: Sequence[tuple[str, str]]) -> None: self._patterns: list[tuple[re.Pattern[str], str]] = [] for search, replace in replacements: self._patterns.append((re.compile(re.escape(search), re.UNICODE), replace))
[docs] def correct(self, text: str) -> str: """Apply all replacements.""" for pattern, replacement in self._patterns: text = pattern.sub(replacement, text) return text