mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-01-19 11:45:10 +08:00
Feat:memory sdk (#12538)
### What problem does this PR solve? Move memory and message apis to /api, and add sdk support. ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -0,0 +1,52 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import pytest
|
||||
import random
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def add_memory_func(client, request):
|
||||
def cleanup():
|
||||
memory_list_res = client.list_memory()
|
||||
exist_memory_ids = [memory.id for memory in memory_list_res["memory_list"]]
|
||||
for memory_id in exist_memory_ids:
|
||||
client.delete_memory(memory_id)
|
||||
|
||||
request.addfinalizer(cleanup)
|
||||
|
||||
memory_ids = []
|
||||
for i in range(3):
|
||||
payload = {
|
||||
"name": f"test_memory_{i}",
|
||||
"memory_type": ["raw"] + random.choices(["semantic", "episodic", "procedural"], k=random.randint(0, 3)),
|
||||
"embd_id": "BAAI/bge-large-zh-v1.5@SILICONFLOW",
|
||||
"llm_id": "glm-4-flash@ZHIPU-AI"
|
||||
}
|
||||
res = client.create_memory(**payload)
|
||||
memory_ids.append(res.id)
|
||||
request.cls.memory_ids = memory_ids
|
||||
return memory_ids
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def delete_test_memory(client, request):
|
||||
def cleanup():
|
||||
memory_list_res = client.list_memory()
|
||||
exist_memory_ids = [memory.id for memory in memory_list_res["memory_list"]]
|
||||
for memory_id in exist_memory_ids:
|
||||
client.delete_memory(memory_id)
|
||||
|
||||
request.addfinalizer(cleanup)
|
||||
return
|
||||
@ -0,0 +1,108 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import random
|
||||
import re
|
||||
|
||||
import pytest
|
||||
from configs import INVALID_API_TOKEN, HOST_ADDRESS
|
||||
from ragflow_sdk import RAGFlow
|
||||
from hypothesis import example, given, settings
|
||||
from utils.hypothesis_utils import valid_names
|
||||
|
||||
|
||||
class TestAuthorization:
|
||||
@pytest.mark.p1
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_auth, expected_message",
|
||||
[
|
||||
(None, "<Unauthorized '401: Unauthorized'>"),
|
||||
(INVALID_API_TOKEN, "<Unauthorized '401: Unauthorized'>"),
|
||||
],
|
||||
ids=["empty_auth", "invalid_api_token"]
|
||||
)
|
||||
def test_auth_invalid(self, invalid_auth, expected_message):
|
||||
client = RAGFlow(invalid_auth, HOST_ADDRESS)
|
||||
with pytest.raises(Exception) as exception_info:
|
||||
client.create_memory(**{"name": "test_memory", "memory_type": ["raw"], "embd_id": "BAAI/bge-large-zh-v1.5@SILICONFLOW", "llm_id": "glm-4-flash@ZHIPU-AI"})
|
||||
assert str(exception_info.value) == expected_message, str(exception_info.value)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("delete_test_memory")
|
||||
class TestMemoryCreate:
|
||||
@pytest.mark.p1
|
||||
@given(name=valid_names())
|
||||
@example("e" * 128)
|
||||
@settings(max_examples=20)
|
||||
def test_name(self, client, name):
|
||||
payload = {
|
||||
"name": name,
|
||||
"memory_type": ["raw"] + random.choices(["semantic", "episodic", "procedural"], k=random.randint(0, 3)),
|
||||
"embd_id": "BAAI/bge-large-zh-v1.5@SILICONFLOW",
|
||||
"llm_id": "glm-4-flash@ZHIPU-AI"
|
||||
}
|
||||
memory = client.create_memory(**payload)
|
||||
pattern = rf'^{name}|{name}(?:\((\d+)\))?$'
|
||||
escaped_name = re.escape(memory.name)
|
||||
assert re.match(pattern, escaped_name), str(memory)
|
||||
|
||||
@pytest.mark.p2
|
||||
@pytest.mark.parametrize(
|
||||
"name, expected_message",
|
||||
[
|
||||
("", "Memory name cannot be empty or whitespace."),
|
||||
(" ", "Memory name cannot be empty or whitespace."),
|
||||
("a" * 129, f"Memory name '{'a'*129}' exceeds limit of 128."),
|
||||
],
|
||||
ids=["empty_name", "space_name", "too_long_name"],
|
||||
)
|
||||
def test_name_invalid(self, client, name, expected_message):
|
||||
payload = {
|
||||
"name": name,
|
||||
"memory_type": ["raw"] + random.choices(["semantic", "episodic", "procedural"], k=random.randint(0, 3)),
|
||||
"embd_id": "BAAI/bge-large-zh-v1.5@SILICONFLOW",
|
||||
"llm_id": "glm-4-flash@ZHIPU-AI"
|
||||
}
|
||||
with pytest.raises(Exception) as exception_info:
|
||||
client.create_memory(**payload)
|
||||
assert str(exception_info.value) == expected_message, str(exception_info.value)
|
||||
|
||||
@pytest.mark.p2
|
||||
@given(name=valid_names())
|
||||
def test_type_invalid(self, client, name):
|
||||
payload = {
|
||||
"name": name,
|
||||
"memory_type": ["something"],
|
||||
"embd_id": "BAAI/bge-large-zh-v1.5@SILICONFLOW",
|
||||
"llm_id": "glm-4-flash@ZHIPU-AI"
|
||||
}
|
||||
with pytest.raises(Exception) as exception_info:
|
||||
client.create_memory(**payload)
|
||||
assert str(exception_info.value) == f"Memory type '{ {'something'} }' is not supported.", str(exception_info.value)
|
||||
|
||||
@pytest.mark.p3
|
||||
def test_name_duplicated(self, client):
|
||||
name = "duplicated_name_test"
|
||||
payload = {
|
||||
"name": name,
|
||||
"memory_type": ["raw"] + random.choices(["semantic", "episodic", "procedural"], k=random.randint(0, 3)),
|
||||
"embd_id": "BAAI/bge-large-zh-v1.5@SILICONFLOW",
|
||||
"llm_id": "glm-4-flash@ZHIPU-AI"
|
||||
}
|
||||
res1 = client.create_memory(**payload)
|
||||
assert res1.name == name, str(res1)
|
||||
|
||||
res2 = client.create_memory(**payload)
|
||||
assert res2.name == f"{name}(1)", str(res2)
|
||||
@ -0,0 +1,116 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
import pytest
|
||||
from ragflow_sdk import RAGFlow
|
||||
from configs import INVALID_API_TOKEN, HOST_ADDRESS
|
||||
|
||||
class TestAuthorization:
|
||||
@pytest.mark.p1
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_auth, expected_message",
|
||||
[
|
||||
(None, "<Unauthorized '401: Unauthorized'>"),
|
||||
(INVALID_API_TOKEN, "<Unauthorized '401: Unauthorized'>"),
|
||||
],
|
||||
)
|
||||
def test_auth_invalid(self, invalid_auth, expected_message):
|
||||
client = RAGFlow(invalid_auth, HOST_ADDRESS)
|
||||
with pytest.raises(Exception) as exception_info:
|
||||
client.list_memory()
|
||||
assert str(exception_info.value) == expected_message, str(exception_info.value)
|
||||
|
||||
|
||||
class TestCapability:
|
||||
@pytest.mark.p3
|
||||
def test_capability(self, client):
|
||||
count = 100
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
futures = [executor.submit(client.list_memory) for _ in range(count)]
|
||||
responses = list(as_completed(futures))
|
||||
assert len(responses) == count, responses
|
||||
assert all(future.result()["code"] == 0 for future in futures)
|
||||
|
||||
@pytest.mark.usefixtures("add_memory_func")
|
||||
class TestMemoryList:
|
||||
@pytest.mark.p1
|
||||
def test_params_unset(self, client):
|
||||
res = client.list_memory()
|
||||
assert len(res["memory_list"]) == 3, str(res)
|
||||
assert res["total_count"] == 3, str(res)
|
||||
|
||||
@pytest.mark.p1
|
||||
def test_params_empty(self, client):
|
||||
res = client.list_memory(**{})
|
||||
assert len(res["memory_list"]) == 3, str(res)
|
||||
assert res["total_count"] == 3, str(res)
|
||||
|
||||
@pytest.mark.p1
|
||||
@pytest.mark.parametrize(
|
||||
"params, expected_page_size",
|
||||
[
|
||||
({"page": 1, "page_size": 10}, 3),
|
||||
({"page": 2, "page_size": 10}, 0),
|
||||
({"page": 1, "page_size": 2}, 2),
|
||||
({"page": 2, "page_size": 2}, 1),
|
||||
({"page": 5, "page_size": 10}, 0),
|
||||
],
|
||||
ids=["normal_first_page", "beyond_max_page", "normal_last_partial_page" , "normal_middle_page",
|
||||
"full_data_single_page"],
|
||||
)
|
||||
def test_page(self, client, params, expected_page_size):
|
||||
# have added 3 memories in fixture
|
||||
res = client.list_memory(**params)
|
||||
assert len(res["memory_list"]) == expected_page_size, str(res)
|
||||
assert res["total_count"] == 3, str(res)
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_filter_memory_type(self, client):
|
||||
res = client.list_memory(**{"memory_type": ["semantic"]})
|
||||
for memory in res["memory_list"]:
|
||||
assert "semantic" in memory.memory_type, str(memory)
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_filter_multi_memory_type(self, client):
|
||||
res = client.list_memory(**{"memory_type": ["episodic", "procedural"]})
|
||||
for memory in res["memory_list"]:
|
||||
assert "episodic" in memory.memory_type or "procedural" in memory.memory_type, str(memory)
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_filter_storage_type(self, client):
|
||||
res = client.list_memory(**{"storage_type": "table"})
|
||||
for memory in res["memory_list"]:
|
||||
assert memory.storage_type == "table", str(memory)
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_match_keyword(self, client):
|
||||
res = client.list_memory(**{"keywords": "s"})
|
||||
for memory in res["memory_list"]:
|
||||
assert "s" in memory.name, str(memory)
|
||||
|
||||
@pytest.mark.p1
|
||||
def test_get_config(self, client):
|
||||
memory_list = client.list_memory()
|
||||
assert len(memory_list["memory_list"]) > 0, str(memory_list)
|
||||
memory = memory_list["memory_list"][0]
|
||||
memory_id = memory.id
|
||||
memory_config = memory.get_config()
|
||||
assert memory_config.id == memory_id, memory_config
|
||||
for field in ["name", "avatar", "tenant_id", "owner_name", "memory_type", "storage_type",
|
||||
"embd_id", "llm_id", "permissions", "description", "memory_size", "forgetting_policy",
|
||||
"temperature", "system_prompt", "user_prompt"]:
|
||||
assert hasattr(memory, field), memory_config
|
||||
@ -0,0 +1,52 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import pytest
|
||||
from ragflow_sdk import RAGFlow
|
||||
from configs import INVALID_API_TOKEN, HOST_ADDRESS
|
||||
|
||||
class TestAuthorization:
|
||||
@pytest.mark.p1
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_auth, expected_message",
|
||||
[
|
||||
(None, "<Unauthorized '401: Unauthorized'>"),
|
||||
(INVALID_API_TOKEN, "<Unauthorized '401: Unauthorized'>"),
|
||||
],
|
||||
)
|
||||
def test_auth_invalid(self, invalid_auth, expected_message):
|
||||
client = RAGFlow(invalid_auth, HOST_ADDRESS)
|
||||
with pytest.raises(Exception) as exception_info:
|
||||
client.delete_memory("some_memory_id")
|
||||
assert str(exception_info.value) == expected_message, str(exception_info.value)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("add_memory_func")
|
||||
class TestMemoryDelete:
|
||||
@pytest.mark.p1
|
||||
def test_memory_id(self, client):
|
||||
memory_ids = self.memory_ids
|
||||
client.delete_memory(memory_ids[0])
|
||||
res = client.list_memory()
|
||||
assert res["total_count"] == 2, res
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_id_wrong_uuid(self, client):
|
||||
with pytest.raises(Exception) as exception_info:
|
||||
client.delete_memory("d94a8dc02c9711f0930f7fbc369eab6d")
|
||||
assert exception_info.value, str(exception_info.value)
|
||||
|
||||
res = client.list_memory()
|
||||
assert len(res["memory_list"]) == 2, res
|
||||
@ -0,0 +1,164 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import random
|
||||
import pytest
|
||||
from configs import INVALID_API_TOKEN, HOST_ADDRESS
|
||||
from ragflow_sdk import RAGFlow, Memory
|
||||
from hypothesis import HealthCheck, example, given, settings
|
||||
from utils import encode_avatar
|
||||
from utils.file_utils import create_image_file
|
||||
from utils.hypothesis_utils import valid_names
|
||||
|
||||
|
||||
class TestAuthorization:
|
||||
@pytest.mark.p1
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_auth, expected_message",
|
||||
[
|
||||
(None, "<Unauthorized '401: Unauthorized'>"),
|
||||
(INVALID_API_TOKEN, "<Unauthorized '401: Unauthorized'>"),
|
||||
],
|
||||
ids=["empty_auth", "invalid_api_token"]
|
||||
)
|
||||
def test_auth_invalid(self, invalid_auth, expected_message):
|
||||
|
||||
with pytest.raises(Exception) as exception_info:
|
||||
client = RAGFlow(invalid_auth, HOST_ADDRESS)
|
||||
memory = Memory(client, {"id": "memory_id"})
|
||||
memory.update({"name": "New_Name"})
|
||||
assert str(exception_info.value) == expected_message, str(exception_info.value)
|
||||
|
||||
@pytest.mark.usefixtures("add_memory_func")
|
||||
class TestMemoryUpdate:
|
||||
|
||||
@pytest.mark.p1
|
||||
@given(name=valid_names())
|
||||
@example("f" * 128)
|
||||
@settings(max_examples=20, suppress_health_check=[HealthCheck.function_scoped_fixture])
|
||||
def test_name(self, client, name):
|
||||
memory_ids = self.memory_ids
|
||||
update_dict = {"name": name}
|
||||
memory = Memory(client, {"id": random.choice(memory_ids)})
|
||||
res = memory.update(update_dict)
|
||||
assert res.name == name, str(res)
|
||||
|
||||
@pytest.mark.p2
|
||||
@pytest.mark.parametrize(
|
||||
"name, expected_message",
|
||||
[
|
||||
("", "Memory name cannot be empty or whitespace."),
|
||||
(" ", "Memory name cannot be empty or whitespace."),
|
||||
("a" * 129, f"Memory name '{'a' * 129}' exceeds limit of 128."),
|
||||
]
|
||||
)
|
||||
def test_name_invalid(self, client, name, expected_message):
|
||||
memory_ids = self.memory_ids
|
||||
update_dict = {"name": name}
|
||||
memory = Memory(client, {"id": random.choice(memory_ids)})
|
||||
with pytest.raises(Exception) as exception_info:
|
||||
memory.update(update_dict)
|
||||
assert str(exception_info.value) == expected_message, str(exception_info.value)
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_duplicate_name(self, client):
|
||||
memory_ids = self.memory_ids
|
||||
update_dict = {"name": "Test_Memory"}
|
||||
memory_0 = Memory(client, {"id": memory_ids[0]})
|
||||
res_0 = memory_0.update(update_dict)
|
||||
assert res_0.name == "Test_Memory", str(res_0)
|
||||
|
||||
memory_1 = Memory(client, {"id": memory_ids[1]})
|
||||
res_1 = memory_1.update(update_dict)
|
||||
assert res_1.name == "Test_Memory(1)", str(res_1)
|
||||
|
||||
@pytest.mark.p1
|
||||
def test_avatar(self, client, tmp_path):
|
||||
memory_ids = self.memory_ids
|
||||
fn = create_image_file(tmp_path / "ragflow_test.png")
|
||||
update_dict = {"avatar": f"data:image/png;base64,{encode_avatar(fn)}"}
|
||||
memory = Memory(client, {"id": random.choice(memory_ids)})
|
||||
res = memory.update(update_dict)
|
||||
assert res.avatar == f"data:image/png;base64,{encode_avatar(fn)}", str(res)
|
||||
|
||||
@pytest.mark.p1
|
||||
def test_description(self, client):
|
||||
memory_ids = self.memory_ids
|
||||
description = "This is a test description."
|
||||
update_dict = {"description": description}
|
||||
memory = Memory(client, {"id": random.choice(memory_ids)})
|
||||
res = memory.update(update_dict)
|
||||
assert res.description == description, str(res)
|
||||
|
||||
@pytest.mark.p1
|
||||
def test_llm(self, client):
|
||||
memory_ids = self.memory_ids
|
||||
llm_id = "glm-4@ZHIPU-AI"
|
||||
update_dict = {"llm_id": llm_id}
|
||||
memory = Memory(client, {"id": random.choice(memory_ids)})
|
||||
res = memory.update(update_dict)
|
||||
assert res.llm_id == llm_id, str(res)
|
||||
|
||||
@pytest.mark.p1
|
||||
@pytest.mark.parametrize(
|
||||
"permission",
|
||||
[
|
||||
"me",
|
||||
"team"
|
||||
],
|
||||
ids=["me", "team"]
|
||||
)
|
||||
def test_permission(self, client, permission):
|
||||
memory_ids = self.memory_ids
|
||||
update_dict = {"permissions": permission}
|
||||
memory = Memory(client, {"id": random.choice(memory_ids)})
|
||||
res = memory.update(update_dict)
|
||||
assert res.permissions == permission.lower().strip(), str(res)
|
||||
|
||||
@pytest.mark.p1
|
||||
def test_memory_size(self, client):
|
||||
memory_ids = self.memory_ids
|
||||
memory_size = 1048576 # 1 MB
|
||||
update_dict = {"memory_size": memory_size}
|
||||
memory = Memory(client, {"id": random.choice(memory_ids)})
|
||||
res = memory.update(update_dict)
|
||||
assert res.memory_size == memory_size, str(res)
|
||||
|
||||
@pytest.mark.p1
|
||||
def test_temperature(self, client):
|
||||
memory_ids = self.memory_ids
|
||||
temperature = 0.7
|
||||
update_dict = {"temperature": temperature}
|
||||
memory = Memory(client, {"id": random.choice(memory_ids)})
|
||||
res = memory.update(update_dict)
|
||||
assert res.temperature == temperature, str(res)
|
||||
|
||||
@pytest.mark.p1
|
||||
def test_system_prompt(self, client):
|
||||
memory_ids = self.memory_ids
|
||||
system_prompt = "This is a system prompt."
|
||||
update_dict = {"system_prompt": system_prompt}
|
||||
memory = Memory(client, {"id": random.choice(memory_ids)})
|
||||
res = memory.update(update_dict)
|
||||
assert res.system_prompt == system_prompt, str(res)
|
||||
|
||||
@pytest.mark.p1
|
||||
def test_user_prompt(self, client):
|
||||
memory_ids = self.memory_ids
|
||||
user_prompt = "This is a user prompt."
|
||||
update_dict = {"user_prompt": user_prompt}
|
||||
memory = Memory(client, {"id": random.choice(memory_ids)})
|
||||
res = memory.update(update_dict)
|
||||
assert res.user_prompt == user_prompt, res
|
||||
166
test/testcases/test_sdk_api/test_message_management/conftest.py
Normal file
166
test/testcases/test_sdk_api/test_message_management/conftest.py
Normal file
@ -0,0 +1,166 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
import random
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def add_empty_raw_type_memory(client, request):
|
||||
def cleanup():
|
||||
memory_list_res = client.list_memory()
|
||||
exist_memory_ids = [memory.id for memory in memory_list_res["memory_list"]]
|
||||
for _memory_id in exist_memory_ids:
|
||||
client.delete_memory(_memory_id)
|
||||
request.addfinalizer(cleanup)
|
||||
payload = {
|
||||
"name": "test_memory_0",
|
||||
"memory_type": ["raw"],
|
||||
"embd_id": "BAAI/bge-small-en-v1.5@Builtin",
|
||||
"llm_id": "glm-4-flash@ZHIPU-AI"
|
||||
}
|
||||
res = client.create_memory(**payload)
|
||||
memory_id = res.id
|
||||
request.cls.memory_id = memory_id
|
||||
request.cls.memory_type = payload["memory_type"]
|
||||
return memory_id
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def add_empty_multiple_type_memory(client, request):
|
||||
def cleanup():
|
||||
memory_list_res = client.list_memory()
|
||||
exist_memory_ids = [memory.id for memory in memory_list_res["memory_list"]]
|
||||
for _memory_id in exist_memory_ids:
|
||||
client.delete_memory(_memory_id)
|
||||
request.addfinalizer(cleanup)
|
||||
payload = {
|
||||
"name": "test_memory_0",
|
||||
"memory_type": ["raw"] + random.choices(["semantic", "episodic", "procedural"], k=random.randint(1, 3)),
|
||||
"embd_id": "BAAI/bge-small-en-v1.5@Builtin",
|
||||
"llm_id": "glm-4-flash@ZHIPU-AI"
|
||||
}
|
||||
res = client.create_memory(**payload)
|
||||
memory_id = res.id
|
||||
request.cls.memory_id = memory_id
|
||||
request.cls.memory_type = payload["memory_type"]
|
||||
return memory_id
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def add_2_multiple_type_memory(client, request):
|
||||
def cleanup():
|
||||
memory_list_res = client.list_memory()
|
||||
exist_memory_ids = [memory.id for memory in memory_list_res["memory_list"]]
|
||||
for _memory_id in exist_memory_ids:
|
||||
client.delete_memory(_memory_id)
|
||||
|
||||
request.addfinalizer(cleanup)
|
||||
memory_ids = []
|
||||
for i in range(2):
|
||||
payload = {
|
||||
"name": f"test_memory_{i}",
|
||||
"memory_type": ["raw"] + random.choices(["semantic", "episodic", "procedural"], k=random.randint(1, 3)),
|
||||
"embd_id": "BAAI/bge-small-en-v1.5@Builtin",
|
||||
"llm_id": "glm-4-flash@ZHIPU-AI"
|
||||
}
|
||||
res = client.create_memory(**payload)
|
||||
memory_ids.append(res.id)
|
||||
request.cls.memory_ids = memory_ids
|
||||
return memory_ids
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def add_memory_with_multiple_type_message_func(client, request):
|
||||
def cleanup():
|
||||
memory_list_res = client.list_memory()
|
||||
exist_memory_ids = [mem.id for mem in memory_list_res["memory_list"]]
|
||||
for _memory_id in exist_memory_ids:
|
||||
client.delete_memory(_memory_id)
|
||||
|
||||
request.addfinalizer(cleanup)
|
||||
|
||||
payload = {
|
||||
"name": "test_memory_0",
|
||||
"memory_type": ["raw"] + random.choices(["semantic", "episodic", "procedural"], k=random.randint(1, 3)),
|
||||
"embd_id": "BAAI/bge-small-en-v1.5@Builtin",
|
||||
"llm_id": "glm-4-flash@ZHIPU-AI"
|
||||
}
|
||||
memory = client.create_memory(**payload)
|
||||
memory_id = memory.id
|
||||
agent_id = uuid.uuid4().hex
|
||||
message_payload = {
|
||||
"memory_id": [memory_id],
|
||||
"agent_id": agent_id,
|
||||
"session_id": uuid.uuid4().hex,
|
||||
"user_id": "",
|
||||
"user_input": "what is coriander?",
|
||||
"agent_response": """
|
||||
Coriander is a versatile herb with two main edible parts, and its name can refer to both:
|
||||
1. Leaves and Stems (often called Cilantro or Fresh Coriander): These are the fresh, green, fragrant leaves and tender stems of the plant Coriandrum sativum. They have a bright, citrusy, and sometimes pungent flavor. Cilantro is widely used as a garnish or key ingredient in cuisines like Mexican, Indian, Thai, and Middle Eastern.
|
||||
2. Seeds (called Coriander Seeds): These are the dried, golden-brown seeds of the same plant. When ground, they become coriander powder. The seeds have a warm, nutty, floral, and slightly citrusy taste, completely different from the fresh leaves. They are a fundamental spice in curries, stews, pickles, and baking.
|
||||
Key Point of Confusion: The naming differs by region. In North America, "coriander" typically refers to the seeds, while "cilantro" refers to the fresh leaves. In the UK, Europe, and many other parts of the world, "coriander" refers to the fresh herb, and the seeds are called "coriander seeds."
|
||||
"""
|
||||
}
|
||||
client.add_message(**message_payload)
|
||||
request.cls.memory_id = memory_id
|
||||
request.cls.agent_id = agent_id
|
||||
time.sleep(2) # make sure refresh to index before search
|
||||
return memory_id
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def add_memory_with_5_raw_message_func(client, request):
|
||||
def cleanup():
|
||||
memory_list_res = client.list_memory()
|
||||
exist_memory_ids = [mem.id for mem in memory_list_res["memory_list"]]
|
||||
for _memory_id in exist_memory_ids:
|
||||
client.delete_memory(_memory_id)
|
||||
|
||||
request.addfinalizer(cleanup)
|
||||
|
||||
payload = {
|
||||
"name": "test_memory_1",
|
||||
"memory_type": ["raw"],
|
||||
"embd_id": "BAAI/bge-small-en-v1.5@Builtin",
|
||||
"llm_id": "glm-4-flash@ZHIPU-AI"
|
||||
}
|
||||
memory = client.create_memory(**payload)
|
||||
memory_id = memory.id
|
||||
agent_ids = [uuid.uuid4().hex for _ in range(2)]
|
||||
session_ids = [uuid.uuid4().hex for _ in range(5)]
|
||||
for i in range(5):
|
||||
message_payload = {
|
||||
"memory_id": [memory_id],
|
||||
"agent_id": agent_ids[i % 2],
|
||||
"session_id": session_ids[i],
|
||||
"user_id": "",
|
||||
"user_input": "what is coriander?",
|
||||
"agent_response": """
|
||||
Coriander is a versatile herb with two main edible parts, and its name can refer to both:
|
||||
1. Leaves and Stems (often called Cilantro or Fresh Coriander): These are the fresh, green, fragrant leaves and tender stems of the plant Coriandrum sativum. They have a bright, citrusy, and sometimes pungent flavor. Cilantro is widely used as a garnish or key ingredient in cuisines like Mexican, Indian, Thai, and Middle Eastern.
|
||||
2. Seeds (called Coriander Seeds): These are the dried, golden-brown seeds of the same plant. When ground, they become coriander powder. The seeds have a warm, nutty, floral, and slightly citrusy taste, completely different from the fresh leaves. They are a fundamental spice in curries, stews, pickles, and baking.
|
||||
Key Point of Confusion: The naming differs by region. In North America, "coriander" typically refers to the seeds, while "cilantro" refers to the fresh leaves. In the UK, Europe, and many other parts of the world, "coriander" refers to the fresh herb, and the seeds are called "coriander seeds."
|
||||
"""
|
||||
}
|
||||
client.add_message(**message_payload)
|
||||
request.cls.memory_id = memory_id
|
||||
request.cls.agent_ids = agent_ids
|
||||
request.cls.session_ids = session_ids
|
||||
time.sleep(2) # make sure refresh to index before search
|
||||
return memory_id
|
||||
@ -0,0 +1,151 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import time
|
||||
import uuid
|
||||
import pytest
|
||||
from ragflow_sdk import RAGFlow, Memory
|
||||
from configs import INVALID_API_TOKEN, HOST_ADDRESS
|
||||
|
||||
class TestAuthorization:
|
||||
@pytest.mark.p1
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_auth, expected_message",
|
||||
[
|
||||
(None, "<Unauthorized '401: Unauthorized'>"),
|
||||
(INVALID_API_TOKEN, "<Unauthorized '401: Unauthorized'>"),
|
||||
],
|
||||
)
|
||||
def test_auth_invalid(self, invalid_auth, expected_message):
|
||||
client = RAGFlow(invalid_auth, HOST_ADDRESS)
|
||||
with pytest.raises(Exception) as exception_info:
|
||||
client.add_message(**{
|
||||
"memory_id": [""],
|
||||
"agent_id": "",
|
||||
"session_id": "",
|
||||
"user_id": "",
|
||||
"user_input": "what is pineapple?",
|
||||
"agent_response": ""
|
||||
})
|
||||
assert str(exception_info.value) == expected_message, str(exception_info.value)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("add_empty_raw_type_memory")
|
||||
class TestAddRawMessage:
|
||||
|
||||
@pytest.mark.p1
|
||||
def test_add_raw_message(self, client):
|
||||
memory_id = self.memory_id
|
||||
agent_id = uuid.uuid4().hex
|
||||
session_id = uuid.uuid4().hex
|
||||
message_payload = {
|
||||
"memory_id": [memory_id],
|
||||
"agent_id": agent_id,
|
||||
"session_id": session_id,
|
||||
"user_id": "",
|
||||
"user_input": "what is pineapple?",
|
||||
"agent_response": """
|
||||
A pineapple is a tropical fruit known for its sweet, tangy flavor and distinctive, spiky appearance. Here are the key facts:
|
||||
Scientific Name: Ananas comosus
|
||||
Physical Description: It has a tough, spiky, diamond-patterned outer skin (rind) that is usually green, yellow, or brownish. Inside, the juicy yellow flesh surrounds a fibrous core.
|
||||
Growth: Unlike most fruits, pineapples do not grow on trees. They grow from a central stem as a composite fruit, meaning they are formed from many individual berries that fuse together around the core. They grow on a short, leafy plant close to the ground.
|
||||
Uses: Pineapples are eaten fresh, cooked, grilled, juiced, or canned. They are a popular ingredient in desserts, fruit salads, savory dishes (like pizzas or ham glazes), smoothies, and cocktails.
|
||||
Nutrition: They are a good source of Vitamin C, manganese, and contain an enzyme called bromelain, which aids in digestion and can tenderize meat.
|
||||
Symbolism: The pineapple is a traditional symbol of hospitality and welcome in many cultures.
|
||||
Are you asking about the fruit itself, or its use in a specific context?
|
||||
"""
|
||||
}
|
||||
add_res = client.add_message(**message_payload)
|
||||
assert add_res == "All add to task.", str(add_res)
|
||||
time.sleep(2) # make sure refresh to index before search
|
||||
memory = Memory(client, {"id": memory_id})
|
||||
message_res = memory.list_memory_messages(**{"agent_id": agent_id, "keywords": session_id})
|
||||
assert message_res["messages"]["total_count"] > 0
|
||||
for message in message_res["messages"]["message_list"]:
|
||||
assert message["agent_id"] == agent_id, message
|
||||
assert message["session_id"] == session_id, message
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("add_empty_multiple_type_memory")
|
||||
class TestAddMultipleTypeMessage:
|
||||
|
||||
@pytest.mark.p1
|
||||
def test_add_multiple_type_message(self, client):
|
||||
memory_id = self.memory_id
|
||||
agent_id = uuid.uuid4().hex
|
||||
session_id = uuid.uuid4().hex
|
||||
message_payload = {
|
||||
"memory_id": [memory_id],
|
||||
"agent_id": agent_id,
|
||||
"session_id": session_id,
|
||||
"user_id": "",
|
||||
"user_input": "what is pineapple?",
|
||||
"agent_response": """
|
||||
A pineapple is a tropical fruit known for its sweet, tangy flavor and distinctive, spiky appearance. Here are the key facts:
|
||||
Scientific Name: Ananas comosus
|
||||
Physical Description: It has a tough, spiky, diamond-patterned outer skin (rind) that is usually green, yellow, or brownish. Inside, the juicy yellow flesh surrounds a fibrous core.
|
||||
Growth: Unlike most fruits, pineapples do not grow on trees. They grow from a central stem as a composite fruit, meaning they are formed from many individual berries that fuse together around the core. They grow on a short, leafy plant close to the ground.
|
||||
Uses: Pineapples are eaten fresh, cooked, grilled, juiced, or canned. They are a popular ingredient in desserts, fruit salads, savory dishes (like pizzas or ham glazes), smoothies, and cocktails.
|
||||
Nutrition: They are a good source of Vitamin C, manganese, and contain an enzyme called bromelain, which aids in digestion and can tenderize meat.
|
||||
Symbolism: The pineapple is a traditional symbol of hospitality and welcome in many cultures.
|
||||
Are you asking about the fruit itself, or its use in a specific context?
|
||||
"""
|
||||
}
|
||||
add_res = client.add_message(**message_payload)
|
||||
assert add_res == "All add to task.", str(add_res)
|
||||
time.sleep(2) # make sure refresh to index before search
|
||||
memory = Memory(client, {"id": memory_id})
|
||||
message_res = memory.list_memory_messages(**{"agent_id": agent_id, "keywords": session_id})
|
||||
assert message_res["messages"]["total_count"] > 0
|
||||
for message in message_res["messages"]["message_list"]:
|
||||
assert message["agent_id"] == agent_id, message
|
||||
assert message["session_id"] == session_id, message
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("add_2_multiple_type_memory")
|
||||
class TestAddToMultipleMemory:
|
||||
|
||||
@pytest.mark.p1
|
||||
def test_add_to_multiple_memory(self, client):
|
||||
memory_ids = self.memory_ids
|
||||
agent_id = uuid.uuid4().hex
|
||||
session_id = uuid.uuid4().hex
|
||||
message_payload = {
|
||||
"memory_id": memory_ids,
|
||||
"agent_id": agent_id,
|
||||
"session_id": session_id,
|
||||
"user_id": "",
|
||||
"user_input": "what is pineapple?",
|
||||
"agent_response": """
|
||||
A pineapple is a tropical fruit known for its sweet, tangy flavor and distinctive, spiky appearance. Here are the key facts:
|
||||
Scientific Name: Ananas comosus
|
||||
Physical Description: It has a tough, spiky, diamond-patterned outer skin (rind) that is usually green, yellow, or brownish. Inside, the juicy yellow flesh surrounds a fibrous core.
|
||||
Growth: Unlike most fruits, pineapples do not grow on trees. They grow from a central stem as a composite fruit, meaning they are formed from many individual berries that fuse together around the core. They grow on a short, leafy plant close to the ground.
|
||||
Uses: Pineapples are eaten fresh, cooked, grilled, juiced, or canned. They are a popular ingredient in desserts, fruit salads, savory dishes (like pizzas or ham glazes), smoothies, and cocktails.
|
||||
Nutrition: They are a good source of Vitamin C, manganese, and contain an enzyme called bromelain, which aids in digestion and can tenderize meat.
|
||||
Symbolism: The pineapple is a traditional symbol of hospitality and welcome in many cultures.
|
||||
Are you asking about the fruit itself, or its use in a specific context?
|
||||
"""
|
||||
}
|
||||
add_res = client.add_message(**message_payload)
|
||||
assert add_res == "All add to task.", str(add_res)
|
||||
time.sleep(2) # make sure refresh to index before search
|
||||
for memory_id in memory_ids:
|
||||
memory = Memory(client, {"id": memory_id})
|
||||
message_res = memory.list_memory_messages(**{"agent_id": agent_id, "keywords": session_id})
|
||||
assert message_res["messages"]["total_count"] > 0
|
||||
for message in message_res["messages"]["message_list"]:
|
||||
assert message["agent_id"] == agent_id, message
|
||||
assert message["session_id"] == session_id, message
|
||||
@ -0,0 +1,54 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import random
|
||||
import pytest
|
||||
from ragflow_sdk import RAGFlow, Memory
|
||||
from configs import INVALID_API_TOKEN, HOST_ADDRESS
|
||||
|
||||
|
||||
class TestAuthorization:
|
||||
@pytest.mark.p1
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_auth, expected_message",
|
||||
[
|
||||
(None, "<Unauthorized '401: Unauthorized'>"),
|
||||
(INVALID_API_TOKEN, "<Unauthorized '401: Unauthorized'>"),
|
||||
],
|
||||
)
|
||||
def test_auth_invalid(self, invalid_auth, expected_message):
|
||||
client = RAGFlow(invalid_auth, HOST_ADDRESS)
|
||||
with pytest.raises(Exception) as exception_info:
|
||||
memory = Memory(client, {"id": "empty_memory_id"})
|
||||
memory.forget_message(0)
|
||||
assert str(exception_info.value) == expected_message, str(exception_info.value)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("add_memory_with_5_raw_message_func")
|
||||
class TestForgetMessage:
|
||||
|
||||
@pytest.mark.p1
|
||||
def test_forget_message(self, client):
|
||||
memory_id = self.memory_id
|
||||
memory = Memory(client, {"id": memory_id})
|
||||
list_res = memory.list_memory_messages()
|
||||
assert len(list_res["messages"]["message_list"]) > 0
|
||||
|
||||
message = random.choice(list_res["messages"]["message_list"])
|
||||
res = memory.forget_message(message["message_id"])
|
||||
assert res, str(res)
|
||||
|
||||
forgot_message_res = memory.get_message_content(message["message_id"])
|
||||
assert forgot_message_res["forget_at"] not in ["-", ""], forgot_message_res
|
||||
@ -0,0 +1,53 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import random
|
||||
|
||||
import pytest
|
||||
from ragflow_sdk import RAGFlow, Memory
|
||||
from configs import INVALID_API_TOKEN, HOST_ADDRESS
|
||||
|
||||
class TestAuthorization:
|
||||
@pytest.mark.p1
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_auth, expected_message",
|
||||
[
|
||||
(None, "<Unauthorized '401: Unauthorized'>"),
|
||||
(INVALID_API_TOKEN, "<Unauthorized '401: Unauthorized'>"),
|
||||
],
|
||||
)
|
||||
def test_auth_invalid(self, invalid_auth, expected_message):
|
||||
client = RAGFlow(INVALID_API_TOKEN, HOST_ADDRESS)
|
||||
with pytest.raises(Exception) as exception_info:
|
||||
memory = Memory(client, {"id": "empty_memory_id"})
|
||||
memory.get_message_content(0)
|
||||
assert str(exception_info.value) == expected_message, str(exception_info.value)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("add_memory_with_multiple_type_message_func")
|
||||
class TestGetMessageContent:
|
||||
|
||||
@pytest.mark.p1
|
||||
def test_get_message_content(self,client):
|
||||
memory_id = self.memory_id
|
||||
recent_messages = client.get_recent_messages([memory_id])
|
||||
assert len(recent_messages) > 0, recent_messages
|
||||
message = random.choice(recent_messages)
|
||||
message_id = message["message_id"]
|
||||
memory = Memory(client, {"id": memory_id})
|
||||
content_res = memory.get_message_content(message_id)
|
||||
for field in ["content", "content_embed"]:
|
||||
assert field in content_res
|
||||
assert content_res[field] is not None, content_res
|
||||
@ -0,0 +1,64 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import random
|
||||
|
||||
import pytest
|
||||
from ragflow_sdk import RAGFlow
|
||||
from configs import INVALID_API_TOKEN, HOST_ADDRESS
|
||||
|
||||
|
||||
class TestAuthorization:
|
||||
@pytest.mark.p1
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_auth, expected_message",
|
||||
[
|
||||
(None, "<Unauthorized '401: Unauthorized'>"),
|
||||
(INVALID_API_TOKEN, "<Unauthorized '401: Unauthorized'>"),
|
||||
],
|
||||
)
|
||||
def test_auth_invalid(self, invalid_auth, expected_message):
|
||||
client = RAGFlow(invalid_auth, HOST_ADDRESS)
|
||||
with pytest.raises(Exception) as exception_info:
|
||||
client.get_recent_messages(["some_memory_id"])
|
||||
assert str(exception_info.value) == expected_message, str(exception_info.value)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("add_memory_with_5_raw_message_func")
|
||||
class TestGetRecentMessage:
|
||||
|
||||
@pytest.mark.p1
|
||||
def test_get_recent_messages(self, client):
|
||||
memory_id = self.memory_id
|
||||
res = client.get_recent_messages([memory_id])
|
||||
assert len(res) == 5, res
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_filter_recent_messages_by_agent(self, client):
|
||||
memory_id = self.memory_id
|
||||
agent_ids = self.agent_ids
|
||||
agent_id = random.choice(agent_ids)
|
||||
res = client.get_recent_messages(**{"agent_id": agent_id, "memory_id": [memory_id]})
|
||||
for message in res:
|
||||
assert message["agent_id"] == agent_id, message
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_filter_recent_messages_by_session(self, client):
|
||||
memory_id = self.memory_id
|
||||
session_ids = self.session_ids
|
||||
session_id = random.choice(session_ids)
|
||||
res = client.get_recent_messages(**{"session_id": session_id, "memory_id": [memory_id]})
|
||||
for message in res:
|
||||
assert message["session_id"] == session_id, message
|
||||
@ -0,0 +1,101 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import os
|
||||
import random
|
||||
|
||||
import pytest
|
||||
from ragflow_sdk import RAGFlow, Memory
|
||||
from configs import INVALID_API_TOKEN, HOST_ADDRESS
|
||||
|
||||
|
||||
class TestAuthorization:
|
||||
@pytest.mark.p1
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_auth, expected_message",
|
||||
[
|
||||
(None, "<Unauthorized '401: Unauthorized'>"),
|
||||
(INVALID_API_TOKEN, "<Unauthorized '401: Unauthorized'>"),
|
||||
],
|
||||
)
|
||||
def test_auth_invalid(self, invalid_auth, expected_message):
|
||||
client = RAGFlow(invalid_auth, HOST_ADDRESS)
|
||||
with pytest.raises(Exception) as exception_info:
|
||||
memory = Memory(client, {"id": "empty_memory_id"})
|
||||
memory.list_memory_messages()
|
||||
assert str(exception_info.value) == expected_message, str(exception_info.value)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("add_memory_with_5_raw_message_func")
|
||||
class TestMessageList:
|
||||
|
||||
@pytest.mark.p1
|
||||
def test_params_unset(self, client):
|
||||
memory_id = self.memory_id
|
||||
memory = Memory(client, {"id": memory_id})
|
||||
res = memory.list_memory_messages()
|
||||
assert len(res["messages"]["message_list"]) == 5, str(res)
|
||||
|
||||
@pytest.mark.p1
|
||||
def test_params_empty(self, client):
|
||||
memory_id = self.memory_id
|
||||
memory = Memory(client, {"id": memory_id})
|
||||
res = memory.list_memory_messages(**{})
|
||||
assert len(res["messages"]["message_list"]) == 5, str(res)
|
||||
|
||||
@pytest.mark.p1
|
||||
@pytest.mark.parametrize(
|
||||
"params, expected_page_size",
|
||||
[
|
||||
({"page": 1, "page_size": 10}, 5),
|
||||
({"page": 2, "page_size": 10}, 0),
|
||||
({"page": 1, "page_size": 2}, 2),
|
||||
({"page": 3, "page_size": 2}, 1),
|
||||
({"page": 5, "page_size": 10}, 0),
|
||||
],
|
||||
ids=["normal_first_page", "beyond_max_page", "normal_last_partial_page", "normal_middle_page",
|
||||
"full_data_single_page"],
|
||||
)
|
||||
def test_page_size(self, client, params, expected_page_size):
|
||||
# have added 5 messages in fixture
|
||||
memory_id = self.memory_id
|
||||
memory = Memory(client, {"id": memory_id})
|
||||
res = memory.list_memory_messages(**params)
|
||||
assert len(res["messages"]["message_list"]) == expected_page_size, str(res)
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_filter_agent_id(self, client):
|
||||
memory_id = self.memory_id
|
||||
agent_ids = self.agent_ids
|
||||
agent_id = random.choice(agent_ids)
|
||||
memory = Memory(client, {"id": memory_id})
|
||||
res = memory.list_memory_messages(**{"agent_id": agent_id})
|
||||
for message in res["messages"]["message_list"]:
|
||||
assert message["agent_id"] == agent_id, message
|
||||
|
||||
@pytest.mark.p2
|
||||
@pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="Not support.")
|
||||
def test_search_keyword(self, client):
|
||||
memory_id = self.memory_id
|
||||
session_ids = self.session_ids
|
||||
session_id = random.choice(session_ids)
|
||||
slice_start = random.randint(0, len(session_id) - 2)
|
||||
slice_end = random.randint(slice_start + 1, len(session_id) - 1)
|
||||
keyword = session_id[slice_start:slice_end]
|
||||
memory = Memory(client, {"id": memory_id})
|
||||
res = memory.list_memory_messages(**{"keywords": keyword})
|
||||
assert len(res["messages"]["message_list"]) > 0, res
|
||||
for message in res["messages"]["message_list"]:
|
||||
assert keyword in message["session_id"], message
|
||||
@ -0,0 +1,79 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import pytest
|
||||
from ragflow_sdk import RAGFlow, Memory
|
||||
from configs import INVALID_API_TOKEN, HOST_ADDRESS
|
||||
|
||||
class TestAuthorization:
|
||||
@pytest.mark.p1
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_auth, expected_message",
|
||||
[
|
||||
(None, "<Unauthorized '401: Unauthorized'>"),
|
||||
(INVALID_API_TOKEN, "<Unauthorized '401: Unauthorized'>"),
|
||||
],
|
||||
)
|
||||
def test_auth_invalid(self, invalid_auth, expected_message):
|
||||
client = RAGFlow(invalid_auth, HOST_ADDRESS)
|
||||
with pytest.raises(Exception) as exception_info:
|
||||
client.search_message("", ["empty_memory_id"])
|
||||
assert str(exception_info.value) == expected_message, str(exception_info.value)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("add_memory_with_multiple_type_message_func")
|
||||
class TestSearchMessage:
|
||||
|
||||
@pytest.mark.p1
|
||||
def test_query(self, client):
|
||||
memory_id = self.memory_id
|
||||
memory = Memory(client, {"id": memory_id})
|
||||
list_res = memory.list_memory_messages()
|
||||
assert list_res["messages"]["total_count"] > 0
|
||||
|
||||
query = "Coriander is a versatile herb with two main edible parts. What's its name can refer to?"
|
||||
res = client.search_message(**{"memory_id": [memory_id], "query": query})
|
||||
assert len(res) > 0
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_query_with_agent_filter(self, client):
|
||||
memory_id = self.memory_id
|
||||
memory = Memory(client, {"id": memory_id})
|
||||
list_res = memory.list_memory_messages()
|
||||
assert list_res["messages"]["total_count"] > 0
|
||||
|
||||
agent_id = self.agent_id
|
||||
query = "Coriander is a versatile herb with two main edible parts. What's its name can refer to?"
|
||||
res = client.search_message(**{"memory_id": [memory_id], "query": query, "agent_id": agent_id})
|
||||
assert len(res) > 0
|
||||
for message in res:
|
||||
assert message["agent_id"] == agent_id, message
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_query_with_not_default_params(self, client):
|
||||
memory_id = self.memory_id
|
||||
memory = Memory(client, {"id": memory_id})
|
||||
list_res = memory.list_memory_messages()
|
||||
assert list_res["messages"]["total_count"] > 0
|
||||
|
||||
query = "Coriander is a versatile herb with two main edible parts. What's its name can refer to?"
|
||||
params = {
|
||||
"similarity_threshold": 0.1,
|
||||
"keywords_similarity_weight": 0.6,
|
||||
"top_n": 4
|
||||
}
|
||||
res = client.search_message(**{"memory_id": [memory_id], "query": query, **params})
|
||||
assert len(res) > 0
|
||||
assert len(res) <= params["top_n"]
|
||||
@ -0,0 +1,73 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import random
|
||||
|
||||
import pytest
|
||||
from ragflow_sdk import RAGFlow, Memory
|
||||
from configs import INVALID_API_TOKEN, HOST_ADDRESS
|
||||
|
||||
|
||||
class TestAuthorization:
|
||||
@pytest.mark.p1
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_auth, expected_message",
|
||||
[
|
||||
(None, "<Unauthorized '401: Unauthorized'>"),
|
||||
(INVALID_API_TOKEN, "<Unauthorized '401: Unauthorized'>"),
|
||||
],
|
||||
)
|
||||
def test_auth_invalid(self, invalid_auth, expected_message):
|
||||
client = RAGFlow(invalid_auth, HOST_ADDRESS)
|
||||
with pytest.raises(Exception) as exception_info:
|
||||
memory = Memory(client, {"id": "empty_memory_id"})
|
||||
memory.update_message_status(0, False)
|
||||
assert str(exception_info.value) == expected_message, str(exception_info.value)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("add_memory_with_5_raw_message_func")
|
||||
class TestUpdateMessageStatus:
|
||||
|
||||
@pytest.mark.p1
|
||||
def test_update_to_false(self, client):
|
||||
memory_id = self.memory_id
|
||||
memory = Memory(client, {"id": memory_id})
|
||||
list_res = memory.list_memory_messages()
|
||||
assert len(list_res["messages"]["message_list"]) > 0, str(list_res)
|
||||
|
||||
message = random.choice(list_res["messages"]["message_list"])
|
||||
res = memory.update_message_status(message["message_id"], False)
|
||||
assert res, str(res)
|
||||
|
||||
updated_message_res = memory.get_message_content(message["message_id"])
|
||||
assert not updated_message_res["status"], str(updated_message_res)
|
||||
|
||||
@pytest.mark.p1
|
||||
def test_update_to_true(self, client):
|
||||
memory_id = self.memory_id
|
||||
memory = Memory(client, {"id": memory_id})
|
||||
list_res = memory.list_memory_messages()
|
||||
assert len(list_res["messages"]["message_list"]) > 0, str(list_res)
|
||||
# set 1 random message to false first
|
||||
message = random.choice(list_res["messages"]["message_list"])
|
||||
set_to_false_res = memory.update_message_status(message["message_id"], False)
|
||||
assert set_to_false_res, str(set_to_false_res)
|
||||
updated_message_res = memory.get_message_content(message["message_id"])
|
||||
assert not updated_message_res["status"], updated_message_res
|
||||
# set to true
|
||||
set_to_true_res = memory.update_message_status(message["message_id"], True)
|
||||
assert set_to_true_res, str(set_to_true_res)
|
||||
res = memory.get_message_content(message["message_id"])
|
||||
assert res["status"], res
|
||||
@ -28,8 +28,8 @@ CHUNK_API_URL = f"/{VERSION}/chunk"
|
||||
DIALOG_APP_URL = f"/{VERSION}/dialog"
|
||||
# SESSION_WITH_CHAT_ASSISTANT_API_URL = "/api/v1/chats/{chat_id}/sessions"
|
||||
# SESSION_WITH_AGENT_API_URL = "/api/v1/agents/{agent_id}/sessions"
|
||||
MEMORY_API_URL = f"/{VERSION}/memories"
|
||||
MESSAGE_API_URL = f"/{VERSION}/messages"
|
||||
MEMORY_API_URL = f"/api/{VERSION}/memories"
|
||||
MESSAGE_API_URL = f"/api/{VERSION}/messages"
|
||||
|
||||
|
||||
# KB APP
|
||||
|
||||
@ -21,7 +21,7 @@ from test_web_api.common import create_memory
|
||||
from configs import INVALID_API_TOKEN
|
||||
from libs.auth import RAGFlowWebApiAuth
|
||||
from hypothesis import example, given, settings
|
||||
from test.testcases.utils.hypothesis_utils import valid_names
|
||||
from utils.hypothesis_utils import valid_names
|
||||
|
||||
|
||||
class TestAuthorization:
|
||||
|
||||
Reference in New Issue
Block a user