File size: 6,675 Bytes
4467e27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
854cf7b
4467e27
 
854cf7b
4467e27
854cf7b
 
 
4467e27
 
 
854cf7b
 
4467e27
854cf7b
4467e27
 
 
854cf7b
4467e27
854cf7b
 
 
 
 
4467e27
854cf7b
4467e27
854cf7b
 
 
 
 
4467e27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
854cf7b
 
4467e27
854cf7b
 
4467e27
 
854cf7b
4467e27
854cf7b
4467e27
 
14b7692
a29101f
854cf7b
4a83fd2
854cf7b
4467e27
 
 
 
 
 
 
 
854cf7b
4467e27
854cf7b
4a83fd2
4467e27
854cf7b
 
4a83fd2
854cf7b
4a83fd2
854cf7b
4467e27
4a83fd2
854cf7b
4a83fd2
 
 
 
 
4467e27
4a83fd2
854cf7b
4a83fd2
 
 
 
 
854cf7b
 
4467e27
854cf7b
 
4a83fd2
854cf7b
 
14b7692
4467e27
854cf7b
 
 
 
 
 
 
4467e27
a29101f
854cf7b
4467e27
 
 
 
 
854cf7b
 
 
 
 
4467e27
854cf7b
4a83fd2
4467e27
 
4a83fd2
4467e27
 
14b7692
 
4467e27
 
4a83fd2
4467e27
 
 
 
 
 
4a83fd2
4467e27
 
 
14b7692
4467e27
a29101f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
# # from fastapi import FastAPI, Request
# # from pydantic import BaseModel
# # from transformers import AutoModel, AutoTokenizer
# # import torch

# # app = FastAPI()

# # model_id = "Qwen/Qwen3-Embedding-0.6B"

# # # Load tokenizer
# # tokenizer = AutoTokenizer.from_pretrained(model_id)

# # # Load model with GPU if available, else CPU
# # use_gpu = torch.cuda.is_available()

# # if use_gpu:
# #     print("CUDA is available, loading model with 4-bit quantization on GPU.")
# #     model = AutoModel.from_pretrained(
# #         model_id,
# #         device_map="auto",
# #         torch_dtype=torch.float16,
# #         load_in_4bit=True
# #     )
# # else:
# #     print("CUDA not available, loading model without 4-bit quantization on CPU.")
# #     model = AutoModel.from_pretrained(
# #         model_id,
# #         device_map="cpu",
# #         torch_dtype=torch.float32
# #     )

# # model.eval()

# # class TextInput(BaseModel):
# #     text: str

# # @app.post("/embed")
# # async def embed_text(input: TextInput):
# #     inputs = tokenizer(input.text, return_tensors="pt", truncation=True, max_length=512)

# #     # Move input tensors to same device as model
# #     device = next(model.parameters()).device
# #     inputs = {k: v.to(device) for k, v in inputs.items()}

# #     with torch.no_grad():
# #         outputs = model(**inputs)
# #         embeddings = outputs.last_hidden_state.mean(dim=1)  # Mean pooling

# #     # Convert to list for JSON serialization
# #     return {"embedding": embeddings[0].cpu().tolist()}

# from fastapi import FastAPI
# from pydantic import BaseModel
# from typing import List
# from transformers import AutoTokenizer, AutoModel
# import torch
# import torch.nn.functional as F

# app = FastAPI()

# # Model config
# MODEL_ID = "Qwen/Qwen3-Embedding-0.6B"
# USE_GPU = torch.cuda.is_available()

# # Load tokenizer
# tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, padding_side='left')

# # Load model with appropriate settings
# if USE_GPU:
#     print("🔋 Loading model on GPU with 4-bit quantization...")
#     model = AutoModel.from_pretrained(
#         MODEL_ID,
#         device_map="auto",
#         torch_dtype=torch.float16,
#         load_in_4bit=True
#     )
# else:
#     print("🧠 Loading model on CPU...")
#     model = AutoModel.from_pretrained(
#         MODEL_ID,
#         device_map="cpu",
#         torch_dtype=torch.float32
#     )

# model.eval()
# device = next(model.parameters()).device

# # Input schema
# class EmbedRequest(BaseModel):
#     texts: List[str]

# # Output schema
# class EmbedResponse(BaseModel):
#     embeddings: List[List[float]]

# # Masked mean pooling (ignores padded tokens)
# def masked_mean_pooling(last_hidden_state, attention_mask):
#     mask = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
#     masked_embeddings = last_hidden_state * mask
#     summed = masked_embeddings.sum(dim=1)
#     counts = mask.sum(dim=1)
#     return summed / counts.clamp(min=1e-9)

# @app.post("/embed", response_model=EmbedResponse)
# async def embed_texts(request: EmbedRequest):
#     # Tokenize input texts
#     inputs = tokenizer(
#         request.texts,
#         return_tensors="pt",
#         padding=True,
#         truncation=True,
#         max_length=32768  # Qwen supports long sequences
#     )
#     inputs = {k: v.to(device) for k, v in inputs.items()}

#     # Get embeddings
#     with torch.no_grad():
#         outputs = model(**inputs)
#         pooled = masked_mean_pooling(outputs.last_hidden_state, inputs['attention_mask'])
#         normalized = F.normalize(pooled, p=2, dim=1)

#     return {"embeddings": normalized.cpu().tolist()}


from fastapi import FastAPI, Request
from pydantic import BaseModel
from typing import List
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F
import logging
import time

# ------------------- Logging Setup --------------------
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(message)s",
    level=logging.INFO
)

# ------------------- FastAPI Setup --------------------
app = FastAPI()

# ------------------- Model Config --------------------
MODEL_ID = "Qwen/Qwen3-Embedding-0.6B"
USE_GPU = torch.cuda.is_available()

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, padding_side='left')

if USE_GPU:
    logging.info("🔋 Loading model on GPU with 4-bit quantization...")
    model = AutoModel.from_pretrained(
        MODEL_ID,
        device_map="auto",
        torch_dtype=torch.float16,
        load_in_4bit=True
    )
else:
    logging.info("🧠 Loading model on CPU...")
    model = AutoModel.from_pretrained(
        MODEL_ID,
        device_map="cpu",
        torch_dtype=torch.float32
    )

model.eval()
device = next(model.parameters()).device

# ------------------- Data Schemas --------------------
class EmbedRequest(BaseModel):
    texts: List[str]

class EmbedResponse(BaseModel):
    embeddings: List[List[float]]

# ------------------- Pooling Function --------------------
def masked_mean_pooling(last_hidden_state, attention_mask):
    mask = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
    masked_embeddings = last_hidden_state * mask
    summed = masked_embeddings.sum(dim=1)
    counts = mask.sum(dim=1)
    return summed / counts.clamp(min=1e-9)

# ------------------- API Endpoint --------------------
@app.post("/embed", response_model=EmbedResponse)
async def embed_texts(request: EmbedRequest):
    overall_start = time.perf_counter()
    logging.info(f"📩 Received request with {len(request.texts)} texts.")

    # Tokenization
    t0 = time.perf_counter()
    inputs = tokenizer(
        request.texts,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=32768
    )
    inputs = {k: v.to(device) for k, v in inputs.items()}
    t1 = time.perf_counter()
    logging.info(f"🧾 Tokenization took {t1 - t0:.3f} seconds.")

    # Model inference
    t2 = time.perf_counter()
    with torch.no_grad():
        outputs = model(**inputs)
    t3 = time.perf_counter()
    logging.info(f"🧠 Model inference took {t3 - t2:.3f} seconds.")

    # Pooling
    t4 = time.perf_counter()
    pooled = masked_mean_pooling(outputs.last_hidden_state, inputs['attention_mask'])
    normalized = F.normalize(pooled, p=2, dim=1)
    t5 = time.perf_counter()
    logging.info(f"🌀 Pooling & normalization took {t5 - t4:.3f} seconds.")

    # Total
    overall_end = time.perf_counter()
    logging.info(f"✅ Total processing time: {overall_end - overall_start:.3f} seconds.")

    return {"embeddings": normalized.cpu().tolist()}