Skip to content

Commit d2400f1

Browse files
authored
Feat: add full text mem (#544)
## Description <!-- Please include a summary of the changes below; Fill in the issue number that this PR addresses (if applicable); Fill in the related MemOS-Docs repository issue or PR link (if applicable); Mention the person who will review this PR (if you know who it is); Replace (summary), (issue), (docs-issue-or-pr-link), and (reviewer) with the appropriate information. 请在下方填写更改的摘要; 填写此 PR 解决的问题编号(如果适用); 填写相关的 MemOS-Docs 仓库 issue 或 PR 链接(如果适用); 提及将审查此 PR 的人(如果您知道是谁); 替换 (summary)、(issue)、(docs-issue-or-pr-link) 和 (reviewer) 为适当的信息。 --> Summary: (summary) Fix: #(issue) Docs Issue/PR: (docs-issue-or-pr-link) Reviewer: @(reviewer) ## Checklist: - [ ] I have performed a self-review of my own code | 我已自行检查了自己的代码 - [ ] I have commented my code in hard-to-understand areas | 我已在难以理解的地方对代码进行了注释 - [ ] I have added tests that prove my fix is effective or that my feature works | 我已添加测试以证明我的修复有效或功能正常 - [ ] I have created related documentation issue/PR in [MemOS-Docs](https://github.com/MemTensor/MemOS-Docs) (if applicable) | 我已在 [MemOS-Docs](https://github.com/MemTensor/MemOS-Docs) 中创建了相关的文档 issue/PR(如果适用) - [ ] I have linked the issue to this PR (if applicable) | 我已将 issue 链接到此 PR(如果适用) - [ ] I have mentioned the person who will review this PR | 我已提及将审查此 PR 的人
2 parents 5e651cd + 7e18a50 commit d2400f1

File tree

13 files changed

+327
-56
lines changed

13 files changed

+327
-56
lines changed

src/memos/api/handlers/component_init.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from memos.memories.textual.simple_preference import SimplePreferenceTextMemory
4141
from memos.memories.textual.simple_tree import SimpleTreeTextMemory
4242
from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager
43+
from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import FastTokenizer
4344

4445

4546
if TYPE_CHECKING:
@@ -142,7 +143,7 @@ def init_server() -> dict[str, Any]:
142143
)
143144

144145
logger.debug("Memory manager initialized")
145-
146+
tokenizer = FastTokenizer()
146147
# Initialize text memory
147148
text_mem = SimpleTreeTextMemory(
148149
llm=llm,
@@ -153,6 +154,7 @@ def init_server() -> dict[str, Any]:
153154
memory_manager=memory_manager,
154155
config=default_cube_config.text_mem.config,
155156
internet_retriever=internet_retriever,
157+
tokenizer=tokenizer,
156158
)
157159

158160
logger.debug("Text memory initialized")
@@ -270,7 +272,6 @@ def init_server() -> dict[str, Any]:
270272

271273
online_bot = get_online_bot_function() if dingding_enabled else None
272274
logger.info("DingDing bot is enabled")
273-
274275
# Return all components as a dictionary for easy access and extension
275276
return {
276277
"graph_db": graph_db,

src/memos/api/handlers/search_handler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def _fast_search(
191191
"""
192192
target_session_id = search_req.session_id or "default_session"
193193
search_filter = {"session_id": search_req.session_id} if search_req.session_id else None
194-
194+
plugin = bool(search_req.source is not None and search_req.source == "plugin")
195195
search_results = self.naive_mem_cube.text_mem.search(
196196
query=search_req.query,
197197
user_name=user_context.mem_cube_id,
@@ -205,6 +205,7 @@ def _fast_search(
205205
"session_id": target_session_id,
206206
"chat_history": search_req.chat_history,
207207
},
208+
plugin=plugin,
208209
)
209210

210211
formatted_memories = [format_memory_item(data) for data in search_results]

src/memos/api/product_models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ class APISearchRequest(BaseRequest):
185185
)
186186
include_preference: bool = Field(True, description="Whether to handle preference memory")
187187
pref_top_k: int = Field(6, description="Number of preference results to return")
188+
source: str | None = Field(None, description="Source of the search")
188189

189190

190191
class APIADDRequest(BaseRequest):

src/memos/graph_dbs/polardb.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1450,6 +1450,115 @@ def get_context_chain(self, id: str, type: str = "FOLLOWS") -> list[str]:
14501450
"""Get the ordered context chain starting from a node."""
14511451
raise NotImplementedError
14521452

1453+
@timed
1454+
def search_by_fulltext(
1455+
self,
1456+
query_words: list[str],
1457+
top_k: int = 10,
1458+
scope: str | None = None,
1459+
status: str | None = None,
1460+
threshold: float | None = None,
1461+
search_filter: dict | None = None,
1462+
user_name: str | None = None,
1463+
tsvector_field: str = "properties_tsvector_zh",
1464+
tsquery_config: str = "jiebaqry",
1465+
**kwargs,
1466+
) -> list[dict]:
1467+
"""
1468+
Full-text search functionality using PostgreSQL's full-text search capabilities.
1469+
1470+
Args:
1471+
query_text: query text
1472+
top_k: maximum number of results to return
1473+
scope: memory type filter (memory_type)
1474+
status: status filter, defaults to "activated"
1475+
threshold: similarity threshold filter
1476+
search_filter: additional property filter conditions
1477+
user_name: username filter
1478+
tsvector_field: full-text index field name, defaults to properties_tsvector_zh_1
1479+
tsquery_config: full-text search configuration, defaults to jiebaqry (Chinese word segmentation)
1480+
**kwargs: other parameters (e.g. cube_name)
1481+
1482+
Returns:
1483+
list[dict]: result list containing id and score
1484+
"""
1485+
# Build WHERE clause dynamically, same as search_by_embedding
1486+
where_clauses = []
1487+
1488+
if scope:
1489+
where_clauses.append(
1490+
f"ag_catalog.agtype_access_operator(properties, '\"memory_type\"'::agtype) = '\"{scope}\"'::agtype"
1491+
)
1492+
if status:
1493+
where_clauses.append(
1494+
f"ag_catalog.agtype_access_operator(properties, '\"status\"'::agtype) = '\"{status}\"'::agtype"
1495+
)
1496+
else:
1497+
where_clauses.append(
1498+
"ag_catalog.agtype_access_operator(properties, '\"status\"'::agtype) = '\"activated\"'::agtype"
1499+
)
1500+
1501+
# Add user_name filter
1502+
user_name = user_name if user_name else self.config.user_name
1503+
where_clauses.append(
1504+
f"ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = '\"{user_name}\"'::agtype"
1505+
)
1506+
1507+
# Add search_filter conditions
1508+
if search_filter:
1509+
for key, value in search_filter.items():
1510+
if isinstance(value, str):
1511+
where_clauses.append(
1512+
f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = '\"{value}\"'::agtype"
1513+
)
1514+
else:
1515+
where_clauses.append(
1516+
f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = {value}::agtype"
1517+
)
1518+
1519+
# Add fulltext search condition
1520+
# Convert query_text to OR query format: "word1 | word2 | word3"
1521+
tsquery_string = " | ".join(query_words)
1522+
1523+
where_clauses.append(f"{tsvector_field} @@ to_tsquery('{tsquery_config}', %s)")
1524+
1525+
where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""
1526+
1527+
# Build fulltext search query
1528+
query = f"""
1529+
SELECT
1530+
ag_catalog.agtype_access_operator(properties, '"id"'::agtype) AS old_id,
1531+
agtype_object_field_text(properties, 'memory') as memory_text,
1532+
ts_rank({tsvector_field}, to_tsquery('{tsquery_config}', %s)) as rank
1533+
FROM "{self.db_name}_graph"."Memory"
1534+
{where_clause}
1535+
ORDER BY rank DESC
1536+
LIMIT {top_k};
1537+
"""
1538+
1539+
params = [tsquery_string, tsquery_string]
1540+
1541+
conn = self._get_connection()
1542+
try:
1543+
with conn.cursor() as cursor:
1544+
cursor.execute(query, params)
1545+
results = cursor.fetchall()
1546+
output = []
1547+
for row in results:
1548+
oldid = row[0] # old_id
1549+
rank = row[2] # rank score
1550+
1551+
id_val = str(oldid)
1552+
score_val = float(rank)
1553+
1554+
# Apply threshold filter if specified
1555+
if threshold is None or score_val >= threshold:
1556+
output.append({"id": id_val, "score": score_val})
1557+
1558+
return output[:top_k]
1559+
finally:
1560+
self._return_connection(conn)
1561+
14531562
@timed
14541563
def search_by_embedding(
14551564
self,

src/memos/memories/textual/item.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ class PreferenceTextualMemoryMetadata(TextualMemoryMetadata):
199199
preference: str | None = Field(default=None, description="Preference.")
200200
created_at: str | None = Field(default=None, description="Timestamp of the dialog.")
201201
mem_cube_id: str | None = Field(default=None, description="ID of the MemCube.")
202+
score: float | None = Field(default=None, description="Score of the retrieval result.")
202203

203204

204205
class TextualMemoryItem(BaseModel):

src/memos/memories/textual/prefer_text_memory/extractor.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,8 @@ def extract_implicit_preference(self, qa_pair: MessageList | str) -> dict[str, A
9090
response = self.llm_provider.generate([{"role": "user", "content": prompt}])
9191
response = response.strip().replace("```json", "").replace("```", "").strip()
9292
result = json.loads(response)
93-
result["preference"] = result.pop("implicit_preference")
93+
for d in result:
94+
d["preference"] = d.pop("implicit_preference")
9495
return result
9596
except Exception as e:
9697
logger.error(f"Error extracting implicit preferences: {e}, return None")
@@ -136,20 +137,24 @@ def _process_single_chunk_implicit(
136137
if not implicit_pref:
137138
return None
138139

139-
vector_info = {
140-
"embedding": self.embedder.embed([implicit_pref["context_summary"]])[0],
141-
}
140+
memories = []
141+
for pref in implicit_pref:
142+
vector_info = {
143+
"embedding": self.embedder.embed([pref["context_summary"]])[0],
144+
}
142145

143-
extract_info = {**basic_info, **implicit_pref, **vector_info, **info}
146+
extract_info = {**basic_info, **pref, **vector_info, **info}
144147

145-
metadata = PreferenceTextualMemoryMetadata(
146-
type=msg_type, preference_type="implicit_preference", **extract_info
147-
)
148-
memory = TextualMemoryItem(
149-
id=extract_info["dialog_id"], memory=implicit_pref["context_summary"], metadata=metadata
150-
)
148+
metadata = PreferenceTextualMemoryMetadata(
149+
type=msg_type, preference_type="implicit_preference", **extract_info
150+
)
151+
memory = TextualMemoryItem(
152+
id=str(uuid.uuid4()), memory=pref["context_summary"], metadata=metadata
153+
)
151154

152-
return memory
155+
memories.append(memory)
156+
157+
return memories
153158

154159
def extract(
155160
self,

src/memos/memories/textual/prefer_text_memory/retrievers.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import os
2+
13
from abc import ABC, abstractmethod
24
from typing import Any
35

@@ -34,9 +36,12 @@ def _naive_reranker(
3436
self, query: str, prefs_mem: list[TextualMemoryItem], top_k: int, **kwargs: Any
3537
) -> list[TextualMemoryItem]:
3638
if self.reranker:
37-
prefs_mem = self.reranker.rerank(query, prefs_mem, top_k)
38-
return [item for item, _ in prefs_mem]
39-
return prefs_mem
39+
prefs_mem_reranked = []
40+
prefs_mem_tuple = self.reranker.rerank(query, prefs_mem, top_k)
41+
for item, score in prefs_mem_tuple:
42+
item.metadata.score = score
43+
prefs_mem_reranked.append(item)
44+
return prefs_mem_reranked
4045

4146
def _original_text_reranker(
4247
self,
@@ -52,11 +57,22 @@ def _original_text_reranker(
5257
prefs_mem_for_reranker = deepcopy(prefs_mem)
5358
for pref_mem, pref in zip(prefs_mem_for_reranker, prefs, strict=False):
5459
pref_mem.memory = pref_mem.memory + "\n" + pref.original_text
55-
prefs_mem_for_reranker = self.reranker.rerank(query, prefs_mem_for_reranker, top_k)
56-
prefs_mem_for_reranker = [item for item, _ in prefs_mem_for_reranker]
60+
reranked_results = self.reranker.rerank(query, prefs_mem_for_reranker, top_k)
61+
prefs_mem_for_reranker = [item for item, _ in reranked_results]
5762
prefs_ids = [item.id for item in prefs_mem_for_reranker]
5863
prefs_dict = {item.id: item for item in prefs_mem}
59-
return [prefs_dict[item_id] for item_id in prefs_ids if item_id in prefs_dict]
64+
65+
# Create mapping from id to score from reranked results
66+
reranked_scores = {item.id: score for item, score in reranked_results}
67+
68+
# Assign scores to the original items
69+
result_items = []
70+
for item_id in prefs_ids:
71+
if item_id in prefs_dict:
72+
original_item = prefs_dict[item_id]
73+
original_item.metadata.score = reranked_scores.get(item_id)
74+
result_items.append(original_item)
75+
return result_items
6076
return prefs_mem
6177

6278
def retrieve(
@@ -119,9 +135,6 @@ def retrieve(
119135
if pref.payload.get("preference", None)
120136
]
121137

122-
# store explicit id and score, use it after reranker
123-
explicit_id_scores = {item.id: item.score for item in explicit_prefs}
124-
125138
reranker_map = {
126139
"naive": self._naive_reranker,
127140
"original_text": self._original_text_reranker,
@@ -136,7 +149,14 @@ def retrieve(
136149

137150
# filter explicit mem by score bigger than threshold
138151
explicit_prefs_mem = [
139-
item for item in explicit_prefs_mem if explicit_id_scores.get(item.id, 0) >= 0.0
152+
item
153+
for item in explicit_prefs_mem
154+
if item.metadata.score >= float(os.getenv("PREFERENCE_SEARCH_THRESHOLD", 0.0))
155+
]
156+
implicit_prefs_mem = [
157+
item
158+
for item in implicit_prefs_mem
159+
if item.metadata.score >= float(os.getenv("PREFERENCE_SEARCH_THRESHOLD", 0.0))
140160
]
141161

142162
return explicit_prefs_mem + implicit_prefs_mem

src/memos/memories/textual/simple_tree.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from memos.memories.textual.tree import TreeTextMemory
1010
from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager
1111
from memos.memories.textual.tree_text_memory.retrieve.bm25_util import EnhancedBM25
12+
from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import FastTokenizer
1213
from memos.reranker.base import BaseReranker
1314

1415

@@ -35,6 +36,7 @@ def __init__(
3536
config: TreeTextMemoryConfig,
3637
internet_retriever: None = None,
3738
is_reorganize: bool = False,
39+
tokenizer: FastTokenizer | None = None,
3840
):
3941
"""Initialize memory with the given configuration."""
4042
self.config: TreeTextMemoryConfig = config
@@ -51,6 +53,7 @@ def __init__(
5153
if self.search_strategy and self.search_strategy.get("bm25", False)
5254
else None
5355
)
56+
self.tokenizer = tokenizer
5457
self.reranker = reranker
5558
self.memory_manager: MemoryManager = memory_manager
5659
# Create internet retriever if configured

src/memos/memories/textual/tree.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def __init__(self, config: TreeTextMemoryConfig):
8989
)
9090
else:
9191
logger.info("No internet retriever configured")
92+
self.tokenizer = None
9293

9394
def add(
9495
self,
@@ -165,6 +166,7 @@ def search(
165166
moscube: bool = False,
166167
search_filter: dict | None = None,
167168
user_name: str | None = None,
169+
**kwargs,
168170
) -> list[TextualMemoryItem]:
169171
"""Search for memories based on a query.
170172
User query -> TaskGoalParser -> MemoryPathResolver ->
@@ -199,6 +201,7 @@ def search(
199201
moscube=moscube,
200202
search_strategy=self.search_strategy,
201203
manual_close_internet=manual_close_internet,
204+
tokenizer=self.tokenizer,
202205
)
203206
else:
204207
searcher = Searcher(
@@ -211,9 +214,17 @@ def search(
211214
moscube=moscube,
212215
search_strategy=self.search_strategy,
213216
manual_close_internet=manual_close_internet,
217+
tokenizer=self.tokenizer,
214218
)
215219
return searcher.search(
216-
query, top_k, info, mode, memory_type, search_filter, user_name=user_name
220+
query,
221+
top_k,
222+
info,
223+
mode,
224+
memory_type,
225+
search_filter,
226+
user_name=user_name,
227+
plugin=kwargs.get("plugin", False),
217228
)
218229

219230
def get_relevant_subgraph(

0 commit comments

Comments
 (0)