#!/usr/bin/env python3
"""
Terminology Extractor - High Performance Multiprocessing Version
================================================================
BERT tabanlı kelime hizalama ile paralel cümlelerden terim çıkarma.
32 çekirdekli Hetzner AX102 için optimize edilmiş.

Kullanım:
    python main.py                  # Normal çalıştırma (30 worker)
    python main.py --debug          # Debug modu (100 satır, 4 worker)
    python main.py --workers 16     # Özel worker sayısı
    python main.py --input file.csv # Özel girdi dosyası

Çıktılar:
    output/glossary_candidates.csv  # Terim adayları (frekans + skor)
    output/deepl_glossary.csv       # DeepL API formatı
    output/google_automl.tsv        # Google Cloud formatı
"""

import os
import sys
import time
import argparse
import multiprocessing as mp
from multiprocessing import Pool, Manager
from typing import List, Dict, Tuple, Generator, Optional
from datetime import datetime
from collections import defaultdict
from functools import partial

import pandas as pd
from tqdm import tqdm

# Add project root to path
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))

import config

# ============================================================
# GLOBAL VARIABLES (Worker Process'lerde paylaşılacak)
# ============================================================
_worker_aligner = None
_worker_extractor = None
_worker_initialized = False


def init_worker():
    """
    Worker process başlatıcı.
    Her worker sadece BİR KEZ bu fonksiyonu çağırır.
    Model burada yüklenir ve global değişkene atanır.
    """
    global _worker_aligner, _worker_extractor, _worker_initialized
    
    if _worker_initialized:
        return
    
    # PyTorch thread sayısını kısıtla (çakışma önleme)
    try:
        import torch
        torch.set_num_threads(config.TORCH_THREADS)
    except ImportError:
        pass
    
    # Modelleri yükle
    from src.alignment import WordAligner
    from src.term_extractor import TermExtractor
    
    _worker_aligner = WordAligner()
    _worker_aligner.initialize()
    
    _worker_extractor = TermExtractor()
    _worker_extractor.initialize()
    
    _worker_initialized = True
    
    # Worker ID'yi logla
    pid = os.getpid()
    # print(f"Worker {pid} initialized")


def process_batch(batch: List[Tuple[str, str]]) -> List[Tuple[str, str, float]]:
    """
    Bir batch cümle çiftini işle.
    Bu fonksiyon worker process'lerde çalışır.
    Global modelleri kullanır (tekrar yüklemez).
    
    Args:
        batch: List of (source, target) tuples
        
    Returns:
        List of (source_term, target_term, score) tuples
    """
    global _worker_aligner, _worker_extractor
    
    # Lazy initialization (ilk çağrıda)
    if not _worker_initialized:
        init_worker()
    
    terms = []
    
    for source, target in batch:
        if not source or not target:
            continue
        
        try:
            # Get word alignments
            alignments = _worker_aligner.align_sentence_pair(source, target)
            
            # Filter by POS (only nouns, adjectives)
            filtered = _worker_extractor.filter_by_pos(alignments, source, target)
            
            # Get phrase alignments
            phrases = _worker_aligner.extract_phrase_alignments(source, target)
            
            # Collect terms
            for align in filtered:
                terms.append((
                    align['src_word'],
                    align['tgt_word'],
                    align.get('score', 1.0)
                ))
            
            for src_phrase, tgt_phrase, score in phrases:
                terms.append((src_phrase, tgt_phrase, score))
                
        except Exception as e:
            # Hata durumunda devam et
            pass
    
    return terms


def parse_args():
    """Komut satırı argümanlarını parse et."""
    parser = argparse.ArgumentParser(
        description='BERT tabanlı kelime hizalama ile terim çıkarma (Multiprocessing)'
    )
    
    parser.add_argument(
        '--input', '-i',
        type=str,
        default=config.INPUT_FILE,
        help='Girdi CSV dosyası (source_text, target_text sütunları)'
    )
    
    parser.add_argument(
        '--output', '-o',
        type=str,
        default=config.OUTPUT_DIR,
        help='Çıktı dizini'
    )
    
    parser.add_argument(
        '--debug', '-d',
        action='store_true',
        help='Debug modu (sadece LIMIT kadar satır, 4 worker)'
    )
    
    parser.add_argument(
        '--limit', '-l',
        type=int,
        default=config.LIMIT,
        help='Debug modda maksimum satır sayısı'
    )
    
    parser.add_argument(
        '--workers', '-w',
        type=int,
        default=None,
        help='Worker sayısı (default: config\'den)'
    )
    
    parser.add_argument(
        '--chunk-size', '-c',
        type=int,
        default=config.CHUNK_SIZE,
        help='CSV okuma chunk boyutu'
    )
    
    parser.add_argument(
        '--batch-size', '-b',
        type=int,
        default=config.BATCH_SIZE,
        help='Her worker\'a gönderilecek satır sayısı'
    )
    
    parser.add_argument(
        '--single', '-s',
        action='store_true',
        help='Tek process modu (debug için)'
    )
    
    return parser.parse_args()


def read_csv_chunks(filepath: str, chunk_size: int, 
                    limit: int = None) -> Generator[pd.DataFrame, None, None]:
    """
    CSV dosyasını chunk'lar halinde oku.
    
    Args:
        filepath: CSV dosya yolu
        chunk_size: Her chunk'taki satır sayısı
        limit: Maksimum satır (debug için)
        
    Yields:
        DataFrame chunks
    """
    if not os.path.exists(filepath):
        raise FileNotFoundError(f"Input file not found: {filepath}")
    
    # Detect column names
    sample = pd.read_csv(filepath, nrows=1)
    
    source_col = None
    target_col = None
    
    for col in sample.columns:
        col_lower = col.lower()
        if 'source' in col_lower or col_lower == 'en' or col_lower == 'english':
            source_col = col
        elif 'target' in col_lower or col_lower == 'tr' or col_lower == 'turkish':
            target_col = col
    
    if not source_col or not target_col:
        source_col = sample.columns[0]
        target_col = sample.columns[1]
    
    print(f"📂 Reading: {filepath}")
    print(f"   Source column: {source_col}")
    print(f"   Target column: {target_col}")
    
    total_rows = 0
    
    for chunk in pd.read_csv(filepath, chunksize=chunk_size):
        chunk = chunk.rename(columns={
            source_col: 'source_text',
            target_col: 'target_text'
        })
        
        if limit:
            remaining = limit - total_rows
            if remaining <= 0:
                break
            chunk = chunk.head(remaining)
        
        total_rows += len(chunk)
        yield chunk
        
        if limit and total_rows >= limit:
            break


def chunk_dataframe_to_batches(df: pd.DataFrame, batch_size: int) -> List[List[Tuple[str, str]]]:
    """
    DataFrame'i batch'lere böl.
    
    Args:
        df: DataFrame
        batch_size: Her batch'teki satır sayısı
        
    Returns:
        List of batches (her batch list of tuples)
    """
    batches = []
    rows = list(zip(
        df['source_text'].astype(str).str.strip(),
        df['target_text'].astype(str).str.strip()
    ))
    
    for i in range(0, len(rows), batch_size):
        batches.append(rows[i:i + batch_size])
    
    return batches


def main():
    """Ana işlem fonksiyonu - Multiprocessing."""
    args = parse_args()
    
    # Debug mode settings
    if args.debug:
        config.DEBUG_MODE = True
        num_workers = 4  # Debug'da az worker
    else:
        num_workers = args.workers or config.NUM_WORKERS
    
    limit = args.limit if config.DEBUG_MODE else None
    
    # Single process mode
    if args.single:
        num_workers = 1
    
    print("=" * 60)
    print("TERMINOLOGY EXTRACTOR - HIGH PERFORMANCE")
    print("=" * 60)
    print(f"Input: {args.input}")
    print(f"Output: {args.output}")
    print(f"Debug Mode: {config.DEBUG_MODE}")
    print(f"Limit: {limit or 'None (full data)'}")
    print(f"Workers: {num_workers}")
    print(f"Chunk Size: {args.chunk_size}")
    print(f"Batch Size: {args.batch_size}")
    print(f"CPU Cores: {os.cpu_count()}")
    print("=" * 60)
    
    start_time = time.time()
    
    # Import cleaner and exporter (main process'te)
    from src.cleaner import TermCleaner
    from src.exporter import GlossaryExporter
    
    cleaner = TermCleaner()
    exporter = GlossaryExporter(args.output)
    
    # Collect all terms
    all_terms = []
    processed_rows = 0
    chunk_count = 0
    
    # Create process pool with initializer
    print(f"\n🚀 Starting {num_workers} worker processes...")
    
    pool = Pool(
        processes=num_workers,
        initializer=init_worker,
        maxtasksperchild=config.MAX_TASKS_PER_CHILD
    )
    
    try:
        # Process chunks
        print("\n📊 Processing data...\n")
        
        for chunk_df in read_csv_chunks(args.input, args.chunk_size, limit):
            chunk_count += 1
            chunk_size = len(chunk_df)
            
            # Split chunk into batches for workers
            batches = chunk_dataframe_to_batches(chunk_df, args.batch_size)
            
            print(f"--- Chunk {chunk_count}: {chunk_size} rows, {len(batches)} batches ---")
            
            # Process batches in parallel with progress bar
            chunk_terms = []
            
            with tqdm(total=len(batches), desc=f"Chunk {chunk_count}", 
                      unit="batch", ncols=80) as pbar:
                
                # imap_unordered for better performance
                for result in pool.imap_unordered(process_batch, batches):
                    chunk_terms.extend(result)
                    pbar.update(1)
            
            all_terms.extend(chunk_terms)
            processed_rows += chunk_size
            
            print(f"   ✓ Found {len(chunk_terms)} terms, Total: {len(all_terms)}")
            
            # Save checkpoint
            exporter.save_checkpoint({
                'chunk_count': chunk_count,
                'processed_rows': processed_rows,
                'term_count': len(all_terms)
            })
    
    except KeyboardInterrupt:
        print("\n\n⚠️ Interrupted! Saving progress...")
        exporter.save_checkpoint({
            'chunk_count': chunk_count,
            'processed_rows': processed_rows,
            'term_count': len(all_terms),
            'interrupted': True
        })
    
    finally:
        # Clean up pool
        pool.close()
        pool.join()
        print("\n🛑 Worker processes terminated")
    
    # Clean and filter terms
    print("\n🧹 Cleaning and filtering terms...")
    cleaned_terms = cleaner.clean_batch(all_terms)
    
    print(f"   Raw terms: {len(all_terms)}")
    print(f"   Cleaned terms: {len(cleaned_terms)}")
    
    # Export results
    print("\n📤 Exporting results...")
    export_results = exporter.export_all(cleaned_terms)
    
    # Calculate stats
    elapsed_time = time.time() - start_time
    
    stats = {
        'Input File': args.input,
        'Processed Rows': processed_rows,
        'Chunks Processed': chunk_count,
        'Workers Used': num_workers,
        'Raw Terms Found': len(all_terms),
        'Cleaned Terms': len(cleaned_terms),
        'Processing Time': f"{elapsed_time:.2f} seconds",
        'Rows per Second': f"{processed_rows / elapsed_time:.2f}" if elapsed_time > 0 else 'N/A',
        'Terms per Second': f"{len(all_terms) / elapsed_time:.2f}" if elapsed_time > 0 else 'N/A',
        'Debug Mode': config.DEBUG_MODE,
        'Limit': limit or 'None'
    }
    
    # Write log
    exporter.write_log(stats)
    
    # Print summary
    print("\n" + "=" * 60)
    print("COMPLETED")
    print("=" * 60)
    print(f"⏱️  Processing time: {elapsed_time:.2f} seconds")
    print(f"👷 Workers used: {num_workers}")
    print(f"📄 Processed rows: {processed_rows}")
    print(f"🔤 Raw terms found: {len(all_terms)}")
    print(f"✅ Cleaned terms: {len(cleaned_terms)}")
    print(f"⚡ Speed: {processed_rows / elapsed_time:.2f} rows/sec")
    print(f"\n📁 Output files:")
    for fmt, path in export_results.items():
        print(f"   {fmt}: {path}")
    
    # Show top terms
    if cleaned_terms:
        print(f"\n🏆 Top 10 Terms:")
        for i, term in enumerate(cleaned_terms[:10], 1):
            print(f"   {i}. {term['source']} -> {term['target']} "
                  f"(freq: {term['frequency']}, conf: {term['confidence']:.2f})")
    
    print("\n" + "=" * 60)
    
    return cleaned_terms


if __name__ == "__main__":
    # Multiprocessing için gerekli (Windows uyumluluğu)
    mp.freeze_support()
    main()
