File size: 5,510 Bytes
4d1849f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import asyncio
import logging
import torch
from difflib import SequenceMatcher
from time import time
from typing import Dict, List, Tuple

from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity

from config_settings_public import MODEL_PATHS, MODEL_CONFIG, MODEL_WEIGHTS
from model_manager_public import ModelManager
from response_optimizer_public import ResponseOptimizer

logger = logging.getLogger(__name__)


class QueryProcessor:
    def __init__(self, model_manager: ModelManager, response_optimizer: ResponseOptimizer):
        self.model_manager = model_manager
        self.response_optimizer = response_optimizer
        self.model_paths = MODEL_PATHS
        self.model_config = MODEL_CONFIG
        self.memory_bank = []

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        if torch.cuda.is_available():
            torch.cuda.set_device(0)
            torch.backends.cudnn.benchmark = True
            torch.backends.cudnn.enabled = True
            torch.set_float32_matmul_precision('high')
            logger.info(f"✅ QueryProcessor initialized on {self.device}")
        else:
            logger.warning("⚠️ Running on CPU")

        self.model_tiers = {
            "balanced": ["gemma", "mistral", "llama"]
        }

    def select_models(self, query: str) -> List[str]:
        return [m for m in self.model_tiers["balanced"] if m in self.model_paths]

    async def process_query_with_fusion(self, query: str) -> str:
        selected_models = self.select_models(query)
        logger.info(f"🧠 Selected models: {selected_models}")

        tasks = [asyncio.create_task(self.invoke_model(m, query)) for m in selected_models]
        responses = await asyncio.gather(*tasks, return_exceptions=True)

        valid_responses = {
            m: r for m, r in zip(selected_models, responses)
            if isinstance(r, str) and r.strip()
        }

        if not valid_responses:
            logger.error("❌ All models failed or returned empty responses.")
            return "⚠️ No valid responses received from models."

        logger.info(f"✅ {len(valid_responses)} valid responses received.")
        fused = await self.advanced_fusion(valid_responses, query)
        self.memory_bank.append({"query": query, "response": fused, "timestamp": time()})
        return fused

    async def invoke_model(self, model_name: str, query: str) -> str:
        try:
            model = self.model_manager.loaded_models.get(model_name)
            if not model:
                model = self.model_manager.load_model_from_disk(model_name)
                if not model:
                    raise RuntimeError(f"Model '{model_name}' failed to load.")

            config = self.model_config.get(model_name, {})
            max_tokens = config.get("max_tokens", 800)
            temperature = config.get("temperature", 0.7)

            logger.info(f"🔧 Invoking '{model_name}' with temp={temperature}, max_tokens={max_tokens}")

            with torch.no_grad():
                completion = await asyncio.to_thread(
                    model.create_completion,
                    prompt=query,
                    max_tokens=max_tokens,
                    temperature=temperature,
                    stop=["</s>", "User:", "Assistant:"]
                )

            response_text = completion.get("choices", [{}])[0].get("text", "").strip()
            return response_text

        except Exception as e:
            logger.error(f"❌ Error in model '{model_name}': {e}")
            return ""

    def clean_response(self, raw) -> str:
        if isinstance(raw, str):
            return raw.strip()

        if isinstance(raw, dict):
            return raw.get("generated_text", "").strip() or raw.get("text", "").strip()

        if isinstance(raw, list):
            if isinstance(raw[0], dict) and "generated_text" in raw[0]:
                return raw[0]["generated_text"].strip()
            return " ".join(str(x).strip() for x in raw if isinstance(x, str))

        return str(raw).strip()

    async def advanced_fusion(self, responses: Dict[str, str], query: str) -> str:
        cleaned = {m: self.clean_response(r) for m, r in responses.items()}
        seen = set()
        unique: List[Tuple[str, float, str]] = []

        for model, text in cleaned.items():
            if not text or any(self.is_similar(text, prev) for prev in seen):
                continue
            seen.add(text)
            score = self.score_response(text, query)
            unique.append((text, score, model))

        if not unique:
            return "⚠️ No high-quality content to merge."

        unique.sort(key=lambda x: x[1], reverse=True)
        best_responses = [await self.response_optimizer.optimize_response(t, query) for t, _, _ in unique[:3]]
        return "\n---\n".join(best_responses).strip()

    def is_similar(self, a: str, b: str, threshold: float = 0.9) -> bool:
        return SequenceMatcher(None, a.strip().lower(), b.strip().lower()).ratio() > threshold

    def score_response(self, response: str, query: str) -> float:
        try:
            tfidf = TfidfVectorizer().fit_transform([response, query])
            return cosine_similarity(tfidf[0:1], tfidf[1:2])[0][0]
        except:
            return 0.5