Spaces:
Runtime error
Runtime error
tech-envision
commited on
Commit
·
bf45c7d
1
Parent(s):
4dbdfda
Add user-based session tracking
Browse files
README.md
CHANGED
@@ -8,4 +8,4 @@ This project provides a simple async interface to interact with an Ollama model
|
|
8 |
python run.py
|
9 |
```
|
10 |
|
11 |
-
The script will ask the model to compute an arithmetic expression and print the answer. Conversations are automatically persisted to `chat.db
|
|
|
8 |
python run.py
|
9 |
```
|
10 |
|
11 |
+
The script will ask the model to compute an arithmetic expression and print the answer. Conversations are automatically persisted to `chat.db` and are now associated with a user and session.
|
run.py
CHANGED
@@ -6,7 +6,7 @@ from src.chat import ChatSession
|
|
6 |
|
7 |
|
8 |
async def _main() -> None:
|
9 |
-
async with ChatSession() as chat:
|
10 |
answer = await chat.chat("What is 10 + 23?")
|
11 |
print("\n>>>", answer)
|
12 |
|
|
|
6 |
|
7 |
|
8 |
async def _main() -> None:
|
9 |
+
async with ChatSession(user="demo_user") as chat:
|
10 |
answer = await chat.chat("What is 10 + 23?")
|
11 |
print("\n>>>", answer)
|
12 |
|
src/chat.py
CHANGED
@@ -6,7 +6,7 @@ import json
|
|
6 |
from ollama import AsyncClient, ChatResponse
|
7 |
|
8 |
from .config import MAX_TOOL_CALL_DEPTH, MODEL_NAME, OLLAMA_HOST
|
9 |
-
from .db import Conversation, Message, _db, init_db
|
10 |
from .log import get_logger
|
11 |
from .schema import Msg
|
12 |
from .tools import add_two_numbers
|
@@ -15,10 +15,11 @@ _LOG = get_logger(__name__)
|
|
15 |
|
16 |
|
17 |
class ChatSession:
|
18 |
-
def __init__(self, host: str = OLLAMA_HOST, model: str = MODEL_NAME) -> None:
|
19 |
init_db()
|
20 |
self._client = AsyncClient(host=host)
|
21 |
self._model = model
|
|
|
22 |
|
23 |
async def __aenter__(self) -> "ChatSession":
|
24 |
return self
|
@@ -82,7 +83,7 @@ class ChatSession:
|
|
82 |
return response
|
83 |
|
84 |
async def chat(self, prompt: str) -> str:
|
85 |
-
conversation = Conversation.create()
|
86 |
Message.create(conversation=conversation, role="user", content=prompt)
|
87 |
messages: List[Msg] = [{"role": "user", "content": prompt}]
|
88 |
response = await self.ask(messages)
|
|
|
6 |
from ollama import AsyncClient, ChatResponse
|
7 |
|
8 |
from .config import MAX_TOOL_CALL_DEPTH, MODEL_NAME, OLLAMA_HOST
|
9 |
+
from .db import Conversation, Message, User, _db, init_db
|
10 |
from .log import get_logger
|
11 |
from .schema import Msg
|
12 |
from .tools import add_two_numbers
|
|
|
15 |
|
16 |
|
17 |
class ChatSession:
|
18 |
+
def __init__(self, user: str = "default", host: str = OLLAMA_HOST, model: str = MODEL_NAME) -> None:
|
19 |
init_db()
|
20 |
self._client = AsyncClient(host=host)
|
21 |
self._model = model
|
22 |
+
self._user, _ = User.get_or_create(username=user)
|
23 |
|
24 |
async def __aenter__(self) -> "ChatSession":
|
25 |
return self
|
|
|
83 |
return response
|
84 |
|
85 |
async def chat(self, prompt: str) -> str:
|
86 |
+
conversation = Conversation.create(user=self._user)
|
87 |
Message.create(conversation=conversation, role="user", content=prompt)
|
88 |
messages: List[Msg] = [{"role": "user", "content": prompt}]
|
89 |
response = await self.ask(messages)
|
src/db.py
CHANGED
@@ -23,8 +23,14 @@ class BaseModel(Model):
|
|
23 |
database = _db
|
24 |
|
25 |
|
|
|
|
|
|
|
|
|
|
|
26 |
class Conversation(BaseModel):
|
27 |
id = AutoField()
|
|
|
28 |
started_at = DateTimeField(default=datetime.utcnow)
|
29 |
|
30 |
|
@@ -36,11 +42,11 @@ class Message(BaseModel):
|
|
36 |
created_at = DateTimeField(default=datetime.utcnow)
|
37 |
|
38 |
|
39 |
-
__all__ = ["_db", "Conversation", "Message"]
|
40 |
|
41 |
|
42 |
def init_db() -> None:
|
43 |
"""Initialise the database and create tables if they do not exist."""
|
44 |
if _db.is_closed():
|
45 |
_db.connect()
|
46 |
-
_db.create_tables([Conversation, Message])
|
|
|
23 |
database = _db
|
24 |
|
25 |
|
26 |
+
class User(BaseModel):
|
27 |
+
id = AutoField()
|
28 |
+
username = CharField(unique=True)
|
29 |
+
|
30 |
+
|
31 |
class Conversation(BaseModel):
|
32 |
id = AutoField()
|
33 |
+
user = ForeignKeyField(User, backref="conversations")
|
34 |
started_at = DateTimeField(default=datetime.utcnow)
|
35 |
|
36 |
|
|
|
42 |
created_at = DateTimeField(default=datetime.utcnow)
|
43 |
|
44 |
|
45 |
+
__all__ = ["_db", "User", "Conversation", "Message"]
|
46 |
|
47 |
|
48 |
def init_db() -> None:
|
49 |
"""Initialise the database and create tables if they do not exist."""
|
50 |
if _db.is_closed():
|
51 |
_db.connect()
|
52 |
+
_db.create_tables([User, Conversation, Message])
|