mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-05-05 09:47:47 +08:00
Feat/tenant model (#13072)
### What problem does this PR solve? Add id for table tenant_llm and apply in LLMBundle. ### Type of change - [x] Refactoring --------- Co-authored-by: Yingfeng <yingfeng.zhang@gmail.com> Co-authored-by: Liu An <asiro@qq.com>
This commit is contained in:
@ -44,9 +44,10 @@ class _AwaitableValue:
|
||||
|
||||
|
||||
class _DummyKB:
|
||||
def __init__(self, embd_id="embd@factory", chunk_num=1):
|
||||
def __init__(self, embd_id="embd@factory", chunk_num=1, tenant_embd_id=1):
|
||||
self.embd_id = embd_id
|
||||
self.chunk_num = chunk_num
|
||||
self.tenant_embd_id = tenant_embd_id
|
||||
|
||||
def to_json(self):
|
||||
return {"id": "kb-1"}
|
||||
|
||||
@ -45,9 +45,10 @@ class _AwaitableValue:
|
||||
|
||||
|
||||
class _DummyKB:
|
||||
def __init__(self, tenant_id="tenant-1", embd_id="embd-1"):
|
||||
def __init__(self, tenant_id="tenant-1", embd_id="embd-1", tenant_embd_id=1):
|
||||
self.tenant_id = tenant_id
|
||||
self.embd_id = embd_id
|
||||
self.tenant_embd_id = tenant_embd_id
|
||||
|
||||
|
||||
class _DummyRetriever:
|
||||
@ -102,6 +103,138 @@ def _load_dify_retrieval_module(monkeypatch):
|
||||
monkeypatch.setitem(sys.modules, "deepdoc.parser.utils", deepdoc_parser_utils)
|
||||
monkeypatch.setitem(sys.modules, "xgboost", ModuleType("xgboost"))
|
||||
|
||||
# Mock tenant_llm_service for TenantLLMService and TenantService
|
||||
tenant_llm_service_mod = ModuleType("api.db.services.tenant_llm_service")
|
||||
|
||||
class _MockModelConfig:
|
||||
def __init__(self, tenant_id, model_name):
|
||||
self.tenant_id = tenant_id
|
||||
self.llm_name = model_name
|
||||
self.llm_factory = "Builtin"
|
||||
self.api_key = "fake-api-key"
|
||||
self.api_base = "https://api.example.com"
|
||||
self.model_type = "chat"
|
||||
self.max_tokens = 8192
|
||||
self.used_tokens = 0
|
||||
self.status = 1
|
||||
self.id = 1
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"tenant_id": self.tenant_id,
|
||||
"llm_name": self.llm_name,
|
||||
"llm_factory": self.llm_factory,
|
||||
"api_key": self.api_key,
|
||||
"api_base": self.api_base,
|
||||
"model_type": self.model_type,
|
||||
"max_tokens": self.max_tokens,
|
||||
"used_tokens": self.used_tokens,
|
||||
"status": self.status,
|
||||
"id": self.id
|
||||
}
|
||||
|
||||
class _StubTenantService:
|
||||
@staticmethod
|
||||
def get_by_id(tenant_id):
|
||||
# Return a mock tenant with default model configurations
|
||||
return True, SimpleNamespace(
|
||||
id=tenant_id,
|
||||
llm_id="chat-model",
|
||||
embd_id="embd-model",
|
||||
asr_id="asr-model",
|
||||
img2txt_id="img2txt-model",
|
||||
rerank_id="rerank-model",
|
||||
tts_id="tts-model"
|
||||
)
|
||||
|
||||
class _StubTenantLLMService:
|
||||
@staticmethod
|
||||
def get_api_key(tenant_id, model_name):
|
||||
return _MockModelConfig(tenant_id, model_name)
|
||||
|
||||
@staticmethod
|
||||
def split_model_name_and_factory(model_name):
|
||||
if "@" in model_name:
|
||||
parts = model_name.split("@")
|
||||
return parts[0], parts[1]
|
||||
return model_name, None
|
||||
|
||||
tenant_llm_service_mod.TenantService = _StubTenantService
|
||||
tenant_llm_service_mod.TenantLLMService = _StubTenantLLMService
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.tenant_llm_service", tenant_llm_service_mod)
|
||||
|
||||
# Mock llm_service for LLMService
|
||||
llm_service_mod = ModuleType("api.db.services.llm_service")
|
||||
|
||||
class _StubLLM:
|
||||
def __init__(self, llm_name):
|
||||
self.llm_name = llm_name
|
||||
self.is_tools = False
|
||||
|
||||
class _StubLLMBundle:
|
||||
def __init__(self, tenant_id: str, model_config: dict, lang="Chinese", **kwargs):
|
||||
self.tenant_id = tenant_id
|
||||
self.model_config = model_config
|
||||
self.lang = lang
|
||||
|
||||
def encode(self, texts: list):
|
||||
import numpy as np
|
||||
# Return mock embeddings and token usage
|
||||
return [np.array([0.1, 0.2, 0.3]) for _ in texts], len(texts) * 10
|
||||
|
||||
llm_service_mod.LLMService = SimpleNamespace(
|
||||
query=lambda llm_name: [_StubLLM(llm_name)] if llm_name else []
|
||||
)
|
||||
llm_service_mod.LLMBundle = _StubLLMBundle
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.llm_service", llm_service_mod)
|
||||
|
||||
# Mock tenant_model_service to ensure it uses mocked services
|
||||
tenant_model_service_mod = ModuleType("api.db.joint_services.tenant_model_service")
|
||||
|
||||
class _MockModelConfig2:
|
||||
def __init__(self, tenant_id, model_name):
|
||||
self.tenant_id = tenant_id
|
||||
self.llm_name = model_name
|
||||
self.llm_factory = "Builtin"
|
||||
self.api_key = "fake-api-key"
|
||||
self.api_base = "https://api.example.com"
|
||||
self.model_type = "chat"
|
||||
self.max_tokens = 8192
|
||||
self.used_tokens = 0
|
||||
self.status = 1
|
||||
self.id = 1
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"tenant_id": self.tenant_id,
|
||||
"llm_name": self.llm_name,
|
||||
"llm_factory": self.llm_factory,
|
||||
"api_key": self.api_key,
|
||||
"api_base": self.api_base,
|
||||
"model_type": self.model_type,
|
||||
"max_tokens": self.max_tokens,
|
||||
"used_tokens": self.used_tokens,
|
||||
"status": self.status,
|
||||
"id": self.id
|
||||
}
|
||||
|
||||
def _get_model_config_by_id(tenant_model_id: int) -> dict:
|
||||
return _MockModelConfig2("tenant-1", "model-1").to_dict()
|
||||
|
||||
def _get_model_config_by_type_and_name(tenant_id: str, model_type: str, model_name: str):
|
||||
if not model_name:
|
||||
raise Exception("Model Name is required")
|
||||
return _MockModelConfig2(tenant_id, model_name).to_dict()
|
||||
|
||||
def _get_tenant_default_model_by_type(tenant_id: str, model_type):
|
||||
# Return mock tenant with default model configurations
|
||||
return _MockModelConfig2(tenant_id, "chat-model").to_dict()
|
||||
|
||||
tenant_model_service_mod.get_model_config_by_id = _get_model_config_by_id
|
||||
tenant_model_service_mod.get_model_config_by_type_and_name = _get_model_config_by_type_and_name
|
||||
tenant_model_service_mod.get_tenant_default_model_by_type = _get_tenant_default_model_by_type
|
||||
monkeypatch.setitem(sys.modules, "api.db.joint_services.tenant_model_service", tenant_model_service_mod)
|
||||
|
||||
module_name = "test_dify_retrieval_routes_unit_module"
|
||||
module_path = repo_root / "api" / "apps" / "sdk" / "dify_retrieval.py"
|
||||
spec = importlib.util.spec_from_file_location(module_name, module_path)
|
||||
@ -134,7 +267,6 @@ def test_retrieval_success_with_metadata_and_kg(monkeypatch):
|
||||
monkeypatch.setattr(module, "jsonify", lambda payload: payload)
|
||||
monkeypatch.setattr(module.DocMetadataService, "get_meta_by_kbs", lambda _kb_ids: [{"doc_id": "doc-1"}])
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, _DummyKB()))
|
||||
monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: object())
|
||||
monkeypatch.setattr(module, "convert_conditions", lambda cond: cond.get("conditions", []))
|
||||
monkeypatch.setattr(module, "meta_filter", lambda *_args, **_kwargs: [])
|
||||
|
||||
@ -185,7 +317,6 @@ def test_retrieval_not_found_exception_mapping(monkeypatch):
|
||||
_set_request_json(monkeypatch, module, {"knowledge_id": "kb-1", "query": "hello"})
|
||||
monkeypatch.setattr(module.DocMetadataService, "get_meta_by_kbs", lambda _kb_ids: [])
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, _DummyKB()))
|
||||
monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: object())
|
||||
monkeypatch.setattr(module, "label_question", lambda *_args, **_kwargs: [])
|
||||
|
||||
class _BrokenRetriever:
|
||||
@ -205,7 +336,6 @@ def test_retrieval_generic_exception_mapping(monkeypatch):
|
||||
_set_request_json(monkeypatch, module, {"knowledge_id": "kb-1", "query": "hello"})
|
||||
monkeypatch.setattr(module.DocMetadataService, "get_meta_by_kbs", lambda _kb_ids: [])
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, _DummyKB()))
|
||||
monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: object())
|
||||
monkeypatch.setattr(module, "label_question", lambda *_args, **_kwargs: [])
|
||||
|
||||
class _BrokenRetriever:
|
||||
|
||||
@ -151,6 +151,149 @@ def _load_doc_module(monkeypatch):
|
||||
monkeypatch.setitem(sys.modules, "deepdoc.parser.utils", deepdoc_parser_utils)
|
||||
monkeypatch.setitem(sys.modules, "xgboost", ModuleType("xgboost"))
|
||||
|
||||
# Mock tenant_llm_service for TenantLLMService and TenantService
|
||||
tenant_llm_service_mod = ModuleType("api.db.services.tenant_llm_service")
|
||||
|
||||
class _MockModelConfig:
|
||||
def __init__(self, tenant_id, model_name):
|
||||
self.tenant_id = tenant_id
|
||||
self.llm_name = model_name
|
||||
self.llm_factory = "Builtin"
|
||||
self.api_key = "fake-api-key"
|
||||
self.api_base = "https://api.example.com"
|
||||
self.model_type = "embedding"
|
||||
self.max_tokens = 8192
|
||||
self.used_tokens = 0
|
||||
self.status = 1
|
||||
self.id = 1
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"tenant_id": self.tenant_id,
|
||||
"llm_name": self.llm_name,
|
||||
"llm_factory": self.llm_factory,
|
||||
"api_key": self.api_key,
|
||||
"api_base": self.api_base,
|
||||
"model_type": self.model_type,
|
||||
"max_tokens": self.max_tokens,
|
||||
"used_tokens": self.used_tokens,
|
||||
"status": self.status,
|
||||
"id": self.id
|
||||
}
|
||||
|
||||
class _StubTenantService:
|
||||
@staticmethod
|
||||
def get_by_id(tenant_id):
|
||||
return True, SimpleNamespace(
|
||||
id=tenant_id,
|
||||
llm_id="chat-model",
|
||||
embd_id="embd-model",
|
||||
asr_id="asr-model",
|
||||
img2txt_id="img2txt-model",
|
||||
rerank_id="rerank-model",
|
||||
tts_id="tts-model"
|
||||
)
|
||||
|
||||
class _StubTenantLLMService:
|
||||
@staticmethod
|
||||
def get_api_key(tenant_id, model_name):
|
||||
return _MockModelConfig(tenant_id, model_name)
|
||||
|
||||
@staticmethod
|
||||
def split_model_name_and_factory(model_name):
|
||||
if "@" in model_name:
|
||||
parts = model_name.split("@")
|
||||
return parts[0], parts[1]
|
||||
return model_name, None
|
||||
|
||||
@staticmethod
|
||||
def get_by_id(tenant_model_id):
|
||||
return True, _MockModelConfig("tenant-1", "model-1")
|
||||
|
||||
@staticmethod
|
||||
def model_instance(model_config):
|
||||
class _EmbedModel:
|
||||
def encode(self, texts):
|
||||
import numpy as np
|
||||
return [np.array([0.2, 0.8]), np.array([0.3, 0.7])], 1
|
||||
return _EmbedModel()
|
||||
|
||||
tenant_llm_service_mod.TenantService = _StubTenantService
|
||||
tenant_llm_service_mod.TenantLLMService = _StubTenantLLMService
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.tenant_llm_service", tenant_llm_service_mod)
|
||||
|
||||
# Mock LLMService
|
||||
llm_service_mod = ModuleType("api.db.services.llm_service")
|
||||
|
||||
class _StubLLM:
|
||||
def __init__(self, llm_name):
|
||||
self.llm_name = llm_name
|
||||
self.is_tools = False
|
||||
|
||||
class _StubLLMBundle:
|
||||
def __init__(self, tenant_id: str, model_config: dict, lang="Chinese", **kwargs):
|
||||
self.tenant_id = tenant_id
|
||||
self.model_config = model_config
|
||||
self.lang = lang
|
||||
|
||||
def encode(self, texts: list):
|
||||
import numpy as np
|
||||
# Return mock embeddings and token usage
|
||||
return [np.array([0.2, 0.8]), np.array([0.3, 0.7])], len(texts) * 10
|
||||
|
||||
llm_service_mod.LLMService = SimpleNamespace(
|
||||
query=lambda llm_name: [_StubLLM(llm_name)] if llm_name else []
|
||||
)
|
||||
llm_service_mod.LLMBundle = _StubLLMBundle
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.llm_service", llm_service_mod)
|
||||
|
||||
# Mock tenant_model_service to ensure it uses mocked services
|
||||
tenant_model_service_mod = ModuleType("api.db.joint_services.tenant_model_service")
|
||||
|
||||
class _MockModelConfig2:
|
||||
def __init__(self, tenant_id, model_name):
|
||||
self.tenant_id = tenant_id
|
||||
self.llm_name = model_name
|
||||
self.llm_factory = "Builtin"
|
||||
self.api_key = "fake-api-key"
|
||||
self.api_base = "https://api.example.com"
|
||||
self.model_type = "embedding"
|
||||
self.max_tokens = 8192
|
||||
self.used_tokens = 0
|
||||
self.status = 1
|
||||
self.id = 1
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"tenant_id": self.tenant_id,
|
||||
"llm_name": self.llm_name,
|
||||
"llm_factory": self.llm_factory,
|
||||
"api_key": self.api_key,
|
||||
"api_base": self.api_base,
|
||||
"model_type": self.model_type,
|
||||
"max_tokens": self.max_tokens,
|
||||
"used_tokens": self.used_tokens,
|
||||
"status": self.status,
|
||||
"id": self.id
|
||||
}
|
||||
|
||||
def _get_model_config_by_id(tenant_model_id: int) -> dict:
|
||||
return _MockModelConfig2("tenant-1", "model-1").to_dict()
|
||||
|
||||
def _get_model_config_by_type_and_name(tenant_id: str, model_type: str, model_name: str):
|
||||
if not model_name:
|
||||
raise Exception("Model Name is required")
|
||||
return _MockModelConfig2(tenant_id, model_name).to_dict()
|
||||
|
||||
def _get_tenant_default_model_by_type(tenant_id: str, model_type):
|
||||
# Return mock tenant with default model configurations
|
||||
return _MockModelConfig2(tenant_id, "chat-model").to_dict()
|
||||
|
||||
tenant_model_service_mod.get_model_config_by_id = _get_model_config_by_id
|
||||
tenant_model_service_mod.get_model_config_by_type_and_name = _get_model_config_by_type_and_name
|
||||
tenant_model_service_mod.get_tenant_default_model_by_type = _get_tenant_default_model_by_type
|
||||
monkeypatch.setitem(sys.modules, "api.db.joint_services.tenant_model_service", tenant_model_service_mod)
|
||||
|
||||
module_path = repo_root / "api" / "apps" / "sdk" / "doc.py"
|
||||
spec = importlib.util.spec_from_file_location("test_doc_sdk_routes_unit", module_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
@ -762,6 +905,7 @@ class TestDocRoutesUnit:
|
||||
monkeypatch.setattr(module.rag_tokenizer, "fine_grained_tokenize", lambda text: text or "")
|
||||
monkeypatch.setattr(module.rag_tokenizer, "is_chinese", lambda _text: False)
|
||||
monkeypatch.setattr(module.DocumentService, "get_embd_id", lambda _doc_id: "embd")
|
||||
monkeypatch.setattr(module.DocumentService, "get_tenant_embd_id", lambda _doc_id: 1)
|
||||
|
||||
class _EmbedModel:
|
||||
def encode(self, _texts):
|
||||
@ -851,8 +995,8 @@ class TestDocRoutesUnit:
|
||||
"get_request_json",
|
||||
lambda: _AwaitableValue({"dataset_ids": ["ds-1"], "question": "q", "highlight": "True"}),
|
||||
)
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_by_ids", lambda _ids: [SimpleNamespace(embd_id="m1", tenant_id="tenant-1")])
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _id: (True, SimpleNamespace(tenant_id="tenant-1", embd_id="m1")))
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_by_ids", lambda _ids: [SimpleNamespace(embd_id="m1", tenant_id="tenant-1", tenant_embd_id=1)])
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _id: (True, SimpleNamespace(tenant_id="tenant-1", embd_id="m1", tenant_embd_id=1)))
|
||||
|
||||
class _Retriever:
|
||||
async def retrieval(self, *_args, **_kwargs):
|
||||
@ -865,7 +1009,7 @@ class TestDocRoutesUnit:
|
||||
monkeypatch.setattr(module, "label_question", lambda *_args, **_kwargs: {})
|
||||
monkeypatch.setattr(module.settings, "retriever", _Retriever())
|
||||
res = _run(module.retrieval_test.__wrapped__("tenant-1"))
|
||||
assert res["code"] == 0
|
||||
assert res["code"] == 0, res["message"]
|
||||
assert res["data"]["chunks"] == []
|
||||
|
||||
monkeypatch.setattr(
|
||||
@ -974,7 +1118,7 @@ class TestDocRoutesUnit:
|
||||
}
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _id: (True, SimpleNamespace(tenant_id="tenant-1", embd_id="m1")))
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _id: (True, SimpleNamespace(tenant_id="tenant-1", embd_id="m1", tenant_embd_id=1)))
|
||||
monkeypatch.setattr(module, "cross_languages", _cross_languages)
|
||||
monkeypatch.setattr(module, "keyword_extraction", _keyword_extraction)
|
||||
monkeypatch.setattr(module.settings, "retriever", _FeatureRetriever())
|
||||
@ -982,7 +1126,7 @@ class TestDocRoutesUnit:
|
||||
monkeypatch.setattr(module, "label_question", lambda *_args, **_kwargs: {})
|
||||
monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: SimpleNamespace())
|
||||
res = _run(module.retrieval_test.__wrapped__("tenant-1"))
|
||||
assert res["code"] == 0
|
||||
assert res["code"] == 0, res["message"]
|
||||
assert feature_calls["cross"] == ("fr",)
|
||||
assert feature_calls["keyword"] == "q-xl"
|
||||
assert feature_calls["retrieval_question"] == "q-xl-kw"
|
||||
|
||||
@ -121,6 +121,132 @@ def _load_session_module(monkeypatch):
|
||||
common_pkg.__path__ = [str(repo_root / "common")]
|
||||
monkeypatch.setitem(sys.modules, "common", common_pkg)
|
||||
|
||||
# Mock common.constants module
|
||||
from enum import Enum
|
||||
from strenum import StrEnum
|
||||
|
||||
class _StubLLMType(StrEnum):
|
||||
CHAT = "chat"
|
||||
EMBEDDING = "embedding"
|
||||
SPEECH2TEXT = "speech2text"
|
||||
IMAGE2TEXT = "image2text"
|
||||
RERANK = "rerank"
|
||||
TTS = "tts"
|
||||
OCR = "ocr"
|
||||
|
||||
class _StubParserType(StrEnum):
|
||||
PRESENTATION = "presentation"
|
||||
LAWS = "laws"
|
||||
MANUAL = "manual"
|
||||
PAPER = "paper"
|
||||
RESUME = "resume"
|
||||
BOOK = "book"
|
||||
QA = "qa"
|
||||
TABLE = "table"
|
||||
NAIVE = "naive"
|
||||
PICTURE = "picture"
|
||||
ONE = "one"
|
||||
AUDIO = "audio"
|
||||
EMAIL = "email"
|
||||
KG = "knowledge_graph"
|
||||
TAG = "tag"
|
||||
|
||||
class _StubRetCode(int, Enum):
|
||||
SUCCESS = 0
|
||||
NOT_EFFECTIVE = 10
|
||||
EXCEPTION_ERROR = 100
|
||||
ARGUMENT_ERROR = 101
|
||||
DATA_ERROR = 102
|
||||
OPERATING_ERROR = 103
|
||||
CONNECTION_ERROR = 105
|
||||
RUNNING = 106
|
||||
PERMISSION_ERROR = 108
|
||||
AUTHENTICATION_ERROR = 109
|
||||
BAD_REQUEST = 400
|
||||
UNAUTHORIZED = 401
|
||||
SERVER_ERROR = 500
|
||||
FORBIDDEN = 403
|
||||
NOT_FOUND = 404
|
||||
CONFLICT = 409
|
||||
|
||||
class _StubStatusEnum(str, Enum):
|
||||
VALID = "1"
|
||||
INVALID = "0"
|
||||
|
||||
class _StubActiveEnum(Enum):
|
||||
ACTIVE = "1"
|
||||
INACTIVE = "0"
|
||||
|
||||
class _StubStorage(Enum):
|
||||
MINIO = 1
|
||||
AZURE_SPN = 2
|
||||
AZURE_SAS = 3
|
||||
AWS_S3 = 4
|
||||
OSS = 5
|
||||
OPENDAL = 6
|
||||
GCS = 7
|
||||
|
||||
class _StubMCPServerType(StrEnum):
|
||||
SSE = "sse"
|
||||
STREAMABLE_HTTP = "streamable-http"
|
||||
|
||||
class _StubTaskStatus(StrEnum):
|
||||
UNSTART = "0"
|
||||
RUNNING = "1"
|
||||
CANCEL = "2"
|
||||
DONE = "3"
|
||||
FAIL = "4"
|
||||
SCHEDULE = "5"
|
||||
|
||||
class _StubFileSource(StrEnum):
|
||||
LOCAL = ""
|
||||
KNOWLEDGEBASE = "knowledgebase"
|
||||
S3 = "s3"
|
||||
NOTION = "notion"
|
||||
DISCORD = "discord"
|
||||
CONFLUENCE = "confluence"
|
||||
GMAIL = "gmail"
|
||||
GOOGLE_DRIVE = "google_drive"
|
||||
JIRA = "jira"
|
||||
SHAREPOINT = "sharepoint"
|
||||
SLACK = "slack"
|
||||
TEAMS = "teams"
|
||||
WEBDAV = "webdav"
|
||||
MOODLE = "moodle"
|
||||
DROPBOX = "dropbox"
|
||||
BOX = "box"
|
||||
R2 = "r2"
|
||||
OCI_STORAGE = "oci_storage"
|
||||
GOOGLE_CLOUD_STORAGE = "google_cloud_storage"
|
||||
AIRTABLE = "airtable"
|
||||
ASANA = "asana"
|
||||
GITHUB = "github"
|
||||
GITLAB = "gitlab"
|
||||
IMAP = "imap"
|
||||
BITBUCKET = "bitbucket"
|
||||
ZENDESK = "zendesk"
|
||||
SEAFILE = "seafile"
|
||||
MYSQL = "mysql"
|
||||
POSTGRESQL = "postgresql"
|
||||
|
||||
common_constants_mod = ModuleType("common.constants")
|
||||
common_constants_mod.LLMType = _StubLLMType
|
||||
common_constants_mod.ParserType = _StubParserType
|
||||
common_constants_mod.RetCode = _StubRetCode
|
||||
common_constants_mod.StatusEnum = _StubStatusEnum
|
||||
common_constants_mod.ActiveEnum = _StubActiveEnum
|
||||
common_constants_mod.Storage = _StubStorage
|
||||
common_constants_mod.MCPServerType = _StubMCPServerType
|
||||
common_constants_mod.TaskStatus = _StubTaskStatus
|
||||
common_constants_mod.FileSource = _StubFileSource
|
||||
common_constants_mod.SERVICE_CONF = "service_conf.yaml"
|
||||
common_constants_mod.RAG_FLOW_SERVICE_NAME = "ragflow"
|
||||
common_constants_mod.SVR_QUEUE_NAME = "rag_flow_svr_queue"
|
||||
common_constants_mod.SVR_CONSUMER_GROUP_NAME = "rag_flow_svr_task_broker"
|
||||
common_constants_mod.PAGERANK_FLD = "pagerank_fea"
|
||||
common_constants_mod.TAG_FLD = "tag_feas"
|
||||
monkeypatch.setitem(sys.modules, "common.constants", common_constants_mod)
|
||||
|
||||
deepdoc_pkg = ModuleType("deepdoc")
|
||||
deepdoc_parser_pkg = ModuleType("deepdoc.parser")
|
||||
deepdoc_parser_pkg.__path__ = []
|
||||
@ -166,6 +292,180 @@ def _load_session_module(monkeypatch):
|
||||
monkeypatch.setitem(sys.modules, "deepdoc.parser.utils", deepdoc_parser_utils)
|
||||
monkeypatch.setitem(sys.modules, "xgboost", ModuleType("xgboost"))
|
||||
|
||||
# Mock tenant_llm_service for TenantLLMService and TenantService
|
||||
tenant_llm_service_mod = ModuleType("api.db.services.tenant_llm_service")
|
||||
|
||||
class _MockModelConfig:
|
||||
def __init__(self, tenant_id, model_name):
|
||||
self.tenant_id = tenant_id
|
||||
self.llm_name = model_name
|
||||
self.llm_factory = "Builtin"
|
||||
self.api_key = "fake-api-key"
|
||||
self.api_base = "https://api.example.com"
|
||||
self.model_type = "chat"
|
||||
self.max_tokens = 8192
|
||||
self.used_tokens = 0
|
||||
self.status = 1
|
||||
self.id = 1
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"tenant_id": self.tenant_id,
|
||||
"llm_name": self.llm_name,
|
||||
"llm_factory": self.llm_factory,
|
||||
"api_key": self.api_key,
|
||||
"api_base": self.api_base,
|
||||
"model_type": self.model_type,
|
||||
"max_tokens": self.max_tokens,
|
||||
"used_tokens": self.used_tokens,
|
||||
"status": self.status,
|
||||
"id": self.id
|
||||
}
|
||||
|
||||
class _StubTenantService:
|
||||
@staticmethod
|
||||
def get_by_id(tenant_id):
|
||||
# Return a mock tenant with default model configurations
|
||||
return True, SimpleNamespace(
|
||||
id=tenant_id,
|
||||
llm_id="chat-model",
|
||||
embd_id="embd-model",
|
||||
asr_id="asr-model",
|
||||
img2txt_id="img2txt-model",
|
||||
rerank_id="rerank-model",
|
||||
tts_id="tts-model"
|
||||
)
|
||||
|
||||
class _StubTenantLLMService:
|
||||
@staticmethod
|
||||
def get_api_key(tenant_id, model_name):
|
||||
return _MockModelConfig(tenant_id, model_name)
|
||||
|
||||
@staticmethod
|
||||
def split_model_name_and_factory(model_name):
|
||||
if "@" in model_name:
|
||||
parts = model_name.split("@")
|
||||
return parts[0], parts[1]
|
||||
return model_name, None
|
||||
|
||||
class _StubLLMFactoriesService:
|
||||
@staticmethod
|
||||
def query(**_kwargs):
|
||||
return []
|
||||
|
||||
tenant_llm_service_mod.TenantService = _StubTenantService
|
||||
tenant_llm_service_mod.TenantLLMService = _StubTenantLLMService
|
||||
tenant_llm_service_mod.LLMFactoriesService = _StubLLMFactoriesService
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.tenant_llm_service", tenant_llm_service_mod)
|
||||
|
||||
# Mock LLMService
|
||||
llm_service_mod = ModuleType("api.db.services.llm_service")
|
||||
|
||||
class _StubLLM:
|
||||
def __init__(self, llm_name):
|
||||
self.llm_name = llm_name
|
||||
self.is_tools = False
|
||||
|
||||
llm_service_mod.LLMService = SimpleNamespace(
|
||||
query=lambda llm_name: [_StubLLM(llm_name)] if llm_name else []
|
||||
)
|
||||
|
||||
class _StubLLMBundle:
|
||||
def __init__(self, tenant_id: str, model_config: dict, lang="Chinese", **kwargs):
|
||||
self.tenant_id = tenant_id
|
||||
self.model_config = model_config
|
||||
self.lang = lang
|
||||
|
||||
async def async_chat(self, prompt, messages, options):
|
||||
return "mock response"
|
||||
|
||||
def transcription(self, audio_path):
|
||||
return "mock transcription"
|
||||
|
||||
llm_service_mod.LLMBundle = _StubLLMBundle
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.llm_service", llm_service_mod)
|
||||
|
||||
# Mock tenant_model_service to ensure it uses mocked services
|
||||
tenant_model_service_mod = ModuleType("api.db.joint_services.tenant_model_service")
|
||||
|
||||
class _MockModelConfig2:
|
||||
def __init__(self, tenant_id, model_name, model_type="chat"):
|
||||
self.tenant_id = tenant_id
|
||||
self.llm_name = model_name
|
||||
self.llm_factory = "Builtin"
|
||||
self.api_key = "fake-api-key"
|
||||
self.api_base = "https://api.example.com"
|
||||
self.model_type = model_type
|
||||
self.max_tokens = 8192
|
||||
self.used_tokens = 0
|
||||
self.status = 1
|
||||
self.id = 1
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"tenant_id": self.tenant_id,
|
||||
"llm_name": self.llm_name,
|
||||
"llm_factory": self.llm_factory,
|
||||
"api_key": self.api_key,
|
||||
"api_base": self.api_base,
|
||||
"model_type": self.model_type,
|
||||
"max_tokens": self.max_tokens,
|
||||
"used_tokens": self.used_tokens,
|
||||
"status": self.status,
|
||||
"id": self.id
|
||||
}
|
||||
|
||||
def _get_model_config_by_id(tenant_model_id: int) -> dict:
|
||||
return _MockModelConfig2("tenant-1", "model-1").to_dict()
|
||||
|
||||
def _get_model_config_by_type_and_name(tenant_id: str, model_type: str, model_name: str):
|
||||
if not model_name:
|
||||
raise Exception("Model Name is required")
|
||||
return _MockModelConfig2(tenant_id, model_name, model_type).to_dict()
|
||||
|
||||
def _get_tenant_default_model_by_type(tenant_id: str, model_type):
|
||||
# Check if tenant exists
|
||||
from api.db.services.tenant_llm_service import TenantService
|
||||
exist, tenant = TenantService.get_by_id(tenant_id)
|
||||
if not exist:
|
||||
raise LookupError("Tenant not found!")
|
||||
# Return mock tenant with default model configurations
|
||||
model_type_val = model_type if isinstance(model_type, str) else model_type.value
|
||||
model_name = ""
|
||||
if model_type_val == "embedding":
|
||||
model_name = tenant.embd_id
|
||||
elif model_type_val == "speech2text":
|
||||
model_name = tenant.asr_id
|
||||
elif model_type_val == "image2text":
|
||||
model_name = tenant.img2txt_id
|
||||
elif model_type_val == "chat":
|
||||
model_name = tenant.llm_id
|
||||
elif model_type_val == "rerank":
|
||||
model_name = tenant.rerank_id
|
||||
elif model_type_val == "tts":
|
||||
model_name = tenant.tts_id
|
||||
elif model_type_val == "ocr":
|
||||
raise Exception("OCR model name is required")
|
||||
if not model_name:
|
||||
# Use friendly model type names
|
||||
friendly_names = {
|
||||
"embedding": "Embedding",
|
||||
"speech2text": "ASR",
|
||||
"image2text": "Image2Text",
|
||||
"chat": "Chat",
|
||||
"rerank": "Rerank",
|
||||
"tts": "TTS",
|
||||
"ocr": "OCR"
|
||||
}
|
||||
friendly_name = friendly_names.get(model_type_val, model_type_val)
|
||||
raise Exception(f"No default {friendly_name} model is set")
|
||||
return _MockModelConfig2(tenant_id, model_name, model_type_val).to_dict()
|
||||
|
||||
tenant_model_service_mod.get_model_config_by_id = _get_model_config_by_id
|
||||
tenant_model_service_mod.get_model_config_by_type_and_name = _get_model_config_by_type_and_name
|
||||
tenant_model_service_mod.get_tenant_default_model_by_type = _get_tenant_default_model_by_type
|
||||
monkeypatch.setitem(sys.modules, "api.db.joint_services.tenant_model_service", tenant_model_service_mod)
|
||||
|
||||
agent_pkg = ModuleType("agent")
|
||||
agent_pkg.__path__ = []
|
||||
agent_canvas_mod = ModuleType("agent.canvas")
|
||||
@ -200,6 +500,29 @@ def _load_session_module(monkeypatch):
|
||||
module.manager = _DummyManager()
|
||||
monkeypatch.setitem(sys.modules, "test_session_sdk_routes_unit_module", module)
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
# Add TenantService to module for test compatibility
|
||||
class _StubTenantServiceForTest:
|
||||
@staticmethod
|
||||
def get_info_by(tenant_id):
|
||||
# Return mock tenant info for tests
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def get_by_id(tenant_id):
|
||||
# Return mock tenant by id
|
||||
return True, SimpleNamespace(
|
||||
id=tenant_id,
|
||||
llm_id="chat-model",
|
||||
embd_id="embd-model",
|
||||
asr_id="asr-model",
|
||||
img2txt_id="img2txt-model",
|
||||
rerank_id="rerank-model",
|
||||
tts_id="tts-model"
|
||||
)
|
||||
|
||||
module.TenantService = _StubTenantServiceForTest
|
||||
|
||||
return module
|
||||
|
||||
|
||||
@ -1028,9 +1351,12 @@ def test_searchbots_retrieval_test_embedded_matrix_unit(monkeypatch):
|
||||
|
||||
llm_calls = []
|
||||
|
||||
def _fake_llm_bundle(tenant_id, llm_type, *args, **kwargs):
|
||||
llm_calls.append((tenant_id, llm_type, args, kwargs))
|
||||
return SimpleNamespace(tenant_id=tenant_id, llm_type=llm_type, args=args, kwargs=kwargs)
|
||||
def _fake_llm_bundle(tenant_id, model_config, *args, **kwargs):
|
||||
# Extract llm_type from model_config for comparison
|
||||
llm_type = model_config.get("model_type") if isinstance(model_config, dict) else model_config
|
||||
llm_name = model_config.get("llm_name") if isinstance(model_config, dict) else None
|
||||
llm_calls.append((tenant_id, llm_type, llm_name, args, kwargs))
|
||||
return SimpleNamespace(tenant_id=tenant_id, llm_type=llm_type, llm_name=llm_name, args=args, kwargs=kwargs)
|
||||
|
||||
monkeypatch.setattr(module, "LLMBundle", _fake_llm_bundle)
|
||||
monkeypatch.setattr(
|
||||
@ -1128,7 +1454,7 @@ def test_searchbots_retrieval_test_embedded_matrix_unit(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
module.KnowledgebaseService,
|
||||
"get_by_id",
|
||||
lambda _kb_id: (True, SimpleNamespace(tenant_id="tenant-kb", embd_id="embd-model")),
|
||||
lambda _kb_id: (True, SimpleNamespace(tenant_id="tenant-kb", embd_id="embd-model", tenant_embd_id=None)),
|
||||
)
|
||||
res = _run(handler())
|
||||
assert res["code"] == 0
|
||||
@ -1143,7 +1469,7 @@ def test_searchbots_retrieval_test_embedded_matrix_unit(monkeypatch):
|
||||
assert retrieval_capture["local_doc_ids"] == ["doc-filtered"]
|
||||
assert retrieval_capture["rank_feature"] == ["label-1"]
|
||||
assert retrieval_capture["rerank_mdl"] is not None
|
||||
assert any(call[1] == module.LLMType.EMBEDDING.value and call[3].get("llm_name") == "embd-model" for call in llm_calls)
|
||||
assert any(call[1] == module.LLMType.EMBEDDING.value and call[2] == "embd-model" for call in llm_calls)
|
||||
|
||||
llm_calls.clear()
|
||||
|
||||
@ -1178,7 +1504,7 @@ def test_searchbots_retrieval_test_embedded_matrix_unit(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
module.KnowledgebaseService,
|
||||
"get_by_id",
|
||||
lambda _kb_id: (True, SimpleNamespace(tenant_id="tenant-kb", embd_id="embd-model")),
|
||||
lambda _kb_id: (True, SimpleNamespace(tenant_id="tenant-kb", embd_id="embd-model", tenant_embd_id=None)),
|
||||
)
|
||||
res = _run(handler())
|
||||
assert res["code"] == 0
|
||||
@ -1247,7 +1573,11 @@ def test_searchbots_related_questions_embedded_matrix_unit(monkeypatch):
|
||||
res = _run(handler())
|
||||
assert res["code"] == 0
|
||||
assert res["data"] == ["Alpha", "Beta"]
|
||||
assert captured["bundle_args"] == ("tenant-1", module.LLMType.CHAT, "chat-x")
|
||||
# LLMBundle is called with (tenant_id, model_config)
|
||||
# model_config is a dict with model_type, llm_name, etc.
|
||||
assert captured["bundle_args"][0] == "tenant-1"
|
||||
assert captured["bundle_args"][1].get("model_type") == module.LLMType.CHAT
|
||||
assert captured["bundle_args"][1].get("llm_name") == "chat-x"
|
||||
assert captured["options"] == {"temperature": 0.2}
|
||||
assert "Keywords: solar" in captured["messages"][0]["content"]
|
||||
|
||||
@ -1361,12 +1691,14 @@ def test_sequence2txt_embedded_validation_and_stream_matrix_unit(monkeypatch):
|
||||
assert "Unsupported audio format: .txt" in res["message"]
|
||||
|
||||
_set_request({"stream": "false"}, {"file": _DummyUploadFile("audio.wav")})
|
||||
monkeypatch.setattr(module.TenantService, "get_info_by", lambda _tid: [])
|
||||
tenant_llm_service = sys.modules["api.db.services.tenant_llm_service"]
|
||||
monkeypatch.setattr(tenant_llm_service.TenantService, "get_by_id", lambda _tid: (False, None))
|
||||
res = _run(handler("tenant-1"))
|
||||
assert res["message"] == "Tenant not found!"
|
||||
|
||||
_set_request({"stream": "false"}, {"file": _DummyUploadFile("audio.wav")})
|
||||
monkeypatch.setattr(module.TenantService, "get_info_by", lambda _tid: [{"tenant_id": "tenant-1", "asr_id": ""}])
|
||||
tenant_llm_service = sys.modules["api.db.services.tenant_llm_service"]
|
||||
monkeypatch.setattr(tenant_llm_service.TenantService, "get_by_id", lambda _tid: (True, SimpleNamespace(asr_id="", tts_id="", llm_id="", embd_id="", img2txt_id="", rerank_id="")))
|
||||
res = _run(handler("tenant-1"))
|
||||
assert res["message"] == "No default ASR model is set"
|
||||
|
||||
@ -1378,7 +1710,7 @@ def test_sequence2txt_embedded_validation_and_stream_matrix_unit(monkeypatch):
|
||||
return []
|
||||
|
||||
_set_request({"stream": "false"}, {"file": _DummyUploadFile("audio.wav")})
|
||||
monkeypatch.setattr(module.TenantService, "get_info_by", lambda _tid: [{"tenant_id": "tenant-1", "asr_id": "asr-x"}])
|
||||
monkeypatch.setattr(tenant_llm_service.TenantService, "get_by_id", lambda _tid: (True, SimpleNamespace(asr_id="asr-x", tts_id="", llm_id="", embd_id="", img2txt_id="", rerank_id="")))
|
||||
monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: _SyncASR())
|
||||
monkeypatch.setattr(module.os, "remove", lambda _path: (_ for _ in ()).throw(RuntimeError("cleanup fail")))
|
||||
res = _run(handler("tenant-1"))
|
||||
@ -1420,14 +1752,15 @@ def test_sequence2txt_embedded_validation_and_stream_matrix_unit(monkeypatch):
|
||||
def test_tts_embedded_stream_and_error_matrix_unit(monkeypatch):
|
||||
module = _load_session_module(monkeypatch)
|
||||
handler = inspect.unwrap(module.tts)
|
||||
monkeypatch.setattr(module, "Response", _StubResponse)
|
||||
monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"text": "A。B"}))
|
||||
monkeypatch.setattr(module, "Response", _StubResponse)
|
||||
|
||||
monkeypatch.setattr(module.TenantService, "get_info_by", lambda _tid: [])
|
||||
tenant_llm_service = sys.modules["api.db.services.tenant_llm_service"]
|
||||
monkeypatch.setattr(tenant_llm_service.TenantService, "get_by_id", lambda _tid: (False, None))
|
||||
res = _run(handler("tenant-1"))
|
||||
assert res["message"] == "Tenant not found!"
|
||||
|
||||
monkeypatch.setattr(module.TenantService, "get_info_by", lambda _tid: [{"tenant_id": "tenant-1", "tts_id": ""}])
|
||||
monkeypatch.setattr(tenant_llm_service.TenantService, "get_by_id", lambda _tid: (True, SimpleNamespace(asr_id="", tts_id="", llm_id="", embd_id="", img2txt_id="", rerank_id="")))
|
||||
res = _run(handler("tenant-1"))
|
||||
assert res["message"] == "No default TTS model is set"
|
||||
|
||||
@ -1437,7 +1770,7 @@ def test_tts_embedded_stream_and_error_matrix_unit(monkeypatch):
|
||||
return []
|
||||
yield f"chunk-{txt}".encode("utf-8")
|
||||
|
||||
monkeypatch.setattr(module.TenantService, "get_info_by", lambda _tid: [{"tenant_id": "tenant-1", "tts_id": "tts-x"}])
|
||||
monkeypatch.setattr(tenant_llm_service.TenantService, "get_by_id", lambda _tid: (True, SimpleNamespace(asr_id="", tts_id="tts-x", llm_id="", embd_id="", img2txt_id="", rerank_id="")))
|
||||
monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: _TTSOk())
|
||||
resp = _run(handler("tenant-1"))
|
||||
assert resp.mimetype == "audio/mpeg"
|
||||
|
||||
@ -31,7 +31,7 @@ def add_memory_func(client, request):
|
||||
payload = {
|
||||
"name": f"test_memory_{i}",
|
||||
"memory_type": ["raw"] + random.choices(["semantic", "episodic", "procedural"], k=random.randint(0, 3)),
|
||||
"embd_id": "BAAI/bge-large-zh-v1.5@SILICONFLOW",
|
||||
"embd_id": "BAAI/bge-small-en-v1.5@Builtin",
|
||||
"llm_id": "glm-4-flash@ZHIPU-AI"
|
||||
}
|
||||
res = client.create_memory(**payload)
|
||||
|
||||
@ -36,7 +36,7 @@ class TestAuthorization:
|
||||
def test_auth_invalid(self, invalid_auth, expected_message):
|
||||
client = RAGFlow(invalid_auth, HOST_ADDRESS)
|
||||
with pytest.raises(Exception) as exception_info:
|
||||
client.create_memory(**{"name": "test_memory", "memory_type": ["raw"], "embd_id": "BAAI/bge-large-zh-v1.5@SILICONFLOW", "llm_id": "glm-4-flash@ZHIPU-AI"})
|
||||
client.create_memory(**{"name": "test_memory", "memory_type": ["raw"], "embd_id": "BAAI/bge-small-en-v1.5@Builtin", "llm_id": "glm-4-flash@ZHIPU-AI"})
|
||||
assert str(exception_info.value) == expected_message, str(exception_info.value)
|
||||
|
||||
|
||||
@ -50,7 +50,7 @@ class TestMemoryCreate:
|
||||
payload = {
|
||||
"name": name,
|
||||
"memory_type": ["raw"] + random.choices(["semantic", "episodic", "procedural"], k=random.randint(0, 3)),
|
||||
"embd_id": "BAAI/bge-large-zh-v1.5@SILICONFLOW",
|
||||
"embd_id": "BAAI/bge-small-en-v1.5@Builtin",
|
||||
"llm_id": "glm-4-flash@ZHIPU-AI"
|
||||
}
|
||||
memory = client.create_memory(**payload)
|
||||
@ -72,7 +72,7 @@ class TestMemoryCreate:
|
||||
payload = {
|
||||
"name": name,
|
||||
"memory_type": ["raw"] + random.choices(["semantic", "episodic", "procedural"], k=random.randint(0, 3)),
|
||||
"embd_id": "BAAI/bge-large-zh-v1.5@SILICONFLOW",
|
||||
"embd_id": "BAAI/bge-small-en-v1.5@Builtin",
|
||||
"llm_id": "glm-4-flash@ZHIPU-AI"
|
||||
}
|
||||
with pytest.raises(Exception) as exception_info:
|
||||
@ -86,7 +86,7 @@ class TestMemoryCreate:
|
||||
payload = {
|
||||
"name": name,
|
||||
"memory_type": ["something"],
|
||||
"embd_id": "BAAI/bge-large-zh-v1.5@SILICONFLOW",
|
||||
"embd_id": "BAAI/bge-small-en-v1.5@Builtin",
|
||||
"llm_id": "glm-4-flash@ZHIPU-AI"
|
||||
}
|
||||
with pytest.raises(Exception) as exception_info:
|
||||
@ -99,7 +99,7 @@ class TestMemoryCreate:
|
||||
payload = {
|
||||
"name": name,
|
||||
"memory_type": ["raw"] + random.choices(["semantic", "episodic", "procedural"], k=random.randint(0, 3)),
|
||||
"embd_id": "BAAI/bge-large-zh-v1.5@SILICONFLOW",
|
||||
"embd_id": "BAAI/bge-small-en-v1.5@Builtin",
|
||||
"llm_id": "glm-4-flash@ZHIPU-AI"
|
||||
}
|
||||
res1 = client.create_memory(**payload)
|
||||
|
||||
@ -207,6 +207,10 @@ def _load_chunk_module(monkeypatch):
|
||||
EMBEDDING = SimpleNamespace(value="embedding")
|
||||
CHAT = SimpleNamespace(value="chat")
|
||||
RERANK = SimpleNamespace(value="rerank")
|
||||
SPEECH2TEXT = SimpleNamespace(value="speech2text")
|
||||
IMAGE2TEXT = SimpleNamespace(value="image2text")
|
||||
TTS = SimpleNamespace(value="tts")
|
||||
OCR = SimpleNamespace(value="ocr")
|
||||
|
||||
constants_mod.RetCode = _DummyRetCode
|
||||
constants_mod.LLMType = _DummyLLMType
|
||||
@ -301,6 +305,10 @@ def _load_chunk_module(monkeypatch):
|
||||
def get_embd_id(_doc_id):
|
||||
return "embed-1"
|
||||
|
||||
@staticmethod
|
||||
def get_tenant_embd_id(_doc_id):
|
||||
return 1
|
||||
|
||||
@staticmethod
|
||||
def decrement_chunk_num(*args):
|
||||
_DocumentService.decrement_calls.append(args)
|
||||
@ -327,13 +335,24 @@ def _load_chunk_module(monkeypatch):
|
||||
|
||||
@staticmethod
|
||||
def get_by_id(_kb_id):
|
||||
return True, SimpleNamespace(pagerank=0.6)
|
||||
return True, SimpleNamespace(pagerank=0.6, tenant_embd_id=2, tenant_llm_id=1)
|
||||
|
||||
kb_service_mod.KnowledgebaseService = _KnowledgebaseService
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.knowledgebase_service", kb_service_mod)
|
||||
services_pkg.knowledgebase_service = kb_service_mod
|
||||
|
||||
class _DummyLLMService:
|
||||
@staticmethod
|
||||
def query(**_kwargs):
|
||||
return [SimpleNamespace(
|
||||
llm_name="gpt-3.5-turbo",
|
||||
model_type="chat",
|
||||
max_tokens=8192,
|
||||
is_tools=True
|
||||
)]
|
||||
|
||||
llm_service_mod = ModuleType("api.db.services.llm_service")
|
||||
llm_service_mod.LLMService = _DummyLLMService
|
||||
llm_service_mod.LLMBundle = _DummyLLMBundle
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.llm_service", llm_service_mod)
|
||||
services_pkg.llm_service = llm_service_mod
|
||||
@ -343,6 +362,77 @@ def _load_chunk_module(monkeypatch):
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.search_service", search_service_mod)
|
||||
services_pkg.search_service = search_service_mod
|
||||
|
||||
tenant_llm_service_mod = ModuleType("api.db.services.tenant_llm_service")
|
||||
|
||||
class _MockTableObject:
|
||||
def __init__(self, **kwargs):
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
def to_dict(self):
|
||||
return {k: v for k, v in self.__dict__.items()}
|
||||
|
||||
class _TenantLLMService:
|
||||
@staticmethod
|
||||
def get_by_id(tenant_model_id):
|
||||
return True, _MockTableObject(
|
||||
id=tenant_model_id,
|
||||
tenant_id="tenant-1",
|
||||
llm_factory="",
|
||||
model_type="chat",
|
||||
llm_name="gpt-3.5-turbo",
|
||||
api_key="fake-api-key",
|
||||
api_base="https://api.example.com",
|
||||
max_tokens=8192,
|
||||
used_tokens=0,
|
||||
status=1
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_api_key(tenant_id, model_name):
|
||||
return _MockTableObject(
|
||||
id=1,
|
||||
tenant_id=tenant_id,
|
||||
llm_factory="",
|
||||
model_type="chat",
|
||||
llm_name=model_name,
|
||||
api_key="fake-api-key",
|
||||
api_base="https://api.example.com",
|
||||
max_tokens=8192,
|
||||
used_tokens=0,
|
||||
status=1
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def split_model_name_and_factory(model_name):
|
||||
if "@" in model_name:
|
||||
parts = model_name.rsplit("@", 1)
|
||||
return parts[0], parts[1]
|
||||
return model_name, None
|
||||
|
||||
@staticmethod
|
||||
def increase_usage_by_id(model_id, used_tokens):
|
||||
return True
|
||||
|
||||
class _TenantService:
|
||||
@staticmethod
|
||||
def get_by_id(tenant_id):
|
||||
return True, SimpleNamespace(
|
||||
llm_id="gpt-3.5-turbo",
|
||||
tenant_llm_id=1,
|
||||
embd_id="text-embedding-ada-002",
|
||||
tenant_embd_id=2,
|
||||
asr_id="whisper-1",
|
||||
img2txt_id="gpt-4-vision-preview",
|
||||
rerank_id="bge-reranker",
|
||||
tts_id="tts-1"
|
||||
)
|
||||
|
||||
tenant_llm_service_mod.TenantLLMService = _TenantLLMService
|
||||
tenant_llm_service_mod.TenantService = _TenantService
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.tenant_llm_service", tenant_llm_service_mod)
|
||||
services_pkg.tenant_llm_service = tenant_llm_service_mod
|
||||
|
||||
user_service_mod = ModuleType("api.db.services.user_service")
|
||||
|
||||
class _UserTenantService:
|
||||
@ -775,7 +865,7 @@ def test_retrieval_test_branch_matrix_unit(monkeypatch):
|
||||
assert "Knowledgebase not found!" in res["message"], res
|
||||
|
||||
retriever = _Retriever(mode="ok")
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, SimpleNamespace(tenant_id="tenant-kb", embd_id="embd-1")), raising=False)
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, SimpleNamespace(tenant_id="tenant-kb", embd_id="embd-1", tenant_embd_id=2)), raising=False)
|
||||
monkeypatch.setattr(module.settings, "retriever", retriever)
|
||||
monkeypatch.setattr(module.settings, "kg_retriever", _KgRetriever(), raising=False)
|
||||
_set_request_json(
|
||||
|
||||
@ -143,6 +143,19 @@ def _load_conversation_module(monkeypatch):
|
||||
apps_mod.login_required = lambda func: func
|
||||
monkeypatch.setitem(sys.modules, "api.apps", apps_mod)
|
||||
|
||||
# Create user_service module with TenantService stub if not already exists
|
||||
if "api.db.services.user_service" not in sys.modules:
|
||||
user_service_mod = ModuleType("api.db.services.user_service")
|
||||
user_service_mod.UserService = SimpleNamespace() # Dummy UserService class
|
||||
user_service_mod.TenantService = SimpleNamespace(
|
||||
get_info_by=lambda _uid: [],
|
||||
get_by_id=lambda _uid: (False, None)
|
||||
)
|
||||
user_service_mod.UserTenantService = SimpleNamespace(
|
||||
query=lambda **_kwargs: []
|
||||
)
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.user_service", user_service_mod)
|
||||
|
||||
module_name = "test_conversation_routes_unit_module"
|
||||
module_path = repo_root / "api" / "apps" / "conversation_app.py"
|
||||
spec = importlib.util.spec_from_file_location(module_name, module_path)
|
||||
@ -519,15 +532,15 @@ def test_sequence2txt_validation_and_transcription_paths(monkeypatch):
|
||||
|
||||
wav_file = _DummyUploadedFile("audio.wav")
|
||||
monkeypatch.setattr(module, "request", _DummyRequest(form={"stream": "false"}, files={"file": wav_file}))
|
||||
monkeypatch.setattr(module.TenantService, "get_info_by", lambda _uid: [])
|
||||
monkeypatch.setattr(sys.modules["api.db.joint_services.tenant_model_service"].TenantService, "get_by_id", lambda _uid: (False, None))
|
||||
res = _run(module.sequence2txt())
|
||||
assert res["message"] == "Tenant not found!"
|
||||
assert res["message"] == "Tenant not found"
|
||||
|
||||
wav_file = _DummyUploadedFile("audio.wav")
|
||||
monkeypatch.setattr(module, "request", _DummyRequest(form={"stream": "false"}, files={"file": wav_file}))
|
||||
monkeypatch.setattr(module.TenantService, "get_info_by", lambda _uid: [{"tenant_id": "tenant-1", "asr_id": ""}])
|
||||
monkeypatch.setattr(sys.modules["api.db.joint_services.tenant_model_service"].TenantService, "get_by_id", lambda _uid: (True, SimpleNamespace(tenant_id="tenant-1", asr_id="")))
|
||||
res = _run(module.sequence2txt())
|
||||
assert res["message"] == "No default ASR model is set"
|
||||
assert res["message"] == "No default speech2text model is set."
|
||||
|
||||
class _SyncAsr:
|
||||
def transcription(self, _path):
|
||||
@ -538,7 +551,8 @@ def test_sequence2txt_validation_and_transcription_paths(monkeypatch):
|
||||
|
||||
wav_file = _DummyUploadedFile("audio.wav")
|
||||
monkeypatch.setattr(module, "request", _DummyRequest(form={"stream": "false"}, files={"file": wav_file}))
|
||||
monkeypatch.setattr(module.TenantService, "get_info_by", lambda _uid: [{"tenant_id": "tenant-1", "asr_id": "asr-model"}])
|
||||
monkeypatch.setattr(sys.modules["api.db.joint_services.tenant_model_service"].TenantService, "get_by_id", lambda _uid: (True, SimpleNamespace(tenant_id="tenant-1", asr_id="asr-model")))
|
||||
monkeypatch.setattr(module.TenantLLMService, "get_api_key", lambda tenant_id, model_name: SimpleNamespace(to_dict=lambda: {"llm_factory": "test", "llm_name": "asr-model"}))
|
||||
monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: _SyncAsr())
|
||||
monkeypatch.setattr(module.os, "remove", lambda _path: (_ for _ in ()).throw(RuntimeError("remove failed")))
|
||||
res = _run(module.sequence2txt())
|
||||
@ -579,13 +593,13 @@ def test_sequence2txt_validation_and_transcription_paths(monkeypatch):
|
||||
def test_tts_request_parse_entry(monkeypatch):
|
||||
module = _load_conversation_module(monkeypatch)
|
||||
_set_request_json(monkeypatch, module, {"text": "A。B"})
|
||||
monkeypatch.setattr(module.TenantService, "get_info_by", lambda _uid: [])
|
||||
monkeypatch.setattr(sys.modules["api.db.joint_services.tenant_model_service"].TenantService, "get_by_id", lambda _uid: (False, None))
|
||||
res = _run(module.tts())
|
||||
assert res["message"] == "Tenant not found!"
|
||||
assert res["message"] == "Tenant not found"
|
||||
|
||||
monkeypatch.setattr(module.TenantService, "get_info_by", lambda _uid: [{"tenant_id": "tenant-1", "tts_id": ""}])
|
||||
monkeypatch.setattr(sys.modules["api.db.joint_services.tenant_model_service"].TenantService, "get_by_id", lambda _uid: (True, SimpleNamespace(tenant_id="tenant-1", tts_id="")))
|
||||
res = _run(module.tts())
|
||||
assert res["message"] == "No default TTS model is set"
|
||||
assert res["message"] == "No default tts model is set."
|
||||
|
||||
class _TTSOk:
|
||||
def tts(self, txt):
|
||||
@ -593,7 +607,8 @@ def test_tts_request_parse_entry(monkeypatch):
|
||||
return []
|
||||
yield f"chunk-{txt}".encode("utf-8")
|
||||
|
||||
monkeypatch.setattr(module.TenantService, "get_info_by", lambda _uid: [{"tenant_id": "tenant-1", "tts_id": "tts-x"}])
|
||||
monkeypatch.setattr(sys.modules["api.db.joint_services.tenant_model_service"].TenantService, "get_by_id", lambda _uid: (True, SimpleNamespace(tenant_id="tenant-1", tts_id="tts-x")))
|
||||
monkeypatch.setattr(module.TenantLLMService, "get_api_key", lambda tenant_id, model_name: SimpleNamespace(to_dict=lambda: {"llm_factory": "test", "llm_name": model_name}))
|
||||
monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: _TTSOk())
|
||||
resp = _run(module.tts())
|
||||
assert resp.mimetype == "audio/mpeg"
|
||||
@ -749,18 +764,18 @@ def test_mindmap_and_related_questions_matrix_unit(monkeypatch):
|
||||
llm_calls["options"] = options
|
||||
return "1. Alpha\n2. Beta\nignored"
|
||||
|
||||
def _fake_bundle(tenant_id, llm_type, chat_id):
|
||||
llm_calls["bundle"] = (tenant_id, llm_type, chat_id)
|
||||
def _fake_bundle(tenant_id, model_config, lang="Chinese", **kwargs):
|
||||
llm_calls["bundle"] = (tenant_id, model_config)
|
||||
return _FakeChat()
|
||||
|
||||
monkeypatch.setattr(module, "LLMBundle", _fake_bundle)
|
||||
monkeypatch.setattr(module, "load_prompt", lambda name: f"prompt-{name}")
|
||||
monkeypatch.setattr(module.TenantLLMService, "get_api_key", lambda tenant_id, model_name: SimpleNamespace(to_dict=lambda: {"llm_factory": "test", "llm_name": model_name}))
|
||||
_set_request_json(monkeypatch, module, {"question": "solar", "search_id": "search-1"})
|
||||
res = _run(module.related_questions.__wrapped__())
|
||||
assert res["code"] == 0
|
||||
assert res["data"] == ["Alpha", "Beta"]
|
||||
assert llm_calls["bundle"][0] == "user-1"
|
||||
assert llm_calls["bundle"][2] == "chat-x"
|
||||
assert llm_calls["options"] == {"temperature": 0.2}
|
||||
assert llm_calls["prompt"] == "prompt-related_question"
|
||||
assert "Keywords: solar" in llm_calls["messages"][0]["content"]
|
||||
|
||||
@ -137,11 +137,34 @@ def _load_dialog_module(monkeypatch):
|
||||
|
||||
tenant_llm_service_mod = ModuleType("api.db.services.tenant_llm_service")
|
||||
|
||||
class _MockTableObject:
|
||||
def __init__(self, **kwargs):
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
def to_dict(self):
|
||||
return {k: v for k, v in self.__dict__.items()}
|
||||
|
||||
class _TenantLLMService:
|
||||
@staticmethod
|
||||
def split_model_name_and_factory(embd_id):
|
||||
return embd_id.split("@")
|
||||
|
||||
@staticmethod
|
||||
def get_api_key(tenant_id, model_name):
|
||||
return _MockTableObject(
|
||||
id=1,
|
||||
tenant_id=tenant_id,
|
||||
llm_factory="",
|
||||
model_type="chat",
|
||||
llm_name=model_name,
|
||||
api_key="fake-api-key",
|
||||
api_base="https://api.example.com",
|
||||
max_tokens=8192,
|
||||
used_tokens=0,
|
||||
status=1
|
||||
)
|
||||
|
||||
tenant_llm_service_mod.TenantLLMService = _TenantLLMService
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.tenant_llm_service", tenant_llm_service_mod)
|
||||
|
||||
@ -253,8 +276,8 @@ def test_set_dialog_branch_matrix_unit(monkeypatch):
|
||||
|
||||
monkeypatch.setattr(module, "duplicate_name", _dup_name)
|
||||
monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: [SimpleNamespace(name="new dialog")])
|
||||
monkeypatch.setattr(module.TenantService, "get_by_id", lambda _id: (True, SimpleNamespace(llm_id="llm-x")))
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_by_ids", lambda _ids: [SimpleNamespace(embd_id="embd-a@builtin")])
|
||||
monkeypatch.setattr(module.TenantService, "get_by_id", lambda _id: (True, SimpleNamespace(llm_id="llm-x", tenant_llm_id=1)))
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_by_ids", lambda _ids: [SimpleNamespace(embd_id="embd-a@builtin", tenant_embd_id=2)])
|
||||
monkeypatch.setattr(module.TenantLLMService, "split_model_name_and_factory", lambda embd_id: embd_id.split("@"))
|
||||
monkeypatch.setattr(module.DialogService, "save", lambda **kwargs: captured.update(kwargs) or False)
|
||||
_set_request_json(
|
||||
@ -301,7 +324,7 @@ def test_set_dialog_branch_matrix_unit(monkeypatch):
|
||||
res = _run(handler())
|
||||
assert res["message"] == "Tenant not found!"
|
||||
|
||||
monkeypatch.setattr(module.TenantService, "get_by_id", lambda _id: (True, SimpleNamespace(llm_id="llm-x")))
|
||||
monkeypatch.setattr(module.TenantService, "get_by_id", lambda _id: (True, SimpleNamespace(llm_id="llm-x", tenant_llm_id=1)))
|
||||
monkeypatch.setattr(
|
||||
module,
|
||||
"get_request_json",
|
||||
@ -316,7 +339,7 @@ def test_set_dialog_branch_matrix_unit(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
module.KnowledgebaseService,
|
||||
"get_by_ids",
|
||||
lambda _ids: [SimpleNamespace(embd_id="embd-a@f1"), SimpleNamespace(embd_id="embd-b@f2")],
|
||||
lambda _ids: [SimpleNamespace(embd_id="embd-a@f1", tenant_embd_id=2), SimpleNamespace(embd_id="embd-b@f2", tenant_embd_id=2)],
|
||||
)
|
||||
monkeypatch.setattr(module.TenantLLMService, "split_model_name_and_factory", lambda embd_id: embd_id.split("@"))
|
||||
res = _run(handler())
|
||||
|
||||
@ -51,11 +51,19 @@ class _DummyTenantLLMModel:
|
||||
llm_factory = _ExprField("llm_factory")
|
||||
llm_name = _ExprField("llm_name")
|
||||
|
||||
def __init__(self, id=None, **kwargs):
|
||||
self.id = id
|
||||
self.api_key = None
|
||||
self.status = None
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
class _TenantLLMRow:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
id,
|
||||
llm_name,
|
||||
llm_factory,
|
||||
model_type,
|
||||
@ -65,6 +73,7 @@ class _TenantLLMRow:
|
||||
api_base="",
|
||||
max_tokens=8192,
|
||||
):
|
||||
self.id = id
|
||||
self.llm_name = llm_name
|
||||
self.llm_factory = llm_factory
|
||||
self.model_type = model_type
|
||||
@ -76,6 +85,7 @@ class _TenantLLMRow:
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"id": self.id,
|
||||
"llm_name": self.llm_name,
|
||||
"llm_factory": self.llm_factory,
|
||||
"model_type": self.model_type,
|
||||
@ -246,8 +256,8 @@ def test_list_app_grouping_availability_and_merge(monkeypatch):
|
||||
monkeypatch.setattr(module.TenantLLMService, "ensure_mineru_from_env", lambda tenant_id: ensure_calls.append(tenant_id))
|
||||
|
||||
tenant_rows = [
|
||||
_TenantLLMRow(llm_name="fast-emb", llm_factory="FastEmbed", model_type="embedding", api_key="k1", status="1"),
|
||||
_TenantLLMRow(llm_name="tenant-only", llm_factory="CustomFactory", model_type="chat", api_key="k2", status="1"),
|
||||
_TenantLLMRow(id=1, llm_name="fast-emb", llm_factory="FastEmbed", model_type="embedding", api_key="k1", status="1"),
|
||||
_TenantLLMRow(id=2, llm_name="tenant-only", llm_factory="CustomFactory", model_type="chat", api_key="k2", status="1"),
|
||||
]
|
||||
monkeypatch.setattr(module.TenantLLMService, "query", lambda **_kwargs: tenant_rows)
|
||||
|
||||
@ -263,7 +273,7 @@ def test_list_app_grouping_availability_and_merge(monkeypatch):
|
||||
monkeypatch.setenv("TEI_MODEL", "tei-embed")
|
||||
|
||||
res = _run(module.list_app())
|
||||
assert res["code"] == 0
|
||||
assert res["code"] == 0, res["message"]
|
||||
assert ensure_calls == ["tenant-1"]
|
||||
|
||||
data = res["data"]
|
||||
@ -291,8 +301,8 @@ def test_list_app_model_type_filter(monkeypatch):
|
||||
module.TenantLLMService,
|
||||
"query",
|
||||
lambda **_kwargs: [
|
||||
_TenantLLMRow(llm_name="fast-emb", llm_factory="FastEmbed", model_type="embedding", api_key="k1", status="1"),
|
||||
_TenantLLMRow(llm_name="tenant-only", llm_factory="CustomFactory", model_type="chat", api_key="k2", status="1"),
|
||||
_TenantLLMRow(id=1, llm_name="fast-emb", llm_factory="FastEmbed", model_type="embedding", api_key="k1", status="1"),
|
||||
_TenantLLMRow(id=2, llm_name="tenant-only", llm_factory="CustomFactory", model_type="chat", api_key="k2", status="1"),
|
||||
],
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
@ -306,7 +316,7 @@ def test_list_app_model_type_filter(monkeypatch):
|
||||
|
||||
monkeypatch.setattr(module, "request", SimpleNamespace(args={"model_type": "chat"}))
|
||||
res = _run(module.list_app())
|
||||
assert res["code"] == 0
|
||||
assert res["code"] == 0, res["message"]
|
||||
assert list(res["data"].keys()) == ["CustomFactory"]
|
||||
assert res["data"]["CustomFactory"][0]["model_type"] == "chat"
|
||||
|
||||
@ -799,7 +809,7 @@ def test_add_llm_model_type_probe_and_persistence_matrix_unit(monkeypatch):
|
||||
monkeypatch.setattr(module.TenantLLMService, "filter_update", lambda _filters, _payload: False)
|
||||
monkeypatch.setattr(module.TenantLLMService, "save", lambda **kwargs: saved.append(kwargs) or True)
|
||||
res = _call({"llm_factory": "FChatPass", "llm_name": "m", "model_type": module.LLMType.CHAT.value, "api_key": "k"})
|
||||
assert res["code"] == 0
|
||||
assert res["code"] == 0, res["message"]
|
||||
assert res["data"] is True
|
||||
assert saved
|
||||
assert saved[0]["llm_factory"] == "FChatPass"
|
||||
@ -841,6 +851,7 @@ def test_my_llms_include_details_and_exception_unit(monkeypatch):
|
||||
"query",
|
||||
lambda **_kwargs: [
|
||||
_TenantLLMRow(
|
||||
id=1,
|
||||
llm_name="chat-model",
|
||||
llm_factory="FactoryX",
|
||||
model_type="chat",
|
||||
|
||||
@ -32,7 +32,7 @@ def add_memory_func(request, WebApiAuth):
|
||||
payload = {
|
||||
"name": f"test_memory_{i}",
|
||||
"memory_type": ["raw"] + random.choices(["semantic", "episodic", "procedural"], k=random.randint(0, 3)),
|
||||
"embd_id": "BAAI/bge-large-zh-v1.5@SILICONFLOW",
|
||||
"embd_id": "BAAI/bge-small-en-v1.5@Builtin",
|
||||
"llm_id": "glm-4-flash@ZHIPU-AI"
|
||||
}
|
||||
res = create_memory(WebApiAuth, payload)
|
||||
|
||||
@ -45,7 +45,7 @@ class TestMemoryCreate:
|
||||
payload = {
|
||||
"name": name,
|
||||
"memory_type": ["raw"] + random.choices(["semantic", "episodic", "procedural"], k=random.randint(0, 3)),
|
||||
"embd_id": "BAAI/bge-large-zh-v1.5@SILICONFLOW",
|
||||
"embd_id": "BAAI/bge-small-en-v1.5@Builtin",
|
||||
"llm_id": "glm-4-flash@ZHIPU-AI"
|
||||
}
|
||||
res = create_memory(WebApiAuth, payload)
|
||||
@ -68,7 +68,7 @@ class TestMemoryCreate:
|
||||
payload = {
|
||||
"name": name,
|
||||
"memory_type": ["raw"] + random.choices(["semantic", "episodic", "procedural"], k=random.randint(0, 3)),
|
||||
"embd_id": "BAAI/bge-large-zh-v1.5@SILICONFLOW",
|
||||
"embd_id": "BAAI/bge-small-en-v1.5@Builtin",
|
||||
"llm_id": "glm-4-flash@ZHIPU-AI"
|
||||
}
|
||||
res = create_memory(WebApiAuth, payload)
|
||||
@ -80,7 +80,7 @@ class TestMemoryCreate:
|
||||
payload = {
|
||||
"name": name,
|
||||
"memory_type": ["something"],
|
||||
"embd_id": "BAAI/bge-large-zh-v1.5@SILICONFLOW",
|
||||
"embd_id": "BAAI/bge-small-en-v1.5@Builtin",
|
||||
"llm_id": "glm-4-flash@ZHIPU-AI"
|
||||
}
|
||||
res = create_memory(WebApiAuth, payload)
|
||||
@ -92,7 +92,7 @@ class TestMemoryCreate:
|
||||
payload = {
|
||||
"name": name,
|
||||
"memory_type": ["raw"] + random.choices(["semantic", "episodic", "procedural"], k=random.randint(0, 3)),
|
||||
"embd_id": "BAAI/bge-large-zh-v1.5@SILICONFLOW",
|
||||
"embd_id": "BAAI/bge-small-en-v1.5@Builtin",
|
||||
"llm_id": "glm-4-flash@ZHIPU-AI"
|
||||
}
|
||||
res1 = create_memory(WebApiAuth, payload)
|
||||
|
||||
@ -216,11 +216,34 @@ def _load_user_app(monkeypatch):
|
||||
|
||||
tenant_llm_service_mod = ModuleType("api.db.services.tenant_llm_service")
|
||||
|
||||
class _MockTableObject:
|
||||
def __init__(self, **kwargs):
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
def to_dict(self):
|
||||
return {k: v for k, v in self.__dict__.items()}
|
||||
|
||||
class _StubTenantLLMService:
|
||||
@staticmethod
|
||||
def insert_many(_payload):
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def get_api_key(tenant_id, model_name):
|
||||
return _MockTableObject(
|
||||
id=1,
|
||||
tenant_id=tenant_id,
|
||||
llm_factory="",
|
||||
model_type="chat",
|
||||
llm_name=model_name,
|
||||
api_key="fake-api-key",
|
||||
api_base="https://api.example.com",
|
||||
max_tokens=8192,
|
||||
used_tokens=0,
|
||||
status=1
|
||||
)
|
||||
|
||||
tenant_llm_service_mod.TenantLLMService = _StubTenantLLMService
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.tenant_llm_service", tenant_llm_service_mod)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user