"""
Normalization pipeline for glossary entries
- Trim and whitespace cleaning
- Lowercase (preserving Turkish characters)
- Plural to singular conversion
- OCR/spelling error correction
- Fuzzy matching for variant grouping
"""

import re
import json
from typing import List, Dict, Tuple, Optional
from collections import defaultdict
from tqdm import tqdm

try:
    from rapidfuzz import fuzz, process
except ImportError:
    fuzz = None
    process = None

from .utils import (
    fix_ocr_errors, 
    normalize_turkish, 
    to_lowercase_turkish,
    is_valid_glossary_entry,
    logger
)


# Turkish plural suffixes
TURKISH_PLURAL_SUFFIXES = [
    'lar', 'ler', 'ları', 'leri', 'lara', 'lere',
    'lardan', 'lerden', 'larda', 'lerde'
]

# English plural rules
ENGLISH_PLURAL_RULES = [
    (r'ies$', 'y'),      # batteries -> battery
    (r'ves$', 'f'),      # valves -> valve
    (r'oes$', 'o'),      # heroes -> hero
    (r'ses$', 's'),      # gases -> gas
    (r'xes$', 'x'),      # boxes -> box
    (r'ches$', 'ch'),    # switches -> switch
    (r'shes$', 'sh'),    # bushes -> bush
    (r's$', ''),         # pumps -> pump
]


def trim_and_clean(text: str) -> str:
    """Trim whitespace and normalize spacing"""
    # Remove leading/trailing whitespace
    text = text.strip()
    # Normalize multiple spaces to single space
    text = re.sub(r'\s+', ' ', text)
    # Remove spaces around punctuation
    text = re.sub(r'\s*([,.])\s*', r'\1 ', text)
    text = re.sub(r'\s+$', '', text)
    return text


def to_singular_english(word: str) -> str:
    """Convert English plural to singular"""
    word_lower = word.lower()
    
    # Skip short words
    if len(word_lower) <= 3:
        return word
    
    # Apply rules
    for pattern, replacement in ENGLISH_PLURAL_RULES:
        if re.search(pattern, word_lower):
            result = re.sub(pattern, replacement, word_lower)
            # Preserve original case
            if word.isupper():
                return result.upper()
            elif word[0].isupper():
                return result.capitalize()
            return result
    
    return word


def to_singular_turkish(word: str) -> str:
    """Convert Turkish plural to singular (basic)"""
    word_lower = word.lower()
    
    for suffix in TURKISH_PLURAL_SUFFIXES:
        if word_lower.endswith(suffix) and len(word_lower) > len(suffix) + 2:
            result = word_lower[:-len(suffix)]
            if word.isupper():
                return result.upper()
            elif word[0].isupper():
                return result.capitalize()
            return result
    
    return word


def normalize_entry(source: str, target: str) -> Tuple[str, str]:
    """
    Normalize a single glossary entry
    Returns: (normalized_source, normalized_target)
    """
    # 1. Trim and clean
    source = trim_and_clean(source)
    target = trim_and_clean(target)
    
    # 2. Fix OCR errors
    source = fix_ocr_errors(source)
    target = fix_ocr_errors(target)
    
    # 3. Normalize Turkish characters
    target = normalize_turkish(target)
    
    return source, target


class GlossaryNormalizer:
    """Main normalizer class for glossary entries"""
    
    def __init__(self, fuzzy_threshold: float = 0.85):
        self.fuzzy_threshold = fuzzy_threshold
        self.entries = []
        self.normalized_entries = []
        self.merged_count = 0
        self.stats = {
            'total_input': 0,
            'after_normalization': 0,
            'duplicates_removed': 0,
            'variants_merged': 0,
            'invalid_removed': 0
        }
    
    def load_tsv(self, filepath: str) -> int:
        """Load glossary from TSV file"""
        self.entries = []
        
        with open(filepath, 'r', encoding='utf-8') as f:
            for line in f:
                line = line.strip()
                if not line or '\t' not in line:
                    continue
                
                parts = line.split('\t')
                if len(parts) >= 2:
                    source, target = parts[0], parts[1]
                    self.entries.append({
                        'source': source,
                        'target': target,
                        'original_source': source,
                        'original_target': target
                    })
        
        self.stats['total_input'] = len(self.entries)
        logger.info(f"Loaded {len(self.entries)} entries from {filepath}")
        return len(self.entries)
    
    def normalize_all(self) -> List[Dict]:
        """Apply normalization to all entries"""
        normalized = []
        seen = set()
        invalid_count = 0
        
        for entry in tqdm(self.entries, desc="Normalizing"):
            source, target = normalize_entry(entry['source'], entry['target'])
            
            # Validate entry
            is_valid, reason = is_valid_glossary_entry(source, target)
            
            if not is_valid:
                entry['invalid'] = True
                entry['invalid_reason'] = reason
                invalid_count += 1
                continue
            
            # Check for duplicates
            key = (source.lower(), target.lower())
            if key in seen:
                continue
            seen.add(key)
            
            normalized.append({
                'source': source,
                'target': target,
                'original_source': entry['original_source'],
                'original_target': entry['original_target'],
                'normalized': source != entry['original_source'] or target != entry['original_target']
            })
        
        self.normalized_entries = normalized
        self.stats['after_normalization'] = len(normalized)
        self.stats['duplicates_removed'] = self.stats['total_input'] - len(normalized) - invalid_count
        self.stats['invalid_removed'] = invalid_count
        
        logger.info(f"Normalization complete: {len(normalized)} entries (removed {invalid_count} invalid, {self.stats['duplicates_removed']} duplicates)")
        return normalized
    
    def group_variants(self) -> List[Dict]:
        """Group similar entries using fuzzy matching"""
        if not fuzz or not process:
            logger.warning("rapidfuzz not available, skipping variant grouping")
            return self.normalized_entries
        
        # Group by similar source terms
        sources = [e['source'] for e in self.normalized_entries]
        grouped = []
        used_indices = set()
        
        for i, entry in enumerate(tqdm(self.normalized_entries, desc="Grouping variants")):
            if i in used_indices:
                continue
            
            # Find similar sources
            similar = []
            for j, other_source in enumerate(sources):
                if j != i and j not in used_indices:
                    ratio = fuzz.ratio(entry['source'].lower(), other_source.lower())
                    if ratio >= self.fuzzy_threshold * 100:
                        similar.append((j, ratio))
            
            if similar:
                # Merge similar entries
                variants = [entry]
                for j, _ in similar:
                    variants.append(self.normalized_entries[j])
                    used_indices.add(j)
                
                # Pick the best one (longest source)
                best = max(variants, key=lambda x: len(x['source']))
                best['variants'] = [v['source'] for v in variants if v != best]
                grouped.append(best)
                self.merged_count += len(variants) - 1
            else:
                grouped.append(entry)
            
            used_indices.add(i)
        
        self.stats['variants_merged'] = self.merged_count
        logger.info(f"Variant grouping complete: {len(grouped)} entries ({self.merged_count} merged)")
        return grouped
    
    def get_stats(self) -> Dict:
        """Get normalization statistics"""
        return self.stats
    
    def save_normalized(self, filepath: str):
        """Save normalized entries to JSON"""
        with open(filepath, 'w', encoding='utf-8') as f:
            json.dump(self.normalized_entries, f, ensure_ascii=False, indent=2)
        logger.info(f"Saved {len(self.normalized_entries)} entries to {filepath}")


def run_normalization(input_file: str, output_file: str, fuzzy_threshold: float = 0.85) -> Tuple[List[Dict], Dict]:
    """
    Main function to run normalization pipeline
    Returns: (normalized_entries, stats)
    """
    normalizer = GlossaryNormalizer(fuzzy_threshold=fuzzy_threshold)
    
    # Load
    normalizer.load_tsv(input_file)
    
    # Normalize
    normalizer.normalize_all()
    
    # Group variants (optional - can be slow)
    # normalizer.group_variants()
    
    # Save
    normalizer.save_normalized(output_file)
    
    return normalizer.normalized_entries, normalizer.get_stats()


if __name__ == "__main__":
    import sys
    if len(sys.argv) >= 3:
        entries, stats = run_normalization(sys.argv[1], sys.argv[2])
        print(f"Stats: {stats}")

