1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
| import json import sqlite3 from datetime import datetime from typing import Optional
class MemoryStore: """基于 SQLite 的记忆存储"""
def __init__(self, db_path: str = "agent_memory.db"): self.conn = sqlite3.connect(db_path) self.conn.execute(""" CREATE TABLE IF NOT EXISTS memories ( id INTEGER PRIMARY KEY AUTOINCREMENT, session_id TEXT, key TEXT, value TEXT, importance INTEGER DEFAULT 1, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) """) self.conn.execute(""" CREATE INDEX IF NOT EXISTS idx_memories_key ON memories(key) """) self.conn.commit()
def save(self, session_id: str, key: str, value: str, importance: int = 1): """保存一条记忆""" self.conn.execute( "INSERT INTO memories (session_id, key, value, importance) VALUES (?, ?, ?, ?)", (session_id, key, value, importance) ) self.conn.commit()
def recall(self, key: str, limit: int = 5) -> list[dict]: """按关键词检索记忆""" cursor = self.conn.execute( "SELECT key, value, importance, created_at FROM memories WHERE key LIKE ? ORDER BY importance DESC, created_at DESC LIMIT ?", (f"%{key}%", limit) ) return [dict(row) for row in cursor.fetchall()]
def summarize_session(self, session_id: str) -> str: """汇总一个会话的所有记忆""" cursor = self.conn.execute( "SELECT key, value FROM memories WHERE session_id = ? ORDER BY importance DESC", (session_id,) ) memories = cursor.fetchall() if not memories: return "无记忆" return "\n".join(f"- {k}: {v}" for k, v in memories)
def forget(self, key: str): """删除记忆""" self.conn.execute("DELETE FROM memories WHERE key = ?", (key,)) self.conn.commit()
class WorkingMemory: """工作记忆(当前会话)"""
def __init__(self, max_tokens: int = 8000): self.context: list[dict] = [] self.max_tokens = max_tokens
def add(self, role: str, content: str): self.context.append({"role": role, "content": content}) self._prune()
def _prune(self): """超出限制时压缩""" total = sum(len(m["content"]) for m in self.context) if total > self.max_tokens: system_msgs = [m for m in self.context if m["role"] == "system"] recent = self.context[-10:] self.context = system_msgs + [ {"role": "system", "content": f"[已压缩 {len(self.context) - len(system_msgs) - len(recent)} 条历史消息]"} ] + recent
def get_context(self) -> list[dict]: return self.context
|