#!/usr/bin/env python3
"""
Qdrant-based Glossary Validation
- Create embeddings for glossary terms
- Search in machine_docs corpus
- Categorize as verified/uncertain/rejected
- Save verified to glossary_terms collection
"""

import os
import json
from typing import List, Dict, Tuple, Optional
from datetime import datetime
from tqdm import tqdm

try:
    from openai import OpenAI
    from qdrant_client import QdrantClient
    from qdrant_client.models import (
        Distance, VectorParams, PointStruct,
        Filter, FieldCondition, MatchText
    )
    DEPS_AVAILABLE = True
except ImportError as e:
    print(f"Warning: {e}")
    DEPS_AVAILABLE = False


class QdrantValidator:
    """Validate glossary terms against Qdrant corpus"""
    
    def __init__(
        self,
        qdrant_host: str = "10.10.10.25",
        qdrant_port: int = 6333,
        source_collection: str = "machine_docs",
        target_collection: str = "glossary_terms",
        embedding_model: str = "text-embedding-3-large"
    ):
        self.qdrant_host = qdrant_host
        self.qdrant_port = qdrant_port
        self.source_collection = source_collection
        self.target_collection = target_collection
        self.embedding_model = embedding_model
        self.vector_size = 3072
        
        self.qdrant = None
        self.openai = None
        
        self.stats = {
            'total': 0,
            'verified': 0,
            'uncertain': 0,
            'rejected': 0,
            'embedding_tokens': 0
        }
    
    def connect(self) -> bool:
        """Connect to Qdrant and OpenAI"""
        if not DEPS_AVAILABLE:
            print("❌ Required dependencies not available")
            return False
        
        try:
            # Qdrant
            self.qdrant = QdrantClient(
                host=self.qdrant_host,
                port=self.qdrant_port,
                timeout=60
            )
            info = self.qdrant.get_collection(self.source_collection)
            print(f"✓ Qdrant connected: {info.points_count:,} vectors in {self.source_collection}")
            
            # OpenAI
            self.openai = OpenAI()
            print(f"✓ OpenAI connected (model: {self.embedding_model})")
            
            return True
        except Exception as e:
            print(f"❌ Connection error: {e}")
            return False
    
    def get_embedding(self, text: str) -> Optional[List[float]]:
        """Get embedding for single text"""
        try:
            response = self.openai.embeddings.create(
                input=text,
                model=self.embedding_model
            )
            return response.data[0].embedding
        except Exception as e:
            print(f"Embedding error: {e}")
            return None
    
    def get_embeddings_batch(self, texts: List[str], batch_size: int = 100) -> List[Optional[List[float]]]:
        """Get embeddings for multiple texts in batches"""
        all_embeddings = []
        
        for i in range(0, len(texts), batch_size):
            batch = texts[i:i+batch_size]
            try:
                response = self.openai.embeddings.create(
                    input=batch,
                    model=self.embedding_model
                )
                for item in response.data:
                    all_embeddings.append(item.embedding)
                
                # Estimate tokens
                self.stats['embedding_tokens'] += sum(len(t.split()) * 2 for t in batch)
                
            except Exception as e:
                print(f"Batch embedding error: {e}")
                all_embeddings.extend([None] * len(batch))
        
        return all_embeddings
    
    def search_corpus(self, embedding: List[float], limit: int = 5) -> List[Dict]:
        """Search in machine_docs corpus"""
        try:
            results = self.qdrant.query_points(
                collection_name=self.source_collection,
                query=embedding,
                limit=limit,
                with_payload=True
            )
            return [
                {
                    'score': r.score,
                    'payload': r.payload
                }
                for r in results.points
            ]
        except Exception as e:
            print(f"Search error: {e}")
            return []
    
    def categorize_result(self, score: float) -> str:
        """Categorize based on similarity score"""
        if score >= 0.7:
            return 'verified'
        elif score >= 0.5:
            return 'uncertain'
        else:
            return 'rejected'
    
    def validate_glossary(self, entries: List[Dict]) -> Tuple[List[Dict], List[Dict], List[Dict]]:
        """
        Validate glossary entries against corpus
        Returns: (verified, uncertain, rejected)
        """
        print(f"\n{'='*50}")
        print("STEP 4: CREATE EMBEDDINGS")
        print(f"{'='*50}")
        
        # Get all source texts
        sources = [e['source'] for e in entries]
        print(f"Creating embeddings for {len(sources)} terms...")
        
        # Batch embedding
        embeddings = self.get_embeddings_batch(sources)
        print(f"✓ Embeddings created (est. {self.stats['embedding_tokens']} tokens)")
        
        print(f"\n{'='*50}")
        print("STEP 5: VALIDATE AGAINST CORPUS")
        print(f"{'='*50}")
        
        verified = []
        uncertain = []
        rejected = []
        
        for i, entry in enumerate(tqdm(entries, desc="Validating")):
            embedding = embeddings[i]
            
            if embedding is None:
                entry['confidence'] = 0
                entry['category'] = 'rejected'
                entry['reason'] = 'embedding_failed'
                rejected.append(entry)
                continue
            
            # Search corpus
            results = self.search_corpus(embedding, limit=3)
            
            if not results:
                entry['confidence'] = 0
                entry['category'] = 'rejected'
                entry['reason'] = 'no_corpus_match'
                rejected.append(entry)
                continue
            
            # Get best score
            best_score = results[0]['score']
            entry['confidence'] = best_score
            entry['category'] = self.categorize_result(best_score)
            entry['embedding'] = embedding  # Store for later saving
            
            # Check if source appears in matched text
            matched_text = results[0]['payload'].get('text', '') or results[0]['payload'].get('content', '')
            entry['corpus_hit'] = entry['source'].lower() in matched_text.lower()
            
            if entry['category'] == 'verified':
                verified.append(entry)
                self.stats['verified'] += 1
            elif entry['category'] == 'uncertain':
                uncertain.append(entry)
                self.stats['uncertain'] += 1
            else:
                entry['reason'] = 'low_score'
                rejected.append(entry)
                self.stats['rejected'] += 1
        
        self.stats['total'] = len(entries)
        
        print(f"\n📊 Validation results:")
        print(f"   Verified (>= 0.7): {len(verified)}")
        print(f"   Uncertain (0.5-0.7): {len(uncertain)}")
        print(f"   Rejected (< 0.5): {len(rejected)}")
        
        return verified, uncertain, rejected
    
    def create_glossary_collection(self, verified_entries: List[Dict]) -> bool:
        """Create glossary_terms collection and save verified entries"""
        print(f"\n{'='*50}")
        print("STEP 6: CREATE GLOSSARY COLLECTION")
        print(f"{'='*50}")
        
        try:
            # Check if collection exists
            collections = [c.name for c in self.qdrant.get_collections().collections]
            
            if self.target_collection in collections:
                print(f"⚠️ Collection {self.target_collection} exists, recreating...")
                self.qdrant.delete_collection(self.target_collection)
            
            # Create collection
            self.qdrant.create_collection(
                collection_name=self.target_collection,
                vectors_config=VectorParams(
                    size=self.vector_size,
                    distance=Distance.COSINE
                )
            )
            print(f"✓ Collection {self.target_collection} created")
            
            # Prepare points
            points = []
            for i, entry in enumerate(verified_entries):
                if 'embedding' not in entry:
                    continue
                
                point = PointStruct(
                    id=i,
                    vector=entry['embedding'],
                    payload={
                        'source': entry['source'],
                        'target': entry['target'],
                        'target_google': entry.get('target_google', ''),
                        'target_deepl': entry.get('target_deepl', ''),
                        'type': entry.get('type', 'term'),
                        'confidence': entry['confidence'],
                        'corpus_hit': entry.get('corpus_hit', False),
                        'origin': entry.get('origin', ''),
                        'verified_at': datetime.now().isoformat()
                    }
                )
                points.append(point)
            
            # Upload in batches
            batch_size = 100
            for i in range(0, len(points), batch_size):
                batch = points[i:i+batch_size]
                self.qdrant.upsert(
                    collection_name=self.target_collection,
                    points=batch
                )
            
            print(f"✓ Saved {len(points)} verified entries to {self.target_collection}")
            return True
            
        except Exception as e:
            print(f"❌ Collection creation error: {e}")
            return False
    
    def get_stats(self) -> Dict:
        """Get validation statistics"""
        return self.stats


def run_qdrant_validation(
    glossary_file: str,
    output_dir: str,
    qdrant_host: str = "10.10.10.25",
    qdrant_port: int = 6333
) -> Tuple[List[Dict], List[Dict], List[Dict], Dict]:
    """
    Main validation function
    Returns: (verified, uncertain, rejected, stats)
    """
    # Load glossary
    with open(glossary_file, 'r', encoding='utf-8') as f:
        entries = json.load(f)
    
    print(f"Loaded {len(entries)} glossary entries for validation")
    
    # Initialize validator
    validator = QdrantValidator(
        qdrant_host=qdrant_host,
        qdrant_port=qdrant_port
    )
    
    # Connect
    if not validator.connect():
        print("Failed to connect, returning empty results")
        return [], [], entries, {}
    
    # Validate
    verified, uncertain, rejected = validator.validate_glossary(entries)
    
    # Create collection with verified entries
    validator.create_glossary_collection(verified)
    
    return verified, uncertain, rejected, validator.get_stats()


if __name__ == "__main__":
    import sys
    
    base_dir = os.path.dirname(os.path.dirname(__file__))
    data_dir = os.path.join(base_dir, 'data')
    glossary_file = os.path.join(data_dir, 'glossary_for_validation.json')
    
    verified, uncertain, rejected, stats = run_qdrant_validation(
        glossary_file, data_dir
    )
    
    print(f"\nFinal stats: {stats}")

