File size: 2,152 Bytes
807c87c
2217b4e
c5e182a
 
 
 
 
 
d1d6df8
d9e2d70
981bdb7
d9e2d70
2217b4e
d9e2d70
 
 
 
2217b4e
 
d9e2d70
2217b4e
d9e2d70
2217b4e
 
 
1dc5913
2217b4e
 
 
 
 
 
 
 
 
 
 
 
 
de23eb7
2217b4e
1dc5913
2217b4e
 
b44d8f4
2217b4e
 
6be943f
2217b4e
 
 
 
 
6be943f
2217b4e
 
 
6be943f
a7f8386
2217b4e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import os
import atexit
import torch

print("CUDA Available:", torch.cuda.is_available())
print("GPU Count:", torch.cuda.device_count())
print("Device Name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "No GPU")


SYSTEM_PROMPT = "You are a compliance assistant. Use the provided risk data to answer user questions. If a single risk object is given, provide a direct answer. If a list of risks is provided, summarize, compare, or analyze the collection as needed. Always base your response on the data provided."
hf_token = os.environ["HF_TOKEN"]

class VllmApiServer:
    def __init__(
        self,
        model_path="casperhansen/llama-3.3-70b-instruct-awq",
        adapter_path="artemisiaai/fine-tuned-adapter",
        port=7860,  # Default HuggingFace Spaces port
        host="0.0.0.0"
    ):
        self.model_path = model_path
        self.adapter_path = adapter_path
        self.port = port
        self.host = host
        self.server_process = None
        
        # Register cleanup on exit
        atexit.register(self._cleanup_server)

    def _start_vllm_server(self):
        """Start vLLM OpenAI API server"""
        cmd = [
            "python", "-m", "vllm.entrypoints.openai.api_server",
            "--model", self.model_path,
            "--host", self.host,
            "--port", str(self.port),
            "--enable-lora",
            "--lora-modules", f"adapter={self.adapter_path}",
            "--max-lora-rank", "64",
            "--tensor-parallel-size", "4"
        ]
        
        print(f"Starting vLLM server with command: {' '.join(cmd)}")
        print(f"API will be available at: http://{self.host}:{self.port}/v1")
        
        # Run as main process (not subprocess for HuggingFace Spaces)
        os.execvp("python", cmd)

    def _cleanup_server(self):
        """Clean up vLLM server process"""
        if self.server_process:
            self.server_process.terminate()
            self.server_process.wait()

    def run(self):
        """Start the vLLM API server"""
        self._start_vllm_server()

if __name__ == "__main__":
    server = VllmApiServer()
    server.run()