#!/usr/bin/env python3

# AI imports
import cmd
import qdrant_client
import openai
from fastembed.rerank.cross_encoder import TextCrossEncoder
# Regular imports
import uuid
import os
import hashlib
import argparse

from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from fastapi import FastAPI
import uvicorn

# Global Vars
EMBED_MODEL = os.getenv("EMBED_MODEL", "jinaai/jina-embeddings-v2-small-en")
SPARSE_MODEL = os.getenv("SPARSE_MODEL", "prithivida/Splade_PP_en_v1")
RANK_MODEL = os.getenv("RANK_MODEL", "Xenova/ms-marco-MiniLM-L-6-v2")
COLLECTION_NAME = "rag"
# Needed for mac to not give errors
os.environ["TOKENIZERS_PARALLELISM"] = "true"


def eprint(e, exit_code):
    perror("Error: " + str(e).strip("'\""))
    sys.exit(exit_code)

# Helper Classes and Functions

class QueryRequest(BaseModel):
    model: str
    messages: list
    max_tokens: int = 150  # Default to 150 tokens

class Rag(cmd.Cmd):
    prompt = "> "
    def __init__(self, vector_path):
        # Initlialze the cmd class
        super().__init__()
        
        self.client = qdrant_client.QdrantClient(path=vector_path)
        self.client.set_model(EMBED_MODEL)
        self.client.set_sparse_model(SPARSE_MODEL)
        self.reranker = TextCrossEncoder(model_name=RANK_MODEL)

        # Setup openai api
        self.llm = openai.OpenAI(api_key="your-api-key", base_url="http://localhost:8080")
        self.chat_history = []  # Store chat history

    def do_EOF(self, user_content):
        print("")
        return True
    
    def query(self, prompt):
        # Add user query to chat history
        self.chat_history.append({"role": "user", "content": prompt})
        
        # Ensure chat history does not exceed 10 messages (5 user + 5 AI)
        if len(self.chat_history) > 10:
            self.chat_history.pop(0)  # Remove the oldest message
        
        # Query the Qdrant client for relevant context
        results = self.client.query(
            collection_name="rag",
            query_text=prompt,
            limit=20,
        )
        result = [r.document for r in results] 
        # reranker code to have the first 5 queries 
        reranked_context = " ".join(
            str(result[i]) for i, _ in sorted(
                enumerate(self.reranker.rerank(prompt, result)),
                key=lambda x: x[1],
                reverse=True
            )[:5]
        )

        # context = "\n".join(r.document for r in results)

        # Prepare the metaprompt with chat history and context
        metaprompt = f"""
            You are an expert software architect.  
            Use the provided context and chat history to answer the question accurately and concisely.  
            If the answer is not explicitly stated, infer the most reasonable answer based on the available information.  
            If there is no relevant information, respond with "I don't know"—do not fabricate details.  

            ### Chat History:
            {self.format_chat_history()}

            ### Context:  
            {reranked_context.strip()}  

            ### Question:  
            {prompt.strip()}  

            ### Answer:
            """
        
        # Query the LLM with the metaprompt
        response = self.llm.chat.completions.create(
            model="your-model-name",
            messages=[{"role": "user", "content": metaprompt}],
            stream=True 
        )

        # Collect the AI response and add it to chat history
        full_response = ""
        for chunk in response:
            if chunk.choices[0].delta.content:
                full_response += chunk.choices[0].delta.content
                print(chunk.choices[0].delta.content, end="", flush=True)
        
        # Add AI response to chat history
        self.chat_history.append({"role": "assistant", "content": full_response})
        
        # Ensure chat history does not exceed 10 messages after adding the AI response
        if len(self.chat_history) > 10:
            self.chat_history.pop(0)  # Remove the oldest message
        
        print(" ")

    def format_chat_history(self):
        """Format the chat history into a string for inclusion in the metaprompt."""
        formatted_history = []
        for i in range(0, len(self.chat_history), 2):
            user_message = self.chat_history[i]["content"]
            if i + 1 < len(self.chat_history):
                ai_message = self.chat_history[i + 1]["content"]
                formatted_history.append(f"User: {user_message}\nAI: {ai_message}")
            else:
                formatted_history.append(f"User: {user_message}\nAI: ")
        return "\n".join(formatted_history)

    def default(self, user_content):
        if user_content == "/bye":
            return True

        self.query(user_content)

    def serve(self):
        app = FastAPI()

        @app.post("/v1/chat/completions")
        async def query(request: QueryRequest):
            """API endpoint to query the Rag instance and return OpenAI-like response."""
            
            if not request.messages:
                raise HTTPException(status_code=400, detail="Messages are required")
            
            try:
                response_text = self.query(request.messages)
                
                return JSONResponse(content={
                    "id": str(uuid.uuid4()),
                    "object": "chat.completion",
                    "created": 1678901234,  # Static, could use a timestamp
                    "model": request.model,
                    "choices": [
                        {
                            "message": {
                                "role": "assistant",
                                "content": response_text
                            },
                            "finish_reason": "stop",
                            "index": 0
                        }
                    ]
                })
            except Exception as e:
                raise HTTPException(status_code=500, detail=str(e))
        
        uvicorn.run(app, host="0.0.0.0", port=8080)

def run_rag(vector_path):
    rag = Rag(vector_path)
    try:
        rag.cmdloop()
    except KeyboardInterrupt:
        print("")

def serve_rag(vector_path):
    rag = Rag(vector_path)
    rag.serve()

def load():
    client = qdrant_client.QdrantClient(":memory:")
    client.set_model(EMBED_MODEL)
    client.set_sparse_model(SPARSE_MODEL)
    TextCrossEncoder(model_name=RANK_MODEL)


parser = argparse.ArgumentParser(description="A script to enable Rag")
subparsers = parser.add_subparsers(dest='command')

run_parser = subparsers.add_parser('run', help='Run RAG interactively')
run_parser.add_argument("vector_path", type=str, help="Path to the vector database")
run_parser.set_defaults(func=run_rag)

serve_parser = subparsers.add_parser('serve', help='Run RAG as a service')
serve_parser.add_argument("vector_path", type=str, help="Path to the vector database")
serve_parser.set_defaults(func=serve_rag)

load_parser = subparsers.add_parser('load', help='Preload RAG Embedding Models')
load_parser.set_defaults(func=load)

try:
    args = parser.parse_args()

    if args.command:
        # Ensure that the appropriate function gets called with the right arguments
        if args.command in ['run', 'serve']:
            args.func(args.vector_path)  # pass vector_path argument to the respective function
        else:
            args.func()  # no argument for 'load'
except ValueError as e:
        eprint(e, 1)
