tech-envision commited on
Commit
27b075e
·
1 Parent(s): ccb1848

Add session-based conversation history

Browse files
Files changed (4) hide show
  1. README.md +1 -1
  2. run.py +1 -1
  3. src/chat.py +44 -11
  4. src/db.py +4 -0
README.md CHANGED
@@ -1,6 +1,6 @@
1
  # llm-backend
2
 
3
- This project provides a simple async interface to interact with an Ollama model and demonstrates basic tool usage. Chat histories are stored in a local SQLite database using Peewee.
4
 
5
  ## Usage
6
 
 
1
  # llm-backend
2
 
3
+ This project provides a simple async interface to interact with an Ollama model and demonstrates basic tool usage. Chat histories are stored in a local SQLite database using Peewee. Histories are persisted per user and session so conversations can be resumed with context.
4
 
5
  ## Usage
6
 
run.py CHANGED
@@ -6,7 +6,7 @@ from src.chat import ChatSession
6
 
7
 
8
  async def _main() -> None:
9
- async with ChatSession(user="demo_user") as chat:
10
  answer = await chat.chat("What did you just say?")
11
  print("\n>>>", answer)
12
 
 
6
 
7
 
8
  async def _main() -> None:
9
+ async with ChatSession(user="demo_user", session="demo_session") as chat:
10
  answer = await chat.chat("What did you just say?")
11
  print("\n>>>", answer)
12
 
src/chat.py CHANGED
@@ -6,7 +6,7 @@ import json
6
  from ollama import AsyncClient, ChatResponse, Message
7
 
8
  from .config import MAX_TOOL_CALL_DEPTH, MODEL_NAME, OLLAMA_HOST
9
- from .db import Conversation, Message, User, _db, init_db
10
  from .log import get_logger
11
  from .schema import Msg
12
  from .tools import add_two_numbers
@@ -15,11 +15,21 @@ _LOG = get_logger(__name__)
15
 
16
 
17
  class ChatSession:
18
- def __init__(self, user: str = "default", host: str = OLLAMA_HOST, model: str = MODEL_NAME) -> None:
 
 
 
 
 
 
19
  init_db()
20
  self._client = AsyncClient(host=host)
21
  self._model = model
22
  self._user, _ = User.get_or_create(username=user)
 
 
 
 
23
 
24
  async def __aenter__(self) -> "ChatSession":
25
  return self
@@ -28,6 +38,27 @@ class ChatSession:
28
  if not _db.is_closed():
29
  _db.close()
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  @staticmethod
32
  def _store_assistant_message(
33
  conversation: Conversation, message: Message
@@ -39,7 +70,7 @@ class ChatSession:
39
  else:
40
  content = message.content or ""
41
 
42
- Message.create(conversation=conversation, role="assistant", content=content)
43
 
44
  async def ask(self, messages: List[Msg], *, think: bool = True) -> ChatResponse:
45
  return await self._client.chat(
@@ -69,7 +100,7 @@ class ChatSession:
69
  "content": str(result),
70
  }
71
  )
72
- Message.create(
73
  conversation=conversation,
74
  role="tool",
75
  content=str(result),
@@ -83,14 +114,16 @@ class ChatSession:
83
  return response
84
 
85
  async def chat(self, prompt: str) -> str:
86
- conversation = Conversation.create(user=self._user)
87
- Message.create(conversation=conversation, role="user", content=prompt)
88
- messages: List[Msg] = [{"role": "user", "content": prompt}]
89
- response = await self.ask(messages)
90
- messages.append(response.message.model_dump())
91
- self._store_assistant_message(conversation, response.message)
92
 
93
  _LOG.info("Thinking:\n%s", response.message.thinking or "<no thinking trace>")
94
 
95
- final_resp = await self._handle_tool_calls(messages, response, conversation)
 
 
96
  return final_resp.message.content
 
6
  from ollama import AsyncClient, ChatResponse, Message
7
 
8
  from .config import MAX_TOOL_CALL_DEPTH, MODEL_NAME, OLLAMA_HOST
9
+ from .db import Conversation, Message as DBMessage, User, _db, init_db
10
  from .log import get_logger
11
  from .schema import Msg
12
  from .tools import add_two_numbers
 
15
 
16
 
17
  class ChatSession:
18
+ def __init__(
19
+ self,
20
+ user: str = "default",
21
+ session: str = "default",
22
+ host: str = OLLAMA_HOST,
23
+ model: str = MODEL_NAME,
24
+ ) -> None:
25
  init_db()
26
  self._client = AsyncClient(host=host)
27
  self._model = model
28
  self._user, _ = User.get_or_create(username=user)
29
+ self._conversation, _ = Conversation.get_or_create(
30
+ user=self._user, session_name=session
31
+ )
32
+ self._messages: List[Msg] = self._load_history()
33
 
34
  async def __aenter__(self) -> "ChatSession":
35
  return self
 
38
  if not _db.is_closed():
39
  _db.close()
40
 
41
+ def _load_history(self) -> List[Msg]:
42
+ messages: List[Msg] = []
43
+ for msg in self._conversation.messages.order_by(DBMessage.created_at):
44
+ if msg.role == "assistant":
45
+ try:
46
+ calls = json.loads(msg.content)
47
+ except json.JSONDecodeError:
48
+ messages.append({"role": "assistant", "content": msg.content})
49
+ else:
50
+ messages.append(
51
+ {
52
+ "role": "assistant",
53
+ "tool_calls": [Message.ToolCall(**c) for c in calls],
54
+ }
55
+ )
56
+ elif msg.role == "user":
57
+ messages.append({"role": "user", "content": msg.content})
58
+ else:
59
+ messages.append({"role": "tool", "content": msg.content})
60
+ return messages
61
+
62
  @staticmethod
63
  def _store_assistant_message(
64
  conversation: Conversation, message: Message
 
70
  else:
71
  content = message.content or ""
72
 
73
+ DBMessage.create(conversation=conversation, role="assistant", content=content)
74
 
75
  async def ask(self, messages: List[Msg], *, think: bool = True) -> ChatResponse:
76
  return await self._client.chat(
 
100
  "content": str(result),
101
  }
102
  )
103
+ DBMessage.create(
104
  conversation=conversation,
105
  role="tool",
106
  content=str(result),
 
114
  return response
115
 
116
  async def chat(self, prompt: str) -> str:
117
+ DBMessage.create(conversation=self._conversation, role="user", content=prompt)
118
+ self._messages.append({"role": "user", "content": prompt})
119
+
120
+ response = await self.ask(self._messages)
121
+ self._messages.append(response.message.model_dump())
122
+ self._store_assistant_message(self._conversation, response.message)
123
 
124
  _LOG.info("Thinking:\n%s", response.message.thinking or "<no thinking trace>")
125
 
126
+ final_resp = await self._handle_tool_calls(
127
+ self._messages, response, self._conversation
128
+ )
129
  return final_resp.message.content
src/db.py CHANGED
@@ -31,8 +31,12 @@ class User(BaseModel):
31
  class Conversation(BaseModel):
32
  id = AutoField()
33
  user = ForeignKeyField(User, backref="conversations")
 
34
  started_at = DateTimeField(default=datetime.utcnow)
35
 
 
 
 
36
 
37
  class Message(BaseModel):
38
  id = AutoField()
 
31
  class Conversation(BaseModel):
32
  id = AutoField()
33
  user = ForeignKeyField(User, backref="conversations")
34
+ session_name = CharField()
35
  started_at = DateTimeField(default=datetime.utcnow)
36
 
37
+ class Meta:
38
+ indexes = ((("user", "session_name"), True),)
39
+
40
 
41
  class Message(BaseModel):
42
  id = AutoField()