tech-envision commited on
Commit
bf45c7d
·
1 Parent(s): 4dbdfda

Add user-based session tracking

Browse files
Files changed (4) hide show
  1. README.md +1 -1
  2. run.py +1 -1
  3. src/chat.py +4 -3
  4. src/db.py +8 -2
README.md CHANGED
@@ -8,4 +8,4 @@ This project provides a simple async interface to interact with an Ollama model
8
  python run.py
9
  ```
10
 
11
- The script will ask the model to compute an arithmetic expression and print the answer. Conversations are automatically persisted to `chat.db`.
 
8
  python run.py
9
  ```
10
 
11
+ The script will ask the model to compute an arithmetic expression and print the answer. Conversations are automatically persisted to `chat.db` and are now associated with a user and session.
run.py CHANGED
@@ -6,7 +6,7 @@ from src.chat import ChatSession
6
 
7
 
8
  async def _main() -> None:
9
- async with ChatSession() as chat:
10
  answer = await chat.chat("What is 10 + 23?")
11
  print("\n>>>", answer)
12
 
 
6
 
7
 
8
  async def _main() -> None:
9
+ async with ChatSession(user="demo_user") as chat:
10
  answer = await chat.chat("What is 10 + 23?")
11
  print("\n>>>", answer)
12
 
src/chat.py CHANGED
@@ -6,7 +6,7 @@ import json
6
  from ollama import AsyncClient, ChatResponse
7
 
8
  from .config import MAX_TOOL_CALL_DEPTH, MODEL_NAME, OLLAMA_HOST
9
- from .db import Conversation, Message, _db, init_db
10
  from .log import get_logger
11
  from .schema import Msg
12
  from .tools import add_two_numbers
@@ -15,10 +15,11 @@ _LOG = get_logger(__name__)
15
 
16
 
17
  class ChatSession:
18
- def __init__(self, host: str = OLLAMA_HOST, model: str = MODEL_NAME) -> None:
19
  init_db()
20
  self._client = AsyncClient(host=host)
21
  self._model = model
 
22
 
23
  async def __aenter__(self) -> "ChatSession":
24
  return self
@@ -82,7 +83,7 @@ class ChatSession:
82
  return response
83
 
84
  async def chat(self, prompt: str) -> str:
85
- conversation = Conversation.create()
86
  Message.create(conversation=conversation, role="user", content=prompt)
87
  messages: List[Msg] = [{"role": "user", "content": prompt}]
88
  response = await self.ask(messages)
 
6
  from ollama import AsyncClient, ChatResponse
7
 
8
  from .config import MAX_TOOL_CALL_DEPTH, MODEL_NAME, OLLAMA_HOST
9
+ from .db import Conversation, Message, User, _db, init_db
10
  from .log import get_logger
11
  from .schema import Msg
12
  from .tools import add_two_numbers
 
15
 
16
 
17
  class ChatSession:
18
+ def __init__(self, user: str = "default", host: str = OLLAMA_HOST, model: str = MODEL_NAME) -> None:
19
  init_db()
20
  self._client = AsyncClient(host=host)
21
  self._model = model
22
+ self._user, _ = User.get_or_create(username=user)
23
 
24
  async def __aenter__(self) -> "ChatSession":
25
  return self
 
83
  return response
84
 
85
  async def chat(self, prompt: str) -> str:
86
+ conversation = Conversation.create(user=self._user)
87
  Message.create(conversation=conversation, role="user", content=prompt)
88
  messages: List[Msg] = [{"role": "user", "content": prompt}]
89
  response = await self.ask(messages)
src/db.py CHANGED
@@ -23,8 +23,14 @@ class BaseModel(Model):
23
  database = _db
24
 
25
 
 
 
 
 
 
26
  class Conversation(BaseModel):
27
  id = AutoField()
 
28
  started_at = DateTimeField(default=datetime.utcnow)
29
 
30
 
@@ -36,11 +42,11 @@ class Message(BaseModel):
36
  created_at = DateTimeField(default=datetime.utcnow)
37
 
38
 
39
- __all__ = ["_db", "Conversation", "Message"]
40
 
41
 
42
  def init_db() -> None:
43
  """Initialise the database and create tables if they do not exist."""
44
  if _db.is_closed():
45
  _db.connect()
46
- _db.create_tables([Conversation, Message])
 
23
  database = _db
24
 
25
 
26
+ class User(BaseModel):
27
+ id = AutoField()
28
+ username = CharField(unique=True)
29
+
30
+
31
  class Conversation(BaseModel):
32
  id = AutoField()
33
+ user = ForeignKeyField(User, backref="conversations")
34
  started_at = DateTimeField(default=datetime.utcnow)
35
 
36
 
 
42
  created_at = DateTimeField(default=datetime.utcnow)
43
 
44
 
45
+ __all__ = ["_db", "User", "Conversation", "Message"]
46
 
47
 
48
  def init_db() -> None:
49
  """Initialise the database and create tables if they do not exist."""
50
  if _db.is_closed():
51
  _db.connect()
52
+ _db.create_tables([User, Conversation, Message])