"""
Word Alignment Module
=====================
BERT tabanlı kelime hizalama (simalign kullanarak)
"""

import sys
from typing import List, Dict, Tuple, Optional
from tqdm import tqdm

# Config import
sys.path.insert(0, '..')
try:
    import config
except ImportError:
    from .. import config


class WordAligner:
    """
    BERT tabanlı kelime hizalama sınıfı.
    simalign kütüphanesini wrapper olarak kullanır.
    """
    
    def __init__(self, model_name: str = None, device: str = None):
        """
        Args:
            model_name: BERT model adı (default: config'den)
            device: 'cpu' veya 'cuda' (default: config'den)
        """
        self.model_name = model_name or config.ALIGNMENT_MODEL
        self.device = device or ('cpu' if config.CPU_ONLY else 'cuda')
        self.aligner = None
        self._initialized = False
    
    def initialize(self):
        """Modeli yükle (lazy loading)"""
        if self._initialized:
            return
        
        try:
            from simalign import SentenceAligner
            
            print(f"🔄 Loading alignment model: {self.model_name}")
            print(f"   Device: {self.device}")
            
            self.aligner = SentenceAligner(
                model=self.model_name,
                token_type="bpe",
                matching_methods="mai"  # mwmf, inter, itermax
            )
            
            self._initialized = True
            print("✅ Alignment model loaded")
            
        except ImportError:
            print("❌ simalign not installed. Run: pip install simalign")
            raise
        except Exception as e:
            print(f"❌ Failed to load alignment model: {e}")
            raise
    
    def align_sentence_pair(self, source: str, target: str) -> List[Dict]:
        """
        Tek bir cümle çifti için kelime hizalaması yap.
        
        Args:
            source: İngilizce cümle
            target: Türkçe cümle
            
        Returns:
            List of alignments: [{'src_idx': 0, 'tgt_idx': 1, 'src_word': 'pump', 'tgt_word': 'pompa', 'score': 0.95}, ...]
        """
        if not self._initialized:
            self.initialize()
        
        if not source or not target:
            return []
        
        try:
            # Tokenize
            src_tokens = source.strip().split()
            tgt_tokens = target.strip().split()
            
            if not src_tokens or not tgt_tokens:
                return []
            
            # Get alignments
            alignments = self.aligner.get_word_aligns(src_tokens, tgt_tokens)
            
            # Parse results (itermax method)
            result = []
            
            # simalign returns dict with different methods
            # We use 'itermax' which is most accurate
            align_pairs = alignments.get('itermax', alignments.get('inter', []))
            
            for src_idx, tgt_idx in align_pairs:
                if src_idx < len(src_tokens) and tgt_idx < len(tgt_tokens):
                    result.append({
                        'src_idx': src_idx,
                        'tgt_idx': tgt_idx,
                        'src_word': src_tokens[src_idx],
                        'tgt_word': tgt_tokens[tgt_idx],
                        'score': 1.0  # simalign doesn't provide scores, assume high confidence
                    })
            
            return result
            
        except Exception as e:
            print(f"⚠️ Alignment error: {e}")
            return []
    
    def align_batch(self, pairs: List[Tuple[str, str]], 
                    batch_size: int = None,
                    show_progress: bool = True) -> List[List[Dict]]:
        """
        Batch halinde cümle çiftlerini hizala.
        
        Args:
            pairs: List of (source, target) tuples
            batch_size: Batch boyutu (default: config'den)
            show_progress: Progress bar göster
            
        Returns:
            List of alignment results for each pair
        """
        if not self._initialized:
            self.initialize()
        
        batch_size = batch_size or config.BATCH_SIZE
        results = []
        
        iterator = tqdm(pairs, desc="Aligning") if show_progress else pairs
        
        for source, target in iterator:
            alignments = self.align_sentence_pair(source, target)
            results.append(alignments)
        
        return results
    
    def extract_aligned_terms(self, source: str, target: str, 
                              min_score: float = None) -> List[Tuple[str, str, float]]:
        """
        Hizalanmış terimleri çıkar.
        
        Args:
            source: İngilizce cümle
            target: Türkçe cümle
            min_score: Minimum skor (default: config'den)
            
        Returns:
            List of (source_term, target_term, score) tuples
        """
        min_score = min_score or config.MIN_CONFIDENCE
        
        alignments = self.align_sentence_pair(source, target)
        
        terms = []
        for align in alignments:
            if align['score'] >= min_score:
                terms.append((
                    align['src_word'],
                    align['tgt_word'],
                    align['score']
                ))
        
        return terms
    
    def extract_phrase_alignments(self, source: str, target: str,
                                  max_phrase_len: int = None) -> List[Tuple[str, str, float]]:
        """
        Çok kelimeli ifadeleri (phrase) hizala.
        Ardışık hizalanmış kelimeleri grupla.
        
        Args:
            source: İngilizce cümle
            target: Türkçe cümle
            max_phrase_len: Maksimum kelime sayısı
            
        Returns:
            List of (source_phrase, target_phrase, avg_score) tuples
        """
        max_phrase_len = max_phrase_len or config.MAX_TERM_WORDS
        
        alignments = self.align_sentence_pair(source, target)
        
        if not alignments:
            return []
        
        # Sort by source index
        alignments.sort(key=lambda x: x['src_idx'])
        
        # Group consecutive alignments
        phrases = []
        current_src = []
        current_tgt = []
        current_scores = []
        last_src_idx = -2
        last_tgt_idx = -2
        
        for align in alignments:
            # Check if consecutive
            if (align['src_idx'] == last_src_idx + 1 and 
                abs(align['tgt_idx'] - last_tgt_idx) <= 1):
                # Continue phrase
                current_src.append(align['src_word'])
                current_tgt.append(align['tgt_word'])
                current_scores.append(align['score'])
            else:
                # Save previous phrase if exists
                if current_src and len(current_src) <= max_phrase_len:
                    phrases.append((
                        ' '.join(current_src),
                        ' '.join(current_tgt),
                        sum(current_scores) / len(current_scores)
                    ))
                
                # Start new phrase
                current_src = [align['src_word']]
                current_tgt = [align['tgt_word']]
                current_scores = [align['score']]
            
            last_src_idx = align['src_idx']
            last_tgt_idx = align['tgt_idx']
        
        # Don't forget last phrase
        if current_src and len(current_src) <= max_phrase_len:
            phrases.append((
                ' '.join(current_src),
                ' '.join(current_tgt),
                sum(current_scores) / len(current_scores)
            ))
        
        return phrases


# Test
if __name__ == "__main__":
    aligner = WordAligner()
    
    # Test sentence
    source = "Check the hydraulic pump pressure."
    target = "Hidrolik pompa basıncını kontrol edin."
    
    print(f"\nSource: {source}")
    print(f"Target: {target}")
    
    # Word alignments
    alignments = aligner.align_sentence_pair(source, target)
    print(f"\nWord Alignments:")
    for a in alignments:
        print(f"  {a['src_word']} -> {a['tgt_word']}")
    
    # Phrase alignments
    phrases = aligner.extract_phrase_alignments(source, target)
    print(f"\nPhrase Alignments:")
    for src, tgt, score in phrases:
        print(f"  {src} -> {tgt} (score: {score:.2f})")

