"""
Vector-based validation for glossary entries
- Connect to Qdrant (10.10.10.25:6333)
- Generate embeddings for terms
- Search corpus for validation
- Calculate confidence scores
"""

import json
from typing import List, Dict, Tuple, Optional
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed

from .utils import (
    QdrantHelper,
    OpenAIHelper,
    is_valid_glossary_entry,
    detect_language,
    logger
)


class GlossaryValidator:
    """Validate glossary entries using vector search"""
    
    def __init__(
        self,
        qdrant_host: str = "10.10.10.25",
        qdrant_port: int = 6333,
        collection: str = "machine_docs",
        confidence_threshold: float = 0.6
    ):
        self.qdrant = QdrantHelper(qdrant_host, qdrant_port, collection)
        self.openai = OpenAIHelper()
        self.confidence_threshold = confidence_threshold
        self.validated_entries = []
        self.stats = {
            'total_validated': 0,
            'high_confidence': 0,
            'low_confidence': 0,
            'corpus_hits': 0,
            'average_confidence': 0.0
        }
    
    def connect(self) -> bool:
        """Connect to Qdrant and OpenAI"""
        qdrant_ok = self.qdrant.connect()
        openai_ok = self.openai.connect()
        return qdrant_ok and openai_ok
    
    def calculate_confidence(
        self,
        source: str,
        target: str,
        search_results: List[Dict]
    ) -> Tuple[float, Dict]:
        """
        Calculate confidence score for an entry
        Returns: (confidence_score, details)
        """
        details = {
            'corpus_hit': False,
            'semantic_score': 0.0,
            'format_score': 1.0,
            'reasons': []
        }
        
        # 1. Format/language validation
        is_valid, reason = is_valid_glossary_entry(source, target)
        if not is_valid:
            details['format_score'] = 0.0
            details['reasons'].append(f"format_invalid: {reason}")
        
        # Check target language
        target_lang = detect_language(target)
        if target_lang != 'turkish':
            details['format_score'] *= 0.5
            details['reasons'].append(f"target_lang: {target_lang}")
        
        # 2. Corpus hit check
        if search_results:
            best_score = max(r['score'] for r in search_results)
            details['semantic_score'] = best_score
            
            # Check if source term appears in results
            source_lower = source.lower()
            for result in search_results:
                payload = result.get('payload', {})
                text = payload.get('text', '') or payload.get('content', '')
                if source_lower in text.lower():
                    details['corpus_hit'] = True
                    break
            
            if details['corpus_hit']:
                details['reasons'].append('corpus_match')
        
        # 3. Calculate final confidence
        # Weighted average: format(0.3) + semantic(0.4) + corpus_hit(0.3)
        corpus_bonus = 1.0 if details['corpus_hit'] else 0.5
        
        confidence = (
            details['format_score'] * 0.3 +
            details['semantic_score'] * 0.4 +
            corpus_bonus * 0.3
        )
        
        # Clamp to [0, 1]
        confidence = max(0.0, min(1.0, confidence))
        
        return confidence, details
    
    def validate_entry(self, entry: Dict) -> Dict:
        """Validate a single entry"""
        source = entry['source']
        target = entry['target']
        
        # Get embedding for source term
        embedding = self.openai.get_embedding(source)
        
        search_results = []
        if embedding:
            # Search in Qdrant
            search_results = self.qdrant.search(embedding, limit=5)
        
        # Calculate confidence
        confidence, details = self.calculate_confidence(source, target, search_results)
        
        return {
            **entry,
            'confidence': confidence,
            'validation_details': details,
            'is_valid': confidence >= self.confidence_threshold
        }
    
    def validate_batch(
        self,
        entries: List[Dict],
        batch_size: int = 50,
        max_workers: int = 4
    ) -> List[Dict]:
        """Validate entries in batches"""
        logger.info(f"Starting validation of {len(entries)} entries...")
        
        validated = []
        total_confidence = 0.0
        
        # Process in batches to manage API rate limits
        for i in tqdm(range(0, len(entries), batch_size), desc="Validating"):
            batch = entries[i:i+batch_size]
            
            # Get embeddings for batch
            sources = [e['source'] for e in batch]
            embeddings = self.openai.get_embeddings_batch(sources)
            
            for j, entry in enumerate(batch):
                embedding = embeddings[j] if j < len(embeddings) else None
                
                search_results = []
                if embedding:
                    search_results = self.qdrant.search(embedding, limit=5)
                
                confidence, details = self.calculate_confidence(
                    entry['source'],
                    entry['target'],
                    search_results
                )
                
                validated_entry = {
                    **entry,
                    'confidence': confidence,
                    'validation_details': details,
                    'is_valid': confidence >= self.confidence_threshold
                }
                validated.append(validated_entry)
                total_confidence += confidence
                
                if validated_entry['is_valid']:
                    self.stats['high_confidence'] += 1
                else:
                    self.stats['low_confidence'] += 1
                
                if details['corpus_hit']:
                    self.stats['corpus_hits'] += 1
        
        self.validated_entries = validated
        self.stats['total_validated'] = len(validated)
        self.stats['average_confidence'] = total_confidence / len(validated) if validated else 0
        
        logger.info(f"Validation complete: {self.stats['high_confidence']} high confidence, {self.stats['low_confidence']} low confidence")
        return validated
    
    def validate_simple(self, entries: List[Dict]) -> List[Dict]:
        """
        Simple validation without vector search (rule-based only)
        Use this when Qdrant/OpenAI is not available
        """
        logger.info(f"Running simple validation on {len(entries)} entries...")
        
        validated = []
        
        for entry in tqdm(entries, desc="Simple validation"):
            source = entry['source']
            target = entry['target']
            
            # Rule-based validation
            is_valid, reason = is_valid_glossary_entry(source, target)
            target_lang = detect_language(target)
            
            # Calculate simple confidence
            confidence = 0.8 if is_valid else 0.3
            if target_lang != 'turkish':
                confidence *= 0.5
            
            details = {
                'corpus_hit': False,
                'semantic_score': 0.0,
                'format_score': 1.0 if is_valid else 0.0,
                'reasons': [reason] if not is_valid else []
            }
            
            if target_lang != 'turkish':
                details['reasons'].append(f"target_lang: {target_lang}")
            
            validated_entry = {
                **entry,
                'confidence': confidence,
                'validation_details': details,
                'is_valid': confidence >= self.confidence_threshold
            }
            validated.append(validated_entry)
            
            if validated_entry['is_valid']:
                self.stats['high_confidence'] += 1
            else:
                self.stats['low_confidence'] += 1
        
        self.validated_entries = validated
        self.stats['total_validated'] = len(validated)
        self.stats['average_confidence'] = sum(e['confidence'] for e in validated) / len(validated) if validated else 0
        
        return validated
    
    def get_stats(self) -> Dict:
        """Get validation statistics"""
        return self.stats
    
    def save_results(self, filepath: str):
        """Save validation results to JSON"""
        with open(filepath, 'w', encoding='utf-8') as f:
            json.dump(self.validated_entries, f, ensure_ascii=False, indent=2)
        logger.info(f"Saved {len(self.validated_entries)} validated entries to {filepath}")


def run_validation(
    input_file: str,
    output_file: str,
    qdrant_host: str = "10.10.10.25",
    qdrant_port: int = 6333,
    collection: str = "machine_docs",
    confidence_threshold: float = 0.6,
    use_vector: bool = True
) -> Tuple[List[Dict], Dict]:
    """
    Main function to run validation
    Returns: (validated_entries, stats)
    """
    # Load normalized entries
    with open(input_file, 'r', encoding='utf-8') as f:
        entries = json.load(f)
    
    validator = GlossaryValidator(
        qdrant_host=qdrant_host,
        qdrant_port=qdrant_port,
        collection=collection,
        confidence_threshold=confidence_threshold
    )
    
    if use_vector:
        # Try to connect to Qdrant and OpenAI
        if validator.connect():
            validated = validator.validate_batch(entries)
        else:
            logger.warning("Vector services unavailable, falling back to simple validation")
            validated = validator.validate_simple(entries)
    else:
        validated = validator.validate_simple(entries)
    
    # Save results
    validator.save_results(output_file)
    
    return validated, validator.get_stats()


if __name__ == "__main__":
    import sys
    if len(sys.argv) >= 3:
        entries, stats = run_validation(sys.argv[1], sys.argv[2])
        print(f"Stats: {stats}")

