import time
from django.core.management.base import BaseCommand
from django.core.files.storage import default_storage
from chatbot.models import Chatbot, File, RAGMetadata
from PyPDF2 import PdfReader
from docx import Document
from sentence_transformers import SentenceTransformer
import chromadb
import os
from django.conf import settings
import spacy
import logging

# Set up logging
logger = logging.getLogger(__name__)

class Command(BaseCommand):
    help = 'Processes PDFs, DOCX, and TXT files for a chatbot and creates RAG data with chunking'

    def add_arguments(self, parser):
        parser.add_argument('chatbot_id', type=int, help='ID of the chatbot to process')
        parser.add_argument('--chunk_strategy', default='sentence', 
                          choices=['sentence', 'fixed', 'paragraph'], 
                          help='Chunking strategy: sentence (with paragraph awareness), fixed-size, or paragraph-based')

    def extract_text_from_pdf(self, file_path):
        """Extract text from PDF with paragraph detection."""
        try:
            with open(file_path, 'rb') as f:
                pdf = PdfReader(f)
                text = ""
                for page_num, page in enumerate(pdf.pages, 1):
                    page_text = page.extract_text() or ""
                    text += page_text + "\n\n"
                    self.stdout.write(f"Extracted text from page {page_num} ({len(page_text)} characters)")
                return text
        except Exception as e:
            raise Exception(f"Failed to extract text from PDF")

    def extract_text_from_docx(self, file_path):
        """Extract text from DOCX with paragraph detection."""
        try:
            doc = Document(file_path)
            text = ""
            for para in doc.paragraphs:
                if para.text.strip():
                    text += para.text.strip() + "\n\n"
            self.stdout.write(f"Extracted text from DOCX ({len(text)} characters)")
            return text
        except Exception as e:
            raise Exception(f"Failed to extract text from DOCX: {str(e)}")

    def extract_text_from_txt(self, file_path):
        """Extract text from TXT file."""
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                text = f.read()
                text = '\n\n'.join(para.strip() for para in text.split('\n\n') if para.strip())
                self.stdout.write(f"Extracted text from TXT ({len(text)} characters)")
                return text
        except Exception as e:
            raise Exception(f"Failed to extract text from TXT: {str(e)}")

    def extract_text_from_file(self, file_path):
        """Extract text from a file based on its extension."""
        file_extension = os.path.splitext(file_path)[1].lower()
        if file_extension == '.pdf':
            return self.extract_text_from_pdf(file_path)
        elif file_extension == '.docx':
            return self.extract_text_from_docx(file_path)
        elif file_extension == '.txt':
            return self.extract_text_from_txt(file_path)
        else:
            raise ValueError(f"Unsupported file type: {file_extension}")

    def sentence_chunking(self, text, max_tokens=400, overlap_tokens=80):
        """Chunk text into sentences while respecting paragraph boundaries."""
        nlp = spacy.load('en_core_web_sm')
        paragraphs = [p.strip() for p in text.split('\n\n') if p.strip()]
        chunks = []
        chunk_index = 0

        self.stdout.write(f"Detected {len(paragraphs)} paragraphs")

        for para in paragraphs:
            doc = nlp(para)
            current_chunk = []
            current_tokens = 0
            sentences = [sent.text.strip() for sent in doc.sents if sent.text.strip()]

            for sent in sentences:
                sent_tokens = len(nlp(sent))
                if current_tokens + sent_tokens > max_tokens:
                    if current_chunk:
                        chunks.append({
                            'text': ' '.join(current_chunk),
                            'index': chunk_index,
                            'id': f"chunk_{chunk_index}"
                        })
                        chunk_index += 1
                        overlap_text = ' '.join(current_chunk[-overlap_tokens//10:]) if overlap_tokens else ''
                        current_chunk = [overlap_text, sent] if overlap_text else [sent]
                        current_tokens = len(nlp(overlap_text)) + sent_tokens
                    else:
                        current_chunk.append(sent)
                        chunks.append({
                            'text': ' '.join(current_chunk),
                            'index': chunk_index,
                            'id': f"chunk_{chunk_index}"
                        })
                        chunk_index += 1
                        current_chunk = []
                        current_tokens = 0
                else:
                    current_chunk.append(sent)
                    current_tokens += sent_tokens

            if current_chunk:
                chunks.append({
                    'text': ' '.join(current_chunk),
                    'index': chunk_index,
                    'id': f"chunk_{chunk_index}"
                })
                chunk_index += 1

        self.stdout.write(f"Created {len(chunks)} sentence-based chunks")
        return chunks

    def fixed_chunking(self, text, chunk_size=500):
        """Fallback fixed-size chunking."""
        chunks = [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)]
        return [{
            'text': chunk,
            'index': i,
            'id': f"chunk_{i}"
        } for i, chunk in enumerate(chunks)]

    def paragraph_chunking(self, text):
        """Chunk text into paragraphs."""
        paragraphs = [p.strip() for p in text.split('\n\n') if p.strip()]
        chunks = []
        chunk_index = 0

        self.stdout.write(f"Detected {len(paragraphs)} paragraphs")

        for para in paragraphs:
            chunks.append({
                'text': para,
                'index': chunk_index,
                'id': f"chunk_{chunk_index}"
            })
            chunk_index += 1

        self.stdout.write(f"Created {len(chunks)} paragraph-based chunks")
        return chunks

    def handle(self, *args, **options):
        start_time = time.time()
        chatbot_id = options['chatbot_id']
        chunk_strategy = options['chunk_strategy']

        try:
            chatbot = Chatbot.objects.get(id=chatbot_id)
            self.stdout.write(f"Starting RAG processing for chatbot: {chatbot.name} (ID: {chatbot_id})", ending='\n')
            self.stdout.flush()
        except Chatbot.DoesNotExist:
            self.stdout.write(self.style.ERROR(f"Chatbot with ID {chatbot_id} does not exist"), ending='\n')
            self.stdout.flush()
            return

        files = File.objects.filter(chatbot=chatbot)
        if not files:
            self.stdout.write(self.style.WARNING(f"No files found for chatbot {chatbot.name}"), ending='\n')
            self.stdout.flush()
            return

        chroma_client = chromadb.PersistentClient(path=str(settings.MEDIA_ROOT / 'chroma_db'))
        collection_name = f"chatbot_{chatbot_id}"
        try:
            chroma_client.delete_collection(name=collection_name)
            self.stdout.write(f"Cleared existing collection: {collection_name}", ending='\n')
            self.stdout.flush()
        except:
            self.stdout.write(f"No existing collection to clear: {collection_name}", ending='\n')
            self.stdout.flush()
        try:
            collection = chroma_client.get_or_create_collection(name=collection_name)
            self.stdout.write(f"Initialized Chroma collection: {collection_name}", ending='\n')
            self.stdout.flush()
        except Exception as e:
            self.stdout.write(self.style.ERROR(f"Failed to initialize Chroma collection: {str(e)}"), ending='\n')
            self.stdout.flush()
            return

        self.stdout.write("Loading sentence-transformers model...", ending='\n')
        self.stdout.flush()
        embedder = SentenceTransformer('all-MiniLM-L6-v2')
        self.stdout.write("Embedding model loaded successfully", ending='\n')
        self.stdout.flush()

        for file_obj in files:
            self.stdout.write(f"\nProcessing file: {file_obj.file.name}", ending='\n')
            self.stdout.flush()
            file_start_time = time.time()

            rag_metadata, created = RAGMetadata.objects.get_or_create(
                file=file_obj,
                defaults={'status': 'PROCESSING'}
            )

            try:
                file_path = default_storage.path(file_obj.file.name)
                self.stdout.write(f"Reading file from: {file_path}", ending='\n')
                self.stdout.flush()
                text = self.extract_text_from_file(file_path)
                self.stdout.write(f"Total text extracted: {len(text)} characters", ending='\n')
                self.stdout.flush()

                self.stdout.write(f"Chunking text using {chunk_strategy} strategy...", ending='\n')
                self.stdout.flush()
                if chunk_strategy == 'sentence':
                    chunks = self.sentence_chunking(text, max_tokens=400, overlap_tokens=80)
                elif chunk_strategy == 'fixed':
                    chunks = self.fixed_chunking(text, chunk_size=500)
                elif chunk_strategy == 'paragraph':
                    chunks = self.paragraph_chunking(text)

                self.stdout.write(f"Created {len(chunks)} chunks", ending='\n')
                self.stdout.flush()

                self.stdout.write("Generating embeddings...", ending='\n')
                self.stdout.flush()
                chunk_texts = [chunk['text'] for chunk in chunks]
                embeddings = embedder.encode(chunk_texts, show_progress_bar=False)
                self.stdout.write(f"Generated {len(embeddings)} embeddings", ending='\n')
                self.stdout.flush()

                documents = [chunk['id'] for chunk in chunks]
                metadatas = [{
                    'chatbot_id': chatbot_id,
                    'file_id': file_obj.id,
                    'chunk_index': chunk['index'],
                    'text': chunk['text']
                } for chunk in chunks]
                collection.add(
                    documents=documents,
                    embeddings=embeddings,
                    metadatas=metadatas,
                    ids=[f"{file_obj.id}_{chunk['index']}" for chunk in chunks]
                )
                self.stdout.write(f"Stored {len(chunks)} chunks in Chroma collection", ending='\n')
                self.stdout.flush()

                rag_metadata.chunk_count = len(chunks)
                rag_metadata.status = 'COMPLETED'
                rag_metadata.save()
                self.stdout.write(f"Updated RAG metadata for file: {file_obj.file.name}", ending='\n')
                self.stdout.flush()

                file_duration = time.time() - file_start_time
                self.stdout.write(f"File processing completed in {file_duration:.2f} seconds", ending='\n')
                self.stdout.flush()

            except Exception as e:
                rag_metadata.status = 'FAILED'
                rag_metadata.error_message = str(e)
                rag_metadata.save()
                self.stdout.write(self.style.ERROR(f"Failed to process file {file_obj.file.name}: {str(e)}"), ending='\n')
                self.stdout.flush()
                logger.error(f"Failed to process file {file_obj.file.name}: {str(e)}")
                continue

        total_duration = time.time() - start_time
        self.stdout.write(self.style.SUCCESS(f"\nRAG processing completed for chatbot {chatbot.name} in {total_duration:.2f} seconds"), ending='\n')
        self.stdout.flush()