brendon-ai commited on
Commit
ff570d4
·
verified ·
1 Parent(s): f4b92bb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -63
app.py CHANGED
@@ -1,64 +1,68 @@
1
- public class TinyLlama extends OllamaContainer {
2
-
3
- private final String imageName;
4
-
5
- public TinyLlama(String imageName) {
6
- super(DockerImageName.parse(imageName)
7
- .asCompatibleSubstituteFor("ollama/ollama:0.1.44"));
8
- this.imageName = imageName;
9
- }
10
-
11
- public void createImage(String imageName) {
12
- var ollama = new OllamaContainer("ollama/ollama:0.1.44");
13
- ollama.start();
14
- try {
15
- ollama.execInContainer("apt-get", "update");
16
- ollama.execInContainer("apt-get", "upgrade", "-y");
17
- ollama.execInContainer("apt-get", "install", "-y", "python3-pip");
18
- ollama.execInContainer("pip", "install", "huggingface-hub");
19
- ollama.execInContainer(
20
- "huggingface-cli",
21
- "download",
22
- "DavidAU/DistiLabelOrca-TinyLLama-1.1B-Q8_0-GGUF",
23
- "distilabelorca-tinyllama-1.1b.Q8_0.gguf",
24
- "--local-dir",
25
- "."
26
- );
27
- ollama.execInContainer(
28
- "sh",
29
- "-c",
30
- String.format("echo '%s' > Modelfile", "FROM distilabelorca-tinyllama-1.1b.Q8_0.gguf")
31
- );
32
- ollama.execInContainer("ollama", "create", "distilabelorca-tinyllama-1.1b.Q8_0.gguf", "-f", "Modelfile");
33
- ollama.execInContainer("rm", "distilabelorca-tinyllama-1.1b.Q8_0.gguf");
34
- ollama.commitToImage(imageName);
35
- } catch (IOException | InterruptedException e) {
36
- throw new ContainerFetchException(e.getMessage());
37
- }
38
- }
39
-
40
- public String getModelName() {
41
- return "distilabelorca-tinyllama-1.1b.Q8_0.gguf";
42
- }
43
-
44
- @Override
45
- public void start() {
46
- try {
47
- super.start();
48
- } catch (ContainerFetchException ex) {
49
- // If image doesn't exist, create it. Subsequent runs will reuse the image.
50
- createImage(imageName);
51
- super.start();
52
- }
53
- }
54
- }
55
 
56
- var tinyLlama = new TinyLlama("faq-ai");
57
- tinyLlama.start();
58
- String response = given()
59
- .baseUri(tinyLlama.getEndpoint())
60
- .header(new Header("Content-Type", "application/json"))
61
- .body(new CompletionRequest(tinyLlama.getModelName() + ":latest", List.of(new Message("user", "What is the capital of France?")), false))
62
- .post("/api/chat")
63
- .getBody().as(ChatResponse.class).message.content;
64
- System.out.println("Response from LLM " + response);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
+ from transformers import AutoTokenizer, AutoModelForMaskedLM
4
+ import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
+ app = FastAPI()
7
+
8
+ # Load model globally to avoid reloading on each request
9
+ tokenizer = AutoTokenizer.from_pretrained("boltuix/NeuroBERT-Tiny")
10
+ model = AutoModelForMaskedLM.from_pretrained("boltuix/NeuroBERT-Tiny")
11
+ model.eval() # Set model to evaluation mode
12
+
13
+ class InferenceRequest(BaseModel):
14
+ text: str
15
+
16
+ class PredictionResult(BaseModel):
17
+ sequence: str
18
+ score: float
19
+ token: int
20
+ token_str: str
21
+
22
+ @app.post("/predict", response_model=list[PredictionResult])
23
+ async def predict_masked_lm(request: InferenceRequest):
24
+ text = request.text
25
+ inputs = tokenizer(text, return_tensors="pt")
26
+
27
+ with torch.no_grad():
28
+ outputs = model(**inputs)
29
+
30
+ logits = outputs.logits
31
+ masked_token_id = tokenizer.convert_tokens_to_ids("[MASK]")
32
+
33
+ # Find all masked tokens
34
+ masked_token_indices = torch.where(inputs["input_ids"] == masked_token_id)[1]
35
+
36
+ results = []
37
+ for masked_index in masked_token_indices:
38
+ # Get top 5 predictions for the masked token
39
+ top_5_logits = torch.topk(logits[0, masked_index], 5).values
40
+ top_5_tokens = torch.topk(logits[0, masked_index], 5).indices
41
+
42
+ for i in range(5):
43
+ score = torch.nn.functional.softmax(logits[0, masked_index], dim=-1)[top_5_tokens[i]].item()
44
+ predicted_token_id = top_5_tokens[i].item()
45
+ predicted_token_str = tokenizer.decode(predicted_token_id)
46
+
47
+ # Replace the [MASK] with the predicted token for the full sequence
48
+ # Create a temporary input_ids tensor to get the sequence
49
+ temp_input_ids = inputs["input_ids"].clone()
50
+ temp_input_ids[0, masked_index] = predicted_token_id
51
+ full_sequence = tokenizer.decode(temp_input_ids[0], skip_special_tokens=True)
52
+
53
+ results.append(PredictionResult(
54
+ sequence=full_sequence,
55
+ score=score,
56
+ token=predicted_token_id,
57
+ token_str=predicted_token_str
58
+ ))
59
+ return results
60
+
61
+ # Optional: A simple health check endpoint
62
+ @app.get("/")
63
+ async def root():
64
+ return {"message": "NeuroBERT-Tiny API is running!"}
65
+
66
+ if __name__ == "__main__":
67
+ import uvicorn
68
+ uvicorn.run(app, host="0.0.0.0", port=8000)