Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 117 additions & 83 deletions veadk/memory/short_term_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from functools import wraps
from typing import Any, Callable, Literal
from typing import TYPE_CHECKING, Any, Callable, Literal

from google.adk.sessions import (
BaseSessionService,
Expand All @@ -34,6 +34,11 @@
)
from veadk.utils.logger import get_logger

if TYPE_CHECKING:
from google.adk.events import Event

from veadk import Agent

logger = get_logger(__name__)


Expand Down Expand Up @@ -69,57 +74,6 @@ class ShortTermMemory(BaseModel):
Default to `/tmp/veadk_local_database.db`.
after_load_memory_callback (Callable | None):
A callback to be called after loading memory from the backend. The callback function should accept `Session` as an input.

Examples:
### In-memory simple memory

You can initialize a short term memory with in-memory storage:

```python
from veadk import Agent, Runner
from veadk.memory.short_term_memory import ShortTermMemory
import asyncio

session_id = "veadk_playground_session"

agent = Agent()
short_term_memory = ShortTermMemory(backend="local")

runner = Runner(
agent=agent, short_term_memory=short_term_memory)

# This invocation will be stored in short-term memory
response = asyncio.run(runner.run(
messages="My name is VeADK", session_id=session_id
))
print(response)

# The history invocation can be fetched by model
response = asyncio.run(runner.run(
messages="Do you remember my name?", session_id=session_id # keep the same `session_id`
))
print(response)
```

### Memory with a Database URL

Also you can use a databasae connection URL to initialize a short-term memory:

```python
from veadk.memory.short_term_memory import ShortTermMemory

short_term_memory = ShortTermMemory(db_url="...")
```

### Memory with SQLite

Once you want to start the short term memory with a local SQLite, you can specify the backend to `sqlite`. It will create a local database in `local_database_path`:

```python
from veadk.memory.short_term_memory import ShortTermMemory

short_term_memory = ShortTermMemory(backend="sqlite", local_database_path="")
```
"""

backend: Literal["local", "mysql", "sqlite", "postgresql", "database"] = "local"
Expand Down Expand Up @@ -200,37 +154,6 @@ async def create_session(

Returns:
Session | None: The retrieved or newly created `Session` object, or `None` if the session creation failed.

Examples:
Create a new session manually:

```python
import asyncio

from veadk.memory import ShortTermMemory

app_name = "app_name"
user_id = "user_id"
session_id = "session_id"

short_term_memory = ShortTermMemory()

session = asyncio.run(
short_term_memory.create_session(
app_name=app_name, user_id=user_id, session_id=session_id
)
)

print(session)

session = asyncio.run(
short_term_memory.session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session_id
)
)

print(session)
```
"""
if isinstance(self._session_service, DatabaseSessionService):
list_sessions_response = await self._session_service.list_sessions(
Expand All @@ -254,3 +177,114 @@ async def create_session(
return await self._session_service.create_session(
app_name=app_name, user_id=user_id, session_id=session_id
)

async def generate_profile(
self,
app_name: str,
user_id: str,
session_id: str,
events: list["Event"],
) -> list[str]:
import json

from veadk import Agent, Runner
from veadk.memory.types import MemoryProfile
from veadk.utils.misc import write_string_to_file

event_text = ""
for event in events:
event_text += f"- Event id: {event.id}\nEvent content: {event.content}\n"

agent = Agent(
name="memory_summarizer",
description="A summarizer that summarizes the memory events.",
instruction="""Summarize the memory events into different groups according to the event content. An event can belong to multiple groups. You must output the summary in JSON format (Each group should have a simple name (only a-z and _ is allowed), and a list of event ids):
[
{
"name": "",
"event_ids": ["Event id here"]
},
{
"name": "",
"event_ids": ["Event id here"]
}
]""",
model_name="deepseek-v3-2-251201",
output_schema=MemoryProfile,
)
runner = Runner(agent=agent)

response = await runner.run(messages="Events are: \n" + event_text)

# profile path: ./profiles/memory/<app_name>/user_id/session_id/profile_name.json
groups = json.loads(response)
group_names = [group["name"] for group in groups]

for group in groups:
group["event_list"] = []
for event_id in group["event_ids"]:
for event in events:
if event.id == event_id:
group["event_list"].append(event.content.model_dump_json())

write_string_to_file(
content=json.dumps(group_names, ensure_ascii=False),
file_path=f"./profiles/memory/{app_name}/{user_id}/{session_id}/profile_list.json",
)

for group in groups:
write_string_to_file(
content=json.dumps(group, ensure_ascii=False),
file_path=f"./profiles/memory/{app_name}/{user_id}/{session_id}/{group['name']}.json",
)
return group_names

async def compact_history_events(
self,
app_name: str,
user_id: str,
session_id: str,
compact_limit: int,
agent: "Agent",
):
# 1. generate profile
# 2. compact history events
# 3. append instruction and corresponding tool
session = await self.session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session_id
)

compact_event_num = 0
compact_counter = 0
for event in session.events:
if event.content.role == "user":
compact_counter += 1
if compact_counter > compact_limit:
break
compact_event_num += 1

events_need_compact = session.events[:compact_event_num] # type: ignore

group_names = await self.generate_profile(
app_name=app_name,
user_id=user_id,
session_id=session_id,
events=events_need_compact,
)

# TODO(yaozheng): directly edit the events are not work as expected,
# need to check the reason later
session.events = session.events[compact_event_num:] # type: ignore
logger.debug(f"Compacted {compact_event_num} events.")

agent.instruction += f"""
The session has been compacted for the first {compact_limit} events. The compacted content are divided into following groups:

{group_names}

You can call `load_history_events` to load the compacted events if you need them according to the user's request.
"""

from veadk.tools.load_history_events import load_history_events

agent.tools.append(load_history_events)
20 changes: 20 additions & 0 deletions veadk/memory/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pydantic import BaseModel


class MemoryProfile(BaseModel):
name: str
event_ids: list[str]
45 changes: 45 additions & 0 deletions veadk/tools/load_history_events.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
from pathlib import Path

from google.adk.tools.tool_context import ToolContext


def load_profile(profile_path: Path) -> dict:
# read file content
with open(profile_path, "r") as f:
content = f.read()
return json.loads(content)


def load_history_events(group_names: list[str], tool_context: ToolContext) -> dict:
"""Load necessary history events by group names.

Args:
group_names (list[str]): The list of group names to load events for.
"""
app_name = tool_context._invocation_context.app_name
user_id = tool_context._invocation_context.user_id
session_id = tool_context._invocation_context.session.id

events = {}
for group_name in group_names:
profile_path = Path(
f"./profiles/memory/{app_name}/{user_id}/{session_id}/{group_name}.json"
)
profile = load_profile(profile_path)
events[group_name] = profile.get("event_list", [])
return events