import chromadb
from sentence_transformers import SentenceTransformer
import google.generativeai as genai
from django.conf import settings
from django.shortcuts import get_object_or_404
from .models import File, Chatbot
import logging

logger = logging.getLogger(__name__)

def retrieve_rag_entries(chatbot_id, user, query, n_results=3):
    """
    Retrieve RAG entries from ChromaDB for a given chatbot and query.
    
    Args:
        chatbot_id (int): ID of the chatbot.
        user (User): Authenticated user making the request.
        query (str): User input query.
        n_results (int): Number of RAG results to retrieve (default: 3).
    
    Returns:
        dict: Contains 'results' (list of RAG entries), 'warnings' (list of warnings),
              and 'error' (if any).
    """
    try:
        # Validate chatbot ownership
        chatbot = get_object_or_404(Chatbot, id=chatbot_id, user=user)
        
        # Initialize Chroma client and embedder
        chroma_client = chromadb.PersistentClient(path=str(settings.MEDIA_ROOT / 'chroma_db'))
        embedder = SentenceTransformer('all-MiniLM-L6-v2')
        collection_name = f"chatbot_{chatbot.id}"

        # Normalize and embed the query
        query = ' '.join(query.split())  # Normalize whitespace
        query_embedding = embedder.encode([query])[0]

        # Query ChromaDB
        rag_results = []
        missing_files = set()
        try:
            collection = chroma_client.get_collection(name=collection_name)
            query_results = collection.query(
                query_embeddings=[query_embedding],
                n_results=n_results,
                include=['metadatas', 'distances']
            )
            logger.debug(f"Chroma query results: {query_results}")

            # Process query results
            result_pairs = list(zip(query_results['metadatas'][0], query_results['distances'][0]))
            result_pairs.sort(key=lambda x: x[1])  # Sort by distance (ascending)
            
            for metadata, distance in result_pairs[:n_results]:
                file_id = metadata.get('file_id')
                try:
                    file = File.objects.get(id=file_id)
                    rag_results.append({
                        'file_name': file.file.name,
                        'text': metadata.get('text'),
                        'distance': distance
                    })
                except File.DoesNotExist:
                    logger.warning(f"File with ID {file_id} not found")
                    missing_files.add(file_id)
                    continue
        except Exception as e:
            logger.exception(f"Failed to query collection {collection_name}: {str(e)}")
            return {'error': f"No RAG data available: {str(e)}"}

        return {
            'results': rag_results,
            'warnings': [f"File with ID {fid} not found" for fid in missing_files]
        }

    except Exception as e:
        logger.exception(f"RAG retrieval failed: {str(e)}")
        return {'error': str(e)}

def retrieve_ai_response(chatbot_id, user, message, n_results=3, include_ai=True):
    """
    Retrieve RAG entries and optionally generate an AI response.
    
    Args:
        chatbot_id (int): ID of the chatbot.
        user (User): Authenticated user making the request.
        message (str): User input message/query.
        n_results (int): Number of RAG results to retrieve (default: 3).
        include_ai (bool): Whether to query the AI model for a response (default: True).
    
    Returns:
        dict: Contains 'results' (RAG entries), 'response' (AI response if include_ai=True),
              'warnings' (list of warnings), and 'error' (if any).
    """
    # Retrieve RAG entries
    rag_data = retrieve_rag_entries(chatbot_id, user, message, n_results)
    if 'error' in rag_data:
        return rag_data

    response_data = {
        'results': rag_data['results'],
        'warnings': rag_data['warnings']
    }

    # Query AI model if requested
    if include_ai:
        try:
            # Prepare context for AI
            context = """General guidelines:
             use the inclusive, inspiring, global local blend for tone, 
               be conscise. ask to rewrite the question if it is not obvious in answer. if you detect that user message is in another language, answer in that language.\n"""
            context += "Relevant document excerpts:\n"
            for idx, result in enumerate(rag_data['results'], 1):
                context += f"{idx}. From {result['file_name']} (Similarity: {result['distance']:.4f}):\n{result['text']}\n\n"
            context += f"User message: {message}"

            # Configure and query AI model
            genai.configure(api_key=settings.GEMINI_API_KEY)
            model = genai.GenerativeModel('gemini-1.5-flash')
            response = model.generate_content(context)
            response_data['response'] = response.text
            logger.debug(f"AI response: {response_data['response']}")
        except Exception as e:
            logger.exception(f"Failed to query AI model: {str(e)}")
            return {'error': f"Failed to get response from AI: {str(e)}"}

    return response_data