# implementations/async_memory.py
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import sessionmaker

from app.settings import DatabaseSettings, MemorySettings
from app.memory.memory import ConversationMemoryInterface
from app.utils.token_counter import SimpleTokenCounter, TikTokenCounter
from app.memory.models.base import Base
from app.memory.models.message import Message
from app.memory.models.user import User
from typing import List, Dict, Optional
from datetime import datetime
from zoneinfo import ZoneInfo
from sqlalchemy.future import select

class AsyncPostgresConversationMemory(ConversationMemoryInterface):
    def __init__(self, db_settings: DatabaseSettings, memory_settings: MemorySettings):
        self.engine = create_async_engine(
            db_settings.url,
            pool_size=db_settings.pool_size,
            max_overflow=db_settings.max_overflow,
            pool_timeout=db_settings.pool_timeout
        )

        self.async_session = sessionmaker(
            self.engine, class_=AsyncSession, expire_on_commit=False
        )
        self.token_limit = memory_settings.token_limit
        
        
        if memory_settings.token_counter == "tiktoken":
            self.token_counter = TikTokenCounter(memory_settings.model_name)
        else:
            self.token_counter = SimpleTokenCounter()

    async def initialize(self):
        """Initialize the database by creating all tables."""
        async with self.engine.begin() as conn:
            await conn.run_sync(Base.metadata.create_all)
    
    # In your async_memory.py
    async def add_message(self, username: str, role: str, message: str, timestamp: Optional[datetime] = None) -> None:
        from app.memory.models.user import User  # Import here to avoid circular dependencies
        async with self.async_session() as session:
            # Look up the user by username
            result = await session.execute(select(User).filter_by(username=username))
            user = result.scalars().first()
            if user is None:
                raise ValueError(f"User with username '{username}' not found")
            
            if timestamp is None:
                timestamp = datetime.now(ZoneInfo("Asia/Jakarta"))
            
            # Create the message using the found user's id
            msg = Message(user_id=user.id, role=role, message=message, timestamp=timestamp)
            session.add(msg)
            await session.commit()
            await self.trim_memory_if_needed(session)



    async def get_all_history(self) -> List[Dict]:
        async with self.async_session() as session:
            result = await session.execute(
                select(Message).order_by(Message.timestamp)
            )
            messages = result.scalars().all()
            return [{"role": msg.role, "content": msg.message} for msg in messages]
    
    async def get_history(
        self, 
        username: Optional[str] = None, 
        token_limit: Optional[int] = None, 
        last_n: Optional[int] = None
    ) -> List[Dict]:
        async with self.async_session() as session:
            # Build the base query
            query = select(Message).order_by(Message.timestamp)
            if username is not None:
                # Join with User table and filter by username
                query = query.join(User).filter(User.username == username)
            result = await session.execute(query)
            messages = result.scalars().all()

        # Accumulate messages in reverse (latest first)
        selected = []
        total_tokens = 0
        for msg in reversed(messages):
            tokens = self.token_counter.count_tokens(msg.message)
            # If token_limit is specified and no message has been added yet,
            # force-add the last message even if it exceeds token_limit.
            if token_limit is not None and len(selected) == 0 and tokens > token_limit:
                selected.append(msg)
                total_tokens = tokens
                continue
            # Otherwise, check if adding this message would exceed the token limit.
            if token_limit is not None and total_tokens + tokens > token_limit:
                break
            selected.append(msg)
            total_tokens += tokens
            # Stop if we've reached the maximum number of messages.
            if last_n is not None and len(selected) >= last_n:
                break

        # Reverse to return in chronological order
        selected.reverse()
        return [{"role": msg.role, "parts": msg.message} for msg in selected]


    async def clear_memory(self) -> None:
        async with self.async_session() as session:
            await session.execute(select(Message).delete())
            await session.commit()

    async def get_total_tokens(self) -> int:
        async with self.async_session() as session:
            result = await session.execute(select(Message))
            messages = result.scalars().all()
            return sum(self.token_counter.count_tokens(msg.message) for msg in messages)

    async def trim_memory_if_needed(self, session: AsyncSession) -> None:
        result = await session.execute(select(Message).order_by(Message.timestamp))
        messages = result.scalars().all()
        total_tokens = sum(self.token_counter.count_tokens(msg.message) for msg in messages)
        
        while total_tokens > self.token_limit and messages:
            oldest = messages.pop(0)
            total_tokens -= self.token_counter.count_tokens(oldest.message)
            await session.delete(oldest)
        
        await session.commit()