Spaces:
Runtime error
Runtime error
tech-envision
commited on
Commit
·
66bc790
1
Parent(s):
572053e
Add per-user VM registry and update upload paths
Browse files- src/chat.py +6 -5
- src/vm.py +40 -0
src/chat.py
CHANGED
@@ -27,7 +27,7 @@ from .db import (
|
|
27 |
from .log import get_logger
|
28 |
from .schema import Msg
|
29 |
from .tools import execute_terminal, set_vm
|
30 |
-
from .vm import
|
31 |
|
32 |
_LOG = get_logger(__name__)
|
33 |
|
@@ -48,18 +48,19 @@ class ChatSession:
|
|
48 |
self._conversation, _ = Conversation.get_or_create(
|
49 |
user=self._user, session_name=session
|
50 |
)
|
51 |
-
self._vm =
|
52 |
self._messages: List[Msg] = self._load_history()
|
53 |
self._ensure_system_prompt()
|
54 |
|
55 |
async def __aenter__(self) -> "ChatSession":
|
56 |
-
self._vm.
|
57 |
set_vm(self._vm)
|
58 |
return self
|
59 |
|
60 |
async def __aexit__(self, exc_type, exc, tb) -> None:
|
61 |
set_vm(None)
|
62 |
-
self._vm
|
|
|
63 |
if not _db.is_closed():
|
64 |
_db.close()
|
65 |
|
@@ -79,7 +80,7 @@ class ChatSession:
|
|
79 |
target = dest / src.name
|
80 |
shutil.copy(src, target)
|
81 |
add_document(self._user.username, str(target), src.name)
|
82 |
-
return f"/data/{
|
83 |
|
84 |
def _ensure_system_prompt(self) -> None:
|
85 |
if any(m.get("role") == "system" for m in self._messages):
|
|
|
27 |
from .log import get_logger
|
28 |
from .schema import Msg
|
29 |
from .tools import execute_terminal, set_vm
|
30 |
+
from .vm import VMRegistry
|
31 |
|
32 |
_LOG = get_logger(__name__)
|
33 |
|
|
|
48 |
self._conversation, _ = Conversation.get_or_create(
|
49 |
user=self._user, session_name=session
|
50 |
)
|
51 |
+
self._vm = None
|
52 |
self._messages: List[Msg] = self._load_history()
|
53 |
self._ensure_system_prompt()
|
54 |
|
55 |
async def __aenter__(self) -> "ChatSession":
|
56 |
+
self._vm = VMRegistry.acquire(self._user.username)
|
57 |
set_vm(self._vm)
|
58 |
return self
|
59 |
|
60 |
async def __aexit__(self, exc_type, exc, tb) -> None:
|
61 |
set_vm(None)
|
62 |
+
if self._vm:
|
63 |
+
VMRegistry.release(self._user.username)
|
64 |
if not _db.is_closed():
|
65 |
_db.close()
|
66 |
|
|
|
80 |
target = dest / src.name
|
81 |
shutil.copy(src, target)
|
82 |
add_document(self._user.username, str(target), src.name)
|
83 |
+
return f"/data/{src.name}"
|
84 |
|
85 |
def _ensure_system_prompt(self) -> None:
|
86 |
if any(m.get("role") == "system" for m in self._messages):
|
src/vm.py
CHANGED
@@ -4,6 +4,8 @@ import subprocess
|
|
4 |
import uuid
|
5 |
from pathlib import Path
|
6 |
|
|
|
|
|
7 |
from .config import UPLOAD_DIR
|
8 |
|
9 |
from .log import get_logger
|
@@ -108,3 +110,41 @@ class LinuxVM:
|
|
108 |
|
109 |
def __exit__(self, exc_type, exc, tb) -> None:
|
110 |
self.stop()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
import uuid
|
5 |
from pathlib import Path
|
6 |
|
7 |
+
from threading import Lock
|
8 |
+
|
9 |
from .config import UPLOAD_DIR
|
10 |
|
11 |
from .log import get_logger
|
|
|
110 |
|
111 |
def __exit__(self, exc_type, exc, tb) -> None:
|
112 |
self.stop()
|
113 |
+
|
114 |
+
|
115 |
+
class VMRegistry:
|
116 |
+
"""Manage Linux VM instances on a per-user basis."""
|
117 |
+
|
118 |
+
_vms: dict[str, LinuxVM] = {}
|
119 |
+
_counts: dict[str, int] = {}
|
120 |
+
_lock = Lock()
|
121 |
+
|
122 |
+
@classmethod
|
123 |
+
def acquire(cls, username: str) -> LinuxVM:
|
124 |
+
"""Return a running VM for ``username``, creating it if needed."""
|
125 |
+
|
126 |
+
with cls._lock:
|
127 |
+
vm = cls._vms.get(username)
|
128 |
+
if vm is None:
|
129 |
+
vm = LinuxVM(host_dir=str(Path(UPLOAD_DIR) / username))
|
130 |
+
cls._vms[username] = vm
|
131 |
+
cls._counts[username] = 0
|
132 |
+
cls._counts[username] += 1
|
133 |
+
|
134 |
+
vm.start()
|
135 |
+
return vm
|
136 |
+
|
137 |
+
@classmethod
|
138 |
+
def release(cls, username: str) -> None:
|
139 |
+
"""Release one reference to ``username``'s VM and stop it if unused."""
|
140 |
+
|
141 |
+
with cls._lock:
|
142 |
+
vm = cls._vms.get(username)
|
143 |
+
if vm is None:
|
144 |
+
return
|
145 |
+
|
146 |
+
cls._counts[username] -= 1
|
147 |
+
if cls._counts[username] <= 0:
|
148 |
+
vm.stop()
|
149 |
+
del cls._vms[username]
|
150 |
+
del cls._counts[username]
|