Feat/tenant model (#13072)

### What problem does this PR solve?

Add id for table tenant_llm and apply in LLMBundle.

### Type of change

- [x] Refactoring

---------

Co-authored-by: Yingfeng <yingfeng.zhang@gmail.com>
Co-authored-by: Liu An <asiro@qq.com>
This commit is contained in:
Lynn
2026-03-05 17:27:17 +08:00
committed by GitHub
parent 47540a4147
commit 62cb292635
54 changed files with 1754 additions and 361 deletions

View File

@ -44,9 +44,10 @@ class _AwaitableValue:
class _DummyKB:
def __init__(self, embd_id="embd@factory", chunk_num=1):
def __init__(self, embd_id="embd@factory", chunk_num=1, tenant_embd_id=1):
self.embd_id = embd_id
self.chunk_num = chunk_num
self.tenant_embd_id = tenant_embd_id
def to_json(self):
return {"id": "kb-1"}

View File

@ -45,9 +45,10 @@ class _AwaitableValue:
class _DummyKB:
def __init__(self, tenant_id="tenant-1", embd_id="embd-1"):
def __init__(self, tenant_id="tenant-1", embd_id="embd-1", tenant_embd_id=1):
self.tenant_id = tenant_id
self.embd_id = embd_id
self.tenant_embd_id = tenant_embd_id
class _DummyRetriever:
@ -102,6 +103,138 @@ def _load_dify_retrieval_module(monkeypatch):
monkeypatch.setitem(sys.modules, "deepdoc.parser.utils", deepdoc_parser_utils)
monkeypatch.setitem(sys.modules, "xgboost", ModuleType("xgboost"))
# Mock tenant_llm_service for TenantLLMService and TenantService
tenant_llm_service_mod = ModuleType("api.db.services.tenant_llm_service")
class _MockModelConfig:
def __init__(self, tenant_id, model_name):
self.tenant_id = tenant_id
self.llm_name = model_name
self.llm_factory = "Builtin"
self.api_key = "fake-api-key"
self.api_base = "https://api.example.com"
self.model_type = "chat"
self.max_tokens = 8192
self.used_tokens = 0
self.status = 1
self.id = 1
def to_dict(self):
return {
"tenant_id": self.tenant_id,
"llm_name": self.llm_name,
"llm_factory": self.llm_factory,
"api_key": self.api_key,
"api_base": self.api_base,
"model_type": self.model_type,
"max_tokens": self.max_tokens,
"used_tokens": self.used_tokens,
"status": self.status,
"id": self.id
}
class _StubTenantService:
@staticmethod
def get_by_id(tenant_id):
# Return a mock tenant with default model configurations
return True, SimpleNamespace(
id=tenant_id,
llm_id="chat-model",
embd_id="embd-model",
asr_id="asr-model",
img2txt_id="img2txt-model",
rerank_id="rerank-model",
tts_id="tts-model"
)
class _StubTenantLLMService:
@staticmethod
def get_api_key(tenant_id, model_name):
return _MockModelConfig(tenant_id, model_name)
@staticmethod
def split_model_name_and_factory(model_name):
if "@" in model_name:
parts = model_name.split("@")
return parts[0], parts[1]
return model_name, None
tenant_llm_service_mod.TenantService = _StubTenantService
tenant_llm_service_mod.TenantLLMService = _StubTenantLLMService
monkeypatch.setitem(sys.modules, "api.db.services.tenant_llm_service", tenant_llm_service_mod)
# Mock llm_service for LLMService
llm_service_mod = ModuleType("api.db.services.llm_service")
class _StubLLM:
def __init__(self, llm_name):
self.llm_name = llm_name
self.is_tools = False
class _StubLLMBundle:
def __init__(self, tenant_id: str, model_config: dict, lang="Chinese", **kwargs):
self.tenant_id = tenant_id
self.model_config = model_config
self.lang = lang
def encode(self, texts: list):
import numpy as np
# Return mock embeddings and token usage
return [np.array([0.1, 0.2, 0.3]) for _ in texts], len(texts) * 10
llm_service_mod.LLMService = SimpleNamespace(
query=lambda llm_name: [_StubLLM(llm_name)] if llm_name else []
)
llm_service_mod.LLMBundle = _StubLLMBundle
monkeypatch.setitem(sys.modules, "api.db.services.llm_service", llm_service_mod)
# Mock tenant_model_service to ensure it uses mocked services
tenant_model_service_mod = ModuleType("api.db.joint_services.tenant_model_service")
class _MockModelConfig2:
def __init__(self, tenant_id, model_name):
self.tenant_id = tenant_id
self.llm_name = model_name
self.llm_factory = "Builtin"
self.api_key = "fake-api-key"
self.api_base = "https://api.example.com"
self.model_type = "chat"
self.max_tokens = 8192
self.used_tokens = 0
self.status = 1
self.id = 1
def to_dict(self):
return {
"tenant_id": self.tenant_id,
"llm_name": self.llm_name,
"llm_factory": self.llm_factory,
"api_key": self.api_key,
"api_base": self.api_base,
"model_type": self.model_type,
"max_tokens": self.max_tokens,
"used_tokens": self.used_tokens,
"status": self.status,
"id": self.id
}
def _get_model_config_by_id(tenant_model_id: int) -> dict:
return _MockModelConfig2("tenant-1", "model-1").to_dict()
def _get_model_config_by_type_and_name(tenant_id: str, model_type: str, model_name: str):
if not model_name:
raise Exception("Model Name is required")
return _MockModelConfig2(tenant_id, model_name).to_dict()
def _get_tenant_default_model_by_type(tenant_id: str, model_type):
# Return mock tenant with default model configurations
return _MockModelConfig2(tenant_id, "chat-model").to_dict()
tenant_model_service_mod.get_model_config_by_id = _get_model_config_by_id
tenant_model_service_mod.get_model_config_by_type_and_name = _get_model_config_by_type_and_name
tenant_model_service_mod.get_tenant_default_model_by_type = _get_tenant_default_model_by_type
monkeypatch.setitem(sys.modules, "api.db.joint_services.tenant_model_service", tenant_model_service_mod)
module_name = "test_dify_retrieval_routes_unit_module"
module_path = repo_root / "api" / "apps" / "sdk" / "dify_retrieval.py"
spec = importlib.util.spec_from_file_location(module_name, module_path)
@ -134,7 +267,6 @@ def test_retrieval_success_with_metadata_and_kg(monkeypatch):
monkeypatch.setattr(module, "jsonify", lambda payload: payload)
monkeypatch.setattr(module.DocMetadataService, "get_meta_by_kbs", lambda _kb_ids: [{"doc_id": "doc-1"}])
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, _DummyKB()))
monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: object())
monkeypatch.setattr(module, "convert_conditions", lambda cond: cond.get("conditions", []))
monkeypatch.setattr(module, "meta_filter", lambda *_args, **_kwargs: [])
@ -185,7 +317,6 @@ def test_retrieval_not_found_exception_mapping(monkeypatch):
_set_request_json(monkeypatch, module, {"knowledge_id": "kb-1", "query": "hello"})
monkeypatch.setattr(module.DocMetadataService, "get_meta_by_kbs", lambda _kb_ids: [])
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, _DummyKB()))
monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: object())
monkeypatch.setattr(module, "label_question", lambda *_args, **_kwargs: [])
class _BrokenRetriever:
@ -205,7 +336,6 @@ def test_retrieval_generic_exception_mapping(monkeypatch):
_set_request_json(monkeypatch, module, {"knowledge_id": "kb-1", "query": "hello"})
monkeypatch.setattr(module.DocMetadataService, "get_meta_by_kbs", lambda _kb_ids: [])
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, _DummyKB()))
monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: object())
monkeypatch.setattr(module, "label_question", lambda *_args, **_kwargs: [])
class _BrokenRetriever:

View File

@ -151,6 +151,149 @@ def _load_doc_module(monkeypatch):
monkeypatch.setitem(sys.modules, "deepdoc.parser.utils", deepdoc_parser_utils)
monkeypatch.setitem(sys.modules, "xgboost", ModuleType("xgboost"))
# Mock tenant_llm_service for TenantLLMService and TenantService
tenant_llm_service_mod = ModuleType("api.db.services.tenant_llm_service")
class _MockModelConfig:
def __init__(self, tenant_id, model_name):
self.tenant_id = tenant_id
self.llm_name = model_name
self.llm_factory = "Builtin"
self.api_key = "fake-api-key"
self.api_base = "https://api.example.com"
self.model_type = "embedding"
self.max_tokens = 8192
self.used_tokens = 0
self.status = 1
self.id = 1
def to_dict(self):
return {
"tenant_id": self.tenant_id,
"llm_name": self.llm_name,
"llm_factory": self.llm_factory,
"api_key": self.api_key,
"api_base": self.api_base,
"model_type": self.model_type,
"max_tokens": self.max_tokens,
"used_tokens": self.used_tokens,
"status": self.status,
"id": self.id
}
class _StubTenantService:
@staticmethod
def get_by_id(tenant_id):
return True, SimpleNamespace(
id=tenant_id,
llm_id="chat-model",
embd_id="embd-model",
asr_id="asr-model",
img2txt_id="img2txt-model",
rerank_id="rerank-model",
tts_id="tts-model"
)
class _StubTenantLLMService:
@staticmethod
def get_api_key(tenant_id, model_name):
return _MockModelConfig(tenant_id, model_name)
@staticmethod
def split_model_name_and_factory(model_name):
if "@" in model_name:
parts = model_name.split("@")
return parts[0], parts[1]
return model_name, None
@staticmethod
def get_by_id(tenant_model_id):
return True, _MockModelConfig("tenant-1", "model-1")
@staticmethod
def model_instance(model_config):
class _EmbedModel:
def encode(self, texts):
import numpy as np
return [np.array([0.2, 0.8]), np.array([0.3, 0.7])], 1
return _EmbedModel()
tenant_llm_service_mod.TenantService = _StubTenantService
tenant_llm_service_mod.TenantLLMService = _StubTenantLLMService
monkeypatch.setitem(sys.modules, "api.db.services.tenant_llm_service", tenant_llm_service_mod)
# Mock LLMService
llm_service_mod = ModuleType("api.db.services.llm_service")
class _StubLLM:
def __init__(self, llm_name):
self.llm_name = llm_name
self.is_tools = False
class _StubLLMBundle:
def __init__(self, tenant_id: str, model_config: dict, lang="Chinese", **kwargs):
self.tenant_id = tenant_id
self.model_config = model_config
self.lang = lang
def encode(self, texts: list):
import numpy as np
# Return mock embeddings and token usage
return [np.array([0.2, 0.8]), np.array([0.3, 0.7])], len(texts) * 10
llm_service_mod.LLMService = SimpleNamespace(
query=lambda llm_name: [_StubLLM(llm_name)] if llm_name else []
)
llm_service_mod.LLMBundle = _StubLLMBundle
monkeypatch.setitem(sys.modules, "api.db.services.llm_service", llm_service_mod)
# Mock tenant_model_service to ensure it uses mocked services
tenant_model_service_mod = ModuleType("api.db.joint_services.tenant_model_service")
class _MockModelConfig2:
def __init__(self, tenant_id, model_name):
self.tenant_id = tenant_id
self.llm_name = model_name
self.llm_factory = "Builtin"
self.api_key = "fake-api-key"
self.api_base = "https://api.example.com"
self.model_type = "embedding"
self.max_tokens = 8192
self.used_tokens = 0
self.status = 1
self.id = 1
def to_dict(self):
return {
"tenant_id": self.tenant_id,
"llm_name": self.llm_name,
"llm_factory": self.llm_factory,
"api_key": self.api_key,
"api_base": self.api_base,
"model_type": self.model_type,
"max_tokens": self.max_tokens,
"used_tokens": self.used_tokens,
"status": self.status,
"id": self.id
}
def _get_model_config_by_id(tenant_model_id: int) -> dict:
return _MockModelConfig2("tenant-1", "model-1").to_dict()
def _get_model_config_by_type_and_name(tenant_id: str, model_type: str, model_name: str):
if not model_name:
raise Exception("Model Name is required")
return _MockModelConfig2(tenant_id, model_name).to_dict()
def _get_tenant_default_model_by_type(tenant_id: str, model_type):
# Return mock tenant with default model configurations
return _MockModelConfig2(tenant_id, "chat-model").to_dict()
tenant_model_service_mod.get_model_config_by_id = _get_model_config_by_id
tenant_model_service_mod.get_model_config_by_type_and_name = _get_model_config_by_type_and_name
tenant_model_service_mod.get_tenant_default_model_by_type = _get_tenant_default_model_by_type
monkeypatch.setitem(sys.modules, "api.db.joint_services.tenant_model_service", tenant_model_service_mod)
module_path = repo_root / "api" / "apps" / "sdk" / "doc.py"
spec = importlib.util.spec_from_file_location("test_doc_sdk_routes_unit", module_path)
module = importlib.util.module_from_spec(spec)
@ -762,6 +905,7 @@ class TestDocRoutesUnit:
monkeypatch.setattr(module.rag_tokenizer, "fine_grained_tokenize", lambda text: text or "")
monkeypatch.setattr(module.rag_tokenizer, "is_chinese", lambda _text: False)
monkeypatch.setattr(module.DocumentService, "get_embd_id", lambda _doc_id: "embd")
monkeypatch.setattr(module.DocumentService, "get_tenant_embd_id", lambda _doc_id: 1)
class _EmbedModel:
def encode(self, _texts):
@ -851,8 +995,8 @@ class TestDocRoutesUnit:
"get_request_json",
lambda: _AwaitableValue({"dataset_ids": ["ds-1"], "question": "q", "highlight": "True"}),
)
monkeypatch.setattr(module.KnowledgebaseService, "get_by_ids", lambda _ids: [SimpleNamespace(embd_id="m1", tenant_id="tenant-1")])
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _id: (True, SimpleNamespace(tenant_id="tenant-1", embd_id="m1")))
monkeypatch.setattr(module.KnowledgebaseService, "get_by_ids", lambda _ids: [SimpleNamespace(embd_id="m1", tenant_id="tenant-1", tenant_embd_id=1)])
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _id: (True, SimpleNamespace(tenant_id="tenant-1", embd_id="m1", tenant_embd_id=1)))
class _Retriever:
async def retrieval(self, *_args, **_kwargs):
@ -865,7 +1009,7 @@ class TestDocRoutesUnit:
monkeypatch.setattr(module, "label_question", lambda *_args, **_kwargs: {})
monkeypatch.setattr(module.settings, "retriever", _Retriever())
res = _run(module.retrieval_test.__wrapped__("tenant-1"))
assert res["code"] == 0
assert res["code"] == 0, res["message"]
assert res["data"]["chunks"] == []
monkeypatch.setattr(
@ -974,7 +1118,7 @@ class TestDocRoutesUnit:
}
),
)
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _id: (True, SimpleNamespace(tenant_id="tenant-1", embd_id="m1")))
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _id: (True, SimpleNamespace(tenant_id="tenant-1", embd_id="m1", tenant_embd_id=1)))
monkeypatch.setattr(module, "cross_languages", _cross_languages)
monkeypatch.setattr(module, "keyword_extraction", _keyword_extraction)
monkeypatch.setattr(module.settings, "retriever", _FeatureRetriever())
@ -982,7 +1126,7 @@ class TestDocRoutesUnit:
monkeypatch.setattr(module, "label_question", lambda *_args, **_kwargs: {})
monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: SimpleNamespace())
res = _run(module.retrieval_test.__wrapped__("tenant-1"))
assert res["code"] == 0
assert res["code"] == 0, res["message"]
assert feature_calls["cross"] == ("fr",)
assert feature_calls["keyword"] == "q-xl"
assert feature_calls["retrieval_question"] == "q-xl-kw"

View File

@ -121,6 +121,132 @@ def _load_session_module(monkeypatch):
common_pkg.__path__ = [str(repo_root / "common")]
monkeypatch.setitem(sys.modules, "common", common_pkg)
# Mock common.constants module
from enum import Enum
from strenum import StrEnum
class _StubLLMType(StrEnum):
CHAT = "chat"
EMBEDDING = "embedding"
SPEECH2TEXT = "speech2text"
IMAGE2TEXT = "image2text"
RERANK = "rerank"
TTS = "tts"
OCR = "ocr"
class _StubParserType(StrEnum):
PRESENTATION = "presentation"
LAWS = "laws"
MANUAL = "manual"
PAPER = "paper"
RESUME = "resume"
BOOK = "book"
QA = "qa"
TABLE = "table"
NAIVE = "naive"
PICTURE = "picture"
ONE = "one"
AUDIO = "audio"
EMAIL = "email"
KG = "knowledge_graph"
TAG = "tag"
class _StubRetCode(int, Enum):
SUCCESS = 0
NOT_EFFECTIVE = 10
EXCEPTION_ERROR = 100
ARGUMENT_ERROR = 101
DATA_ERROR = 102
OPERATING_ERROR = 103
CONNECTION_ERROR = 105
RUNNING = 106
PERMISSION_ERROR = 108
AUTHENTICATION_ERROR = 109
BAD_REQUEST = 400
UNAUTHORIZED = 401
SERVER_ERROR = 500
FORBIDDEN = 403
NOT_FOUND = 404
CONFLICT = 409
class _StubStatusEnum(str, Enum):
VALID = "1"
INVALID = "0"
class _StubActiveEnum(Enum):
ACTIVE = "1"
INACTIVE = "0"
class _StubStorage(Enum):
MINIO = 1
AZURE_SPN = 2
AZURE_SAS = 3
AWS_S3 = 4
OSS = 5
OPENDAL = 6
GCS = 7
class _StubMCPServerType(StrEnum):
SSE = "sse"
STREAMABLE_HTTP = "streamable-http"
class _StubTaskStatus(StrEnum):
UNSTART = "0"
RUNNING = "1"
CANCEL = "2"
DONE = "3"
FAIL = "4"
SCHEDULE = "5"
class _StubFileSource(StrEnum):
LOCAL = ""
KNOWLEDGEBASE = "knowledgebase"
S3 = "s3"
NOTION = "notion"
DISCORD = "discord"
CONFLUENCE = "confluence"
GMAIL = "gmail"
GOOGLE_DRIVE = "google_drive"
JIRA = "jira"
SHAREPOINT = "sharepoint"
SLACK = "slack"
TEAMS = "teams"
WEBDAV = "webdav"
MOODLE = "moodle"
DROPBOX = "dropbox"
BOX = "box"
R2 = "r2"
OCI_STORAGE = "oci_storage"
GOOGLE_CLOUD_STORAGE = "google_cloud_storage"
AIRTABLE = "airtable"
ASANA = "asana"
GITHUB = "github"
GITLAB = "gitlab"
IMAP = "imap"
BITBUCKET = "bitbucket"
ZENDESK = "zendesk"
SEAFILE = "seafile"
MYSQL = "mysql"
POSTGRESQL = "postgresql"
common_constants_mod = ModuleType("common.constants")
common_constants_mod.LLMType = _StubLLMType
common_constants_mod.ParserType = _StubParserType
common_constants_mod.RetCode = _StubRetCode
common_constants_mod.StatusEnum = _StubStatusEnum
common_constants_mod.ActiveEnum = _StubActiveEnum
common_constants_mod.Storage = _StubStorage
common_constants_mod.MCPServerType = _StubMCPServerType
common_constants_mod.TaskStatus = _StubTaskStatus
common_constants_mod.FileSource = _StubFileSource
common_constants_mod.SERVICE_CONF = "service_conf.yaml"
common_constants_mod.RAG_FLOW_SERVICE_NAME = "ragflow"
common_constants_mod.SVR_QUEUE_NAME = "rag_flow_svr_queue"
common_constants_mod.SVR_CONSUMER_GROUP_NAME = "rag_flow_svr_task_broker"
common_constants_mod.PAGERANK_FLD = "pagerank_fea"
common_constants_mod.TAG_FLD = "tag_feas"
monkeypatch.setitem(sys.modules, "common.constants", common_constants_mod)
deepdoc_pkg = ModuleType("deepdoc")
deepdoc_parser_pkg = ModuleType("deepdoc.parser")
deepdoc_parser_pkg.__path__ = []
@ -166,6 +292,180 @@ def _load_session_module(monkeypatch):
monkeypatch.setitem(sys.modules, "deepdoc.parser.utils", deepdoc_parser_utils)
monkeypatch.setitem(sys.modules, "xgboost", ModuleType("xgboost"))
# Mock tenant_llm_service for TenantLLMService and TenantService
tenant_llm_service_mod = ModuleType("api.db.services.tenant_llm_service")
class _MockModelConfig:
def __init__(self, tenant_id, model_name):
self.tenant_id = tenant_id
self.llm_name = model_name
self.llm_factory = "Builtin"
self.api_key = "fake-api-key"
self.api_base = "https://api.example.com"
self.model_type = "chat"
self.max_tokens = 8192
self.used_tokens = 0
self.status = 1
self.id = 1
def to_dict(self):
return {
"tenant_id": self.tenant_id,
"llm_name": self.llm_name,
"llm_factory": self.llm_factory,
"api_key": self.api_key,
"api_base": self.api_base,
"model_type": self.model_type,
"max_tokens": self.max_tokens,
"used_tokens": self.used_tokens,
"status": self.status,
"id": self.id
}
class _StubTenantService:
@staticmethod
def get_by_id(tenant_id):
# Return a mock tenant with default model configurations
return True, SimpleNamespace(
id=tenant_id,
llm_id="chat-model",
embd_id="embd-model",
asr_id="asr-model",
img2txt_id="img2txt-model",
rerank_id="rerank-model",
tts_id="tts-model"
)
class _StubTenantLLMService:
@staticmethod
def get_api_key(tenant_id, model_name):
return _MockModelConfig(tenant_id, model_name)
@staticmethod
def split_model_name_and_factory(model_name):
if "@" in model_name:
parts = model_name.split("@")
return parts[0], parts[1]
return model_name, None
class _StubLLMFactoriesService:
@staticmethod
def query(**_kwargs):
return []
tenant_llm_service_mod.TenantService = _StubTenantService
tenant_llm_service_mod.TenantLLMService = _StubTenantLLMService
tenant_llm_service_mod.LLMFactoriesService = _StubLLMFactoriesService
monkeypatch.setitem(sys.modules, "api.db.services.tenant_llm_service", tenant_llm_service_mod)
# Mock LLMService
llm_service_mod = ModuleType("api.db.services.llm_service")
class _StubLLM:
def __init__(self, llm_name):
self.llm_name = llm_name
self.is_tools = False
llm_service_mod.LLMService = SimpleNamespace(
query=lambda llm_name: [_StubLLM(llm_name)] if llm_name else []
)
class _StubLLMBundle:
def __init__(self, tenant_id: str, model_config: dict, lang="Chinese", **kwargs):
self.tenant_id = tenant_id
self.model_config = model_config
self.lang = lang
async def async_chat(self, prompt, messages, options):
return "mock response"
def transcription(self, audio_path):
return "mock transcription"
llm_service_mod.LLMBundle = _StubLLMBundle
monkeypatch.setitem(sys.modules, "api.db.services.llm_service", llm_service_mod)
# Mock tenant_model_service to ensure it uses mocked services
tenant_model_service_mod = ModuleType("api.db.joint_services.tenant_model_service")
class _MockModelConfig2:
def __init__(self, tenant_id, model_name, model_type="chat"):
self.tenant_id = tenant_id
self.llm_name = model_name
self.llm_factory = "Builtin"
self.api_key = "fake-api-key"
self.api_base = "https://api.example.com"
self.model_type = model_type
self.max_tokens = 8192
self.used_tokens = 0
self.status = 1
self.id = 1
def to_dict(self):
return {
"tenant_id": self.tenant_id,
"llm_name": self.llm_name,
"llm_factory": self.llm_factory,
"api_key": self.api_key,
"api_base": self.api_base,
"model_type": self.model_type,
"max_tokens": self.max_tokens,
"used_tokens": self.used_tokens,
"status": self.status,
"id": self.id
}
def _get_model_config_by_id(tenant_model_id: int) -> dict:
return _MockModelConfig2("tenant-1", "model-1").to_dict()
def _get_model_config_by_type_and_name(tenant_id: str, model_type: str, model_name: str):
if not model_name:
raise Exception("Model Name is required")
return _MockModelConfig2(tenant_id, model_name, model_type).to_dict()
def _get_tenant_default_model_by_type(tenant_id: str, model_type):
# Check if tenant exists
from api.db.services.tenant_llm_service import TenantService
exist, tenant = TenantService.get_by_id(tenant_id)
if not exist:
raise LookupError("Tenant not found!")
# Return mock tenant with default model configurations
model_type_val = model_type if isinstance(model_type, str) else model_type.value
model_name = ""
if model_type_val == "embedding":
model_name = tenant.embd_id
elif model_type_val == "speech2text":
model_name = tenant.asr_id
elif model_type_val == "image2text":
model_name = tenant.img2txt_id
elif model_type_val == "chat":
model_name = tenant.llm_id
elif model_type_val == "rerank":
model_name = tenant.rerank_id
elif model_type_val == "tts":
model_name = tenant.tts_id
elif model_type_val == "ocr":
raise Exception("OCR model name is required")
if not model_name:
# Use friendly model type names
friendly_names = {
"embedding": "Embedding",
"speech2text": "ASR",
"image2text": "Image2Text",
"chat": "Chat",
"rerank": "Rerank",
"tts": "TTS",
"ocr": "OCR"
}
friendly_name = friendly_names.get(model_type_val, model_type_val)
raise Exception(f"No default {friendly_name} model is set")
return _MockModelConfig2(tenant_id, model_name, model_type_val).to_dict()
tenant_model_service_mod.get_model_config_by_id = _get_model_config_by_id
tenant_model_service_mod.get_model_config_by_type_and_name = _get_model_config_by_type_and_name
tenant_model_service_mod.get_tenant_default_model_by_type = _get_tenant_default_model_by_type
monkeypatch.setitem(sys.modules, "api.db.joint_services.tenant_model_service", tenant_model_service_mod)
agent_pkg = ModuleType("agent")
agent_pkg.__path__ = []
agent_canvas_mod = ModuleType("agent.canvas")
@ -200,6 +500,29 @@ def _load_session_module(monkeypatch):
module.manager = _DummyManager()
monkeypatch.setitem(sys.modules, "test_session_sdk_routes_unit_module", module)
spec.loader.exec_module(module)
# Add TenantService to module for test compatibility
class _StubTenantServiceForTest:
@staticmethod
def get_info_by(tenant_id):
# Return mock tenant info for tests
return []
@staticmethod
def get_by_id(tenant_id):
# Return mock tenant by id
return True, SimpleNamespace(
id=tenant_id,
llm_id="chat-model",
embd_id="embd-model",
asr_id="asr-model",
img2txt_id="img2txt-model",
rerank_id="rerank-model",
tts_id="tts-model"
)
module.TenantService = _StubTenantServiceForTest
return module
@ -1028,9 +1351,12 @@ def test_searchbots_retrieval_test_embedded_matrix_unit(monkeypatch):
llm_calls = []
def _fake_llm_bundle(tenant_id, llm_type, *args, **kwargs):
llm_calls.append((tenant_id, llm_type, args, kwargs))
return SimpleNamespace(tenant_id=tenant_id, llm_type=llm_type, args=args, kwargs=kwargs)
def _fake_llm_bundle(tenant_id, model_config, *args, **kwargs):
# Extract llm_type from model_config for comparison
llm_type = model_config.get("model_type") if isinstance(model_config, dict) else model_config
llm_name = model_config.get("llm_name") if isinstance(model_config, dict) else None
llm_calls.append((tenant_id, llm_type, llm_name, args, kwargs))
return SimpleNamespace(tenant_id=tenant_id, llm_type=llm_type, llm_name=llm_name, args=args, kwargs=kwargs)
monkeypatch.setattr(module, "LLMBundle", _fake_llm_bundle)
monkeypatch.setattr(
@ -1128,7 +1454,7 @@ def test_searchbots_retrieval_test_embedded_matrix_unit(monkeypatch):
monkeypatch.setattr(
module.KnowledgebaseService,
"get_by_id",
lambda _kb_id: (True, SimpleNamespace(tenant_id="tenant-kb", embd_id="embd-model")),
lambda _kb_id: (True, SimpleNamespace(tenant_id="tenant-kb", embd_id="embd-model", tenant_embd_id=None)),
)
res = _run(handler())
assert res["code"] == 0
@ -1143,7 +1469,7 @@ def test_searchbots_retrieval_test_embedded_matrix_unit(monkeypatch):
assert retrieval_capture["local_doc_ids"] == ["doc-filtered"]
assert retrieval_capture["rank_feature"] == ["label-1"]
assert retrieval_capture["rerank_mdl"] is not None
assert any(call[1] == module.LLMType.EMBEDDING.value and call[3].get("llm_name") == "embd-model" for call in llm_calls)
assert any(call[1] == module.LLMType.EMBEDDING.value and call[2] == "embd-model" for call in llm_calls)
llm_calls.clear()
@ -1178,7 +1504,7 @@ def test_searchbots_retrieval_test_embedded_matrix_unit(monkeypatch):
monkeypatch.setattr(
module.KnowledgebaseService,
"get_by_id",
lambda _kb_id: (True, SimpleNamespace(tenant_id="tenant-kb", embd_id="embd-model")),
lambda _kb_id: (True, SimpleNamespace(tenant_id="tenant-kb", embd_id="embd-model", tenant_embd_id=None)),
)
res = _run(handler())
assert res["code"] == 0
@ -1247,7 +1573,11 @@ def test_searchbots_related_questions_embedded_matrix_unit(monkeypatch):
res = _run(handler())
assert res["code"] == 0
assert res["data"] == ["Alpha", "Beta"]
assert captured["bundle_args"] == ("tenant-1", module.LLMType.CHAT, "chat-x")
# LLMBundle is called with (tenant_id, model_config)
# model_config is a dict with model_type, llm_name, etc.
assert captured["bundle_args"][0] == "tenant-1"
assert captured["bundle_args"][1].get("model_type") == module.LLMType.CHAT
assert captured["bundle_args"][1].get("llm_name") == "chat-x"
assert captured["options"] == {"temperature": 0.2}
assert "Keywords: solar" in captured["messages"][0]["content"]
@ -1361,12 +1691,14 @@ def test_sequence2txt_embedded_validation_and_stream_matrix_unit(monkeypatch):
assert "Unsupported audio format: .txt" in res["message"]
_set_request({"stream": "false"}, {"file": _DummyUploadFile("audio.wav")})
monkeypatch.setattr(module.TenantService, "get_info_by", lambda _tid: [])
tenant_llm_service = sys.modules["api.db.services.tenant_llm_service"]
monkeypatch.setattr(tenant_llm_service.TenantService, "get_by_id", lambda _tid: (False, None))
res = _run(handler("tenant-1"))
assert res["message"] == "Tenant not found!"
_set_request({"stream": "false"}, {"file": _DummyUploadFile("audio.wav")})
monkeypatch.setattr(module.TenantService, "get_info_by", lambda _tid: [{"tenant_id": "tenant-1", "asr_id": ""}])
tenant_llm_service = sys.modules["api.db.services.tenant_llm_service"]
monkeypatch.setattr(tenant_llm_service.TenantService, "get_by_id", lambda _tid: (True, SimpleNamespace(asr_id="", tts_id="", llm_id="", embd_id="", img2txt_id="", rerank_id="")))
res = _run(handler("tenant-1"))
assert res["message"] == "No default ASR model is set"
@ -1378,7 +1710,7 @@ def test_sequence2txt_embedded_validation_and_stream_matrix_unit(monkeypatch):
return []
_set_request({"stream": "false"}, {"file": _DummyUploadFile("audio.wav")})
monkeypatch.setattr(module.TenantService, "get_info_by", lambda _tid: [{"tenant_id": "tenant-1", "asr_id": "asr-x"}])
monkeypatch.setattr(tenant_llm_service.TenantService, "get_by_id", lambda _tid: (True, SimpleNamespace(asr_id="asr-x", tts_id="", llm_id="", embd_id="", img2txt_id="", rerank_id="")))
monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: _SyncASR())
monkeypatch.setattr(module.os, "remove", lambda _path: (_ for _ in ()).throw(RuntimeError("cleanup fail")))
res = _run(handler("tenant-1"))
@ -1420,14 +1752,15 @@ def test_sequence2txt_embedded_validation_and_stream_matrix_unit(monkeypatch):
def test_tts_embedded_stream_and_error_matrix_unit(monkeypatch):
module = _load_session_module(monkeypatch)
handler = inspect.unwrap(module.tts)
monkeypatch.setattr(module, "Response", _StubResponse)
monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"text": "A。B"}))
monkeypatch.setattr(module, "Response", _StubResponse)
monkeypatch.setattr(module.TenantService, "get_info_by", lambda _tid: [])
tenant_llm_service = sys.modules["api.db.services.tenant_llm_service"]
monkeypatch.setattr(tenant_llm_service.TenantService, "get_by_id", lambda _tid: (False, None))
res = _run(handler("tenant-1"))
assert res["message"] == "Tenant not found!"
monkeypatch.setattr(module.TenantService, "get_info_by", lambda _tid: [{"tenant_id": "tenant-1", "tts_id": ""}])
monkeypatch.setattr(tenant_llm_service.TenantService, "get_by_id", lambda _tid: (True, SimpleNamespace(asr_id="", tts_id="", llm_id="", embd_id="", img2txt_id="", rerank_id="")))
res = _run(handler("tenant-1"))
assert res["message"] == "No default TTS model is set"
@ -1437,7 +1770,7 @@ def test_tts_embedded_stream_and_error_matrix_unit(monkeypatch):
return []
yield f"chunk-{txt}".encode("utf-8")
monkeypatch.setattr(module.TenantService, "get_info_by", lambda _tid: [{"tenant_id": "tenant-1", "tts_id": "tts-x"}])
monkeypatch.setattr(tenant_llm_service.TenantService, "get_by_id", lambda _tid: (True, SimpleNamespace(asr_id="", tts_id="tts-x", llm_id="", embd_id="", img2txt_id="", rerank_id="")))
monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: _TTSOk())
resp = _run(handler("tenant-1"))
assert resp.mimetype == "audio/mpeg"

View File

@ -31,7 +31,7 @@ def add_memory_func(client, request):
payload = {
"name": f"test_memory_{i}",
"memory_type": ["raw"] + random.choices(["semantic", "episodic", "procedural"], k=random.randint(0, 3)),
"embd_id": "BAAI/bge-large-zh-v1.5@SILICONFLOW",
"embd_id": "BAAI/bge-small-en-v1.5@Builtin",
"llm_id": "glm-4-flash@ZHIPU-AI"
}
res = client.create_memory(**payload)

View File

@ -36,7 +36,7 @@ class TestAuthorization:
def test_auth_invalid(self, invalid_auth, expected_message):
client = RAGFlow(invalid_auth, HOST_ADDRESS)
with pytest.raises(Exception) as exception_info:
client.create_memory(**{"name": "test_memory", "memory_type": ["raw"], "embd_id": "BAAI/bge-large-zh-v1.5@SILICONFLOW", "llm_id": "glm-4-flash@ZHIPU-AI"})
client.create_memory(**{"name": "test_memory", "memory_type": ["raw"], "embd_id": "BAAI/bge-small-en-v1.5@Builtin", "llm_id": "glm-4-flash@ZHIPU-AI"})
assert str(exception_info.value) == expected_message, str(exception_info.value)
@ -50,7 +50,7 @@ class TestMemoryCreate:
payload = {
"name": name,
"memory_type": ["raw"] + random.choices(["semantic", "episodic", "procedural"], k=random.randint(0, 3)),
"embd_id": "BAAI/bge-large-zh-v1.5@SILICONFLOW",
"embd_id": "BAAI/bge-small-en-v1.5@Builtin",
"llm_id": "glm-4-flash@ZHIPU-AI"
}
memory = client.create_memory(**payload)
@ -72,7 +72,7 @@ class TestMemoryCreate:
payload = {
"name": name,
"memory_type": ["raw"] + random.choices(["semantic", "episodic", "procedural"], k=random.randint(0, 3)),
"embd_id": "BAAI/bge-large-zh-v1.5@SILICONFLOW",
"embd_id": "BAAI/bge-small-en-v1.5@Builtin",
"llm_id": "glm-4-flash@ZHIPU-AI"
}
with pytest.raises(Exception) as exception_info:
@ -86,7 +86,7 @@ class TestMemoryCreate:
payload = {
"name": name,
"memory_type": ["something"],
"embd_id": "BAAI/bge-large-zh-v1.5@SILICONFLOW",
"embd_id": "BAAI/bge-small-en-v1.5@Builtin",
"llm_id": "glm-4-flash@ZHIPU-AI"
}
with pytest.raises(Exception) as exception_info:
@ -99,7 +99,7 @@ class TestMemoryCreate:
payload = {
"name": name,
"memory_type": ["raw"] + random.choices(["semantic", "episodic", "procedural"], k=random.randint(0, 3)),
"embd_id": "BAAI/bge-large-zh-v1.5@SILICONFLOW",
"embd_id": "BAAI/bge-small-en-v1.5@Builtin",
"llm_id": "glm-4-flash@ZHIPU-AI"
}
res1 = client.create_memory(**payload)

View File

@ -207,6 +207,10 @@ def _load_chunk_module(monkeypatch):
EMBEDDING = SimpleNamespace(value="embedding")
CHAT = SimpleNamespace(value="chat")
RERANK = SimpleNamespace(value="rerank")
SPEECH2TEXT = SimpleNamespace(value="speech2text")
IMAGE2TEXT = SimpleNamespace(value="image2text")
TTS = SimpleNamespace(value="tts")
OCR = SimpleNamespace(value="ocr")
constants_mod.RetCode = _DummyRetCode
constants_mod.LLMType = _DummyLLMType
@ -301,6 +305,10 @@ def _load_chunk_module(monkeypatch):
def get_embd_id(_doc_id):
return "embed-1"
@staticmethod
def get_tenant_embd_id(_doc_id):
return 1
@staticmethod
def decrement_chunk_num(*args):
_DocumentService.decrement_calls.append(args)
@ -327,13 +335,24 @@ def _load_chunk_module(monkeypatch):
@staticmethod
def get_by_id(_kb_id):
return True, SimpleNamespace(pagerank=0.6)
return True, SimpleNamespace(pagerank=0.6, tenant_embd_id=2, tenant_llm_id=1)
kb_service_mod.KnowledgebaseService = _KnowledgebaseService
monkeypatch.setitem(sys.modules, "api.db.services.knowledgebase_service", kb_service_mod)
services_pkg.knowledgebase_service = kb_service_mod
class _DummyLLMService:
@staticmethod
def query(**_kwargs):
return [SimpleNamespace(
llm_name="gpt-3.5-turbo",
model_type="chat",
max_tokens=8192,
is_tools=True
)]
llm_service_mod = ModuleType("api.db.services.llm_service")
llm_service_mod.LLMService = _DummyLLMService
llm_service_mod.LLMBundle = _DummyLLMBundle
monkeypatch.setitem(sys.modules, "api.db.services.llm_service", llm_service_mod)
services_pkg.llm_service = llm_service_mod
@ -343,6 +362,77 @@ def _load_chunk_module(monkeypatch):
monkeypatch.setitem(sys.modules, "api.db.services.search_service", search_service_mod)
services_pkg.search_service = search_service_mod
tenant_llm_service_mod = ModuleType("api.db.services.tenant_llm_service")
class _MockTableObject:
def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
def to_dict(self):
return {k: v for k, v in self.__dict__.items()}
class _TenantLLMService:
@staticmethod
def get_by_id(tenant_model_id):
return True, _MockTableObject(
id=tenant_model_id,
tenant_id="tenant-1",
llm_factory="",
model_type="chat",
llm_name="gpt-3.5-turbo",
api_key="fake-api-key",
api_base="https://api.example.com",
max_tokens=8192,
used_tokens=0,
status=1
)
@staticmethod
def get_api_key(tenant_id, model_name):
return _MockTableObject(
id=1,
tenant_id=tenant_id,
llm_factory="",
model_type="chat",
llm_name=model_name,
api_key="fake-api-key",
api_base="https://api.example.com",
max_tokens=8192,
used_tokens=0,
status=1
)
@staticmethod
def split_model_name_and_factory(model_name):
if "@" in model_name:
parts = model_name.rsplit("@", 1)
return parts[0], parts[1]
return model_name, None
@staticmethod
def increase_usage_by_id(model_id, used_tokens):
return True
class _TenantService:
@staticmethod
def get_by_id(tenant_id):
return True, SimpleNamespace(
llm_id="gpt-3.5-turbo",
tenant_llm_id=1,
embd_id="text-embedding-ada-002",
tenant_embd_id=2,
asr_id="whisper-1",
img2txt_id="gpt-4-vision-preview",
rerank_id="bge-reranker",
tts_id="tts-1"
)
tenant_llm_service_mod.TenantLLMService = _TenantLLMService
tenant_llm_service_mod.TenantService = _TenantService
monkeypatch.setitem(sys.modules, "api.db.services.tenant_llm_service", tenant_llm_service_mod)
services_pkg.tenant_llm_service = tenant_llm_service_mod
user_service_mod = ModuleType("api.db.services.user_service")
class _UserTenantService:
@ -775,7 +865,7 @@ def test_retrieval_test_branch_matrix_unit(monkeypatch):
assert "Knowledgebase not found!" in res["message"], res
retriever = _Retriever(mode="ok")
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, SimpleNamespace(tenant_id="tenant-kb", embd_id="embd-1")), raising=False)
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, SimpleNamespace(tenant_id="tenant-kb", embd_id="embd-1", tenant_embd_id=2)), raising=False)
monkeypatch.setattr(module.settings, "retriever", retriever)
monkeypatch.setattr(module.settings, "kg_retriever", _KgRetriever(), raising=False)
_set_request_json(

View File

@ -143,6 +143,19 @@ def _load_conversation_module(monkeypatch):
apps_mod.login_required = lambda func: func
monkeypatch.setitem(sys.modules, "api.apps", apps_mod)
# Create user_service module with TenantService stub if not already exists
if "api.db.services.user_service" not in sys.modules:
user_service_mod = ModuleType("api.db.services.user_service")
user_service_mod.UserService = SimpleNamespace() # Dummy UserService class
user_service_mod.TenantService = SimpleNamespace(
get_info_by=lambda _uid: [],
get_by_id=lambda _uid: (False, None)
)
user_service_mod.UserTenantService = SimpleNamespace(
query=lambda **_kwargs: []
)
monkeypatch.setitem(sys.modules, "api.db.services.user_service", user_service_mod)
module_name = "test_conversation_routes_unit_module"
module_path = repo_root / "api" / "apps" / "conversation_app.py"
spec = importlib.util.spec_from_file_location(module_name, module_path)
@ -519,15 +532,15 @@ def test_sequence2txt_validation_and_transcription_paths(monkeypatch):
wav_file = _DummyUploadedFile("audio.wav")
monkeypatch.setattr(module, "request", _DummyRequest(form={"stream": "false"}, files={"file": wav_file}))
monkeypatch.setattr(module.TenantService, "get_info_by", lambda _uid: [])
monkeypatch.setattr(sys.modules["api.db.joint_services.tenant_model_service"].TenantService, "get_by_id", lambda _uid: (False, None))
res = _run(module.sequence2txt())
assert res["message"] == "Tenant not found!"
assert res["message"] == "Tenant not found"
wav_file = _DummyUploadedFile("audio.wav")
monkeypatch.setattr(module, "request", _DummyRequest(form={"stream": "false"}, files={"file": wav_file}))
monkeypatch.setattr(module.TenantService, "get_info_by", lambda _uid: [{"tenant_id": "tenant-1", "asr_id": ""}])
monkeypatch.setattr(sys.modules["api.db.joint_services.tenant_model_service"].TenantService, "get_by_id", lambda _uid: (True, SimpleNamespace(tenant_id="tenant-1", asr_id="")))
res = _run(module.sequence2txt())
assert res["message"] == "No default ASR model is set"
assert res["message"] == "No default speech2text model is set."
class _SyncAsr:
def transcription(self, _path):
@ -538,7 +551,8 @@ def test_sequence2txt_validation_and_transcription_paths(monkeypatch):
wav_file = _DummyUploadedFile("audio.wav")
monkeypatch.setattr(module, "request", _DummyRequest(form={"stream": "false"}, files={"file": wav_file}))
monkeypatch.setattr(module.TenantService, "get_info_by", lambda _uid: [{"tenant_id": "tenant-1", "asr_id": "asr-model"}])
monkeypatch.setattr(sys.modules["api.db.joint_services.tenant_model_service"].TenantService, "get_by_id", lambda _uid: (True, SimpleNamespace(tenant_id="tenant-1", asr_id="asr-model")))
monkeypatch.setattr(module.TenantLLMService, "get_api_key", lambda tenant_id, model_name: SimpleNamespace(to_dict=lambda: {"llm_factory": "test", "llm_name": "asr-model"}))
monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: _SyncAsr())
monkeypatch.setattr(module.os, "remove", lambda _path: (_ for _ in ()).throw(RuntimeError("remove failed")))
res = _run(module.sequence2txt())
@ -579,13 +593,13 @@ def test_sequence2txt_validation_and_transcription_paths(monkeypatch):
def test_tts_request_parse_entry(monkeypatch):
module = _load_conversation_module(monkeypatch)
_set_request_json(monkeypatch, module, {"text": "A。B"})
monkeypatch.setattr(module.TenantService, "get_info_by", lambda _uid: [])
monkeypatch.setattr(sys.modules["api.db.joint_services.tenant_model_service"].TenantService, "get_by_id", lambda _uid: (False, None))
res = _run(module.tts())
assert res["message"] == "Tenant not found!"
assert res["message"] == "Tenant not found"
monkeypatch.setattr(module.TenantService, "get_info_by", lambda _uid: [{"tenant_id": "tenant-1", "tts_id": ""}])
monkeypatch.setattr(sys.modules["api.db.joint_services.tenant_model_service"].TenantService, "get_by_id", lambda _uid: (True, SimpleNamespace(tenant_id="tenant-1", tts_id="")))
res = _run(module.tts())
assert res["message"] == "No default TTS model is set"
assert res["message"] == "No default tts model is set."
class _TTSOk:
def tts(self, txt):
@ -593,7 +607,8 @@ def test_tts_request_parse_entry(monkeypatch):
return []
yield f"chunk-{txt}".encode("utf-8")
monkeypatch.setattr(module.TenantService, "get_info_by", lambda _uid: [{"tenant_id": "tenant-1", "tts_id": "tts-x"}])
monkeypatch.setattr(sys.modules["api.db.joint_services.tenant_model_service"].TenantService, "get_by_id", lambda _uid: (True, SimpleNamespace(tenant_id="tenant-1", tts_id="tts-x")))
monkeypatch.setattr(module.TenantLLMService, "get_api_key", lambda tenant_id, model_name: SimpleNamespace(to_dict=lambda: {"llm_factory": "test", "llm_name": model_name}))
monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: _TTSOk())
resp = _run(module.tts())
assert resp.mimetype == "audio/mpeg"
@ -749,18 +764,18 @@ def test_mindmap_and_related_questions_matrix_unit(monkeypatch):
llm_calls["options"] = options
return "1. Alpha\n2. Beta\nignored"
def _fake_bundle(tenant_id, llm_type, chat_id):
llm_calls["bundle"] = (tenant_id, llm_type, chat_id)
def _fake_bundle(tenant_id, model_config, lang="Chinese", **kwargs):
llm_calls["bundle"] = (tenant_id, model_config)
return _FakeChat()
monkeypatch.setattr(module, "LLMBundle", _fake_bundle)
monkeypatch.setattr(module, "load_prompt", lambda name: f"prompt-{name}")
monkeypatch.setattr(module.TenantLLMService, "get_api_key", lambda tenant_id, model_name: SimpleNamespace(to_dict=lambda: {"llm_factory": "test", "llm_name": model_name}))
_set_request_json(monkeypatch, module, {"question": "solar", "search_id": "search-1"})
res = _run(module.related_questions.__wrapped__())
assert res["code"] == 0
assert res["data"] == ["Alpha", "Beta"]
assert llm_calls["bundle"][0] == "user-1"
assert llm_calls["bundle"][2] == "chat-x"
assert llm_calls["options"] == {"temperature": 0.2}
assert llm_calls["prompt"] == "prompt-related_question"
assert "Keywords: solar" in llm_calls["messages"][0]["content"]

View File

@ -137,11 +137,34 @@ def _load_dialog_module(monkeypatch):
tenant_llm_service_mod = ModuleType("api.db.services.tenant_llm_service")
class _MockTableObject:
def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
def to_dict(self):
return {k: v for k, v in self.__dict__.items()}
class _TenantLLMService:
@staticmethod
def split_model_name_and_factory(embd_id):
return embd_id.split("@")
@staticmethod
def get_api_key(tenant_id, model_name):
return _MockTableObject(
id=1,
tenant_id=tenant_id,
llm_factory="",
model_type="chat",
llm_name=model_name,
api_key="fake-api-key",
api_base="https://api.example.com",
max_tokens=8192,
used_tokens=0,
status=1
)
tenant_llm_service_mod.TenantLLMService = _TenantLLMService
monkeypatch.setitem(sys.modules, "api.db.services.tenant_llm_service", tenant_llm_service_mod)
@ -253,8 +276,8 @@ def test_set_dialog_branch_matrix_unit(monkeypatch):
monkeypatch.setattr(module, "duplicate_name", _dup_name)
monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: [SimpleNamespace(name="new dialog")])
monkeypatch.setattr(module.TenantService, "get_by_id", lambda _id: (True, SimpleNamespace(llm_id="llm-x")))
monkeypatch.setattr(module.KnowledgebaseService, "get_by_ids", lambda _ids: [SimpleNamespace(embd_id="embd-a@builtin")])
monkeypatch.setattr(module.TenantService, "get_by_id", lambda _id: (True, SimpleNamespace(llm_id="llm-x", tenant_llm_id=1)))
monkeypatch.setattr(module.KnowledgebaseService, "get_by_ids", lambda _ids: [SimpleNamespace(embd_id="embd-a@builtin", tenant_embd_id=2)])
monkeypatch.setattr(module.TenantLLMService, "split_model_name_and_factory", lambda embd_id: embd_id.split("@"))
monkeypatch.setattr(module.DialogService, "save", lambda **kwargs: captured.update(kwargs) or False)
_set_request_json(
@ -301,7 +324,7 @@ def test_set_dialog_branch_matrix_unit(monkeypatch):
res = _run(handler())
assert res["message"] == "Tenant not found!"
monkeypatch.setattr(module.TenantService, "get_by_id", lambda _id: (True, SimpleNamespace(llm_id="llm-x")))
monkeypatch.setattr(module.TenantService, "get_by_id", lambda _id: (True, SimpleNamespace(llm_id="llm-x", tenant_llm_id=1)))
monkeypatch.setattr(
module,
"get_request_json",
@ -316,7 +339,7 @@ def test_set_dialog_branch_matrix_unit(monkeypatch):
monkeypatch.setattr(
module.KnowledgebaseService,
"get_by_ids",
lambda _ids: [SimpleNamespace(embd_id="embd-a@f1"), SimpleNamespace(embd_id="embd-b@f2")],
lambda _ids: [SimpleNamespace(embd_id="embd-a@f1", tenant_embd_id=2), SimpleNamespace(embd_id="embd-b@f2", tenant_embd_id=2)],
)
monkeypatch.setattr(module.TenantLLMService, "split_model_name_and_factory", lambda embd_id: embd_id.split("@"))
res = _run(handler())

View File

@ -51,11 +51,19 @@ class _DummyTenantLLMModel:
llm_factory = _ExprField("llm_factory")
llm_name = _ExprField("llm_name")
def __init__(self, id=None, **kwargs):
self.id = id
self.api_key = None
self.status = None
for key, value in kwargs.items():
setattr(self, key, value)
class _TenantLLMRow:
def __init__(
self,
*,
id,
llm_name,
llm_factory,
model_type,
@ -65,6 +73,7 @@ class _TenantLLMRow:
api_base="",
max_tokens=8192,
):
self.id = id
self.llm_name = llm_name
self.llm_factory = llm_factory
self.model_type = model_type
@ -76,6 +85,7 @@ class _TenantLLMRow:
def to_dict(self):
return {
"id": self.id,
"llm_name": self.llm_name,
"llm_factory": self.llm_factory,
"model_type": self.model_type,
@ -246,8 +256,8 @@ def test_list_app_grouping_availability_and_merge(monkeypatch):
monkeypatch.setattr(module.TenantLLMService, "ensure_mineru_from_env", lambda tenant_id: ensure_calls.append(tenant_id))
tenant_rows = [
_TenantLLMRow(llm_name="fast-emb", llm_factory="FastEmbed", model_type="embedding", api_key="k1", status="1"),
_TenantLLMRow(llm_name="tenant-only", llm_factory="CustomFactory", model_type="chat", api_key="k2", status="1"),
_TenantLLMRow(id=1, llm_name="fast-emb", llm_factory="FastEmbed", model_type="embedding", api_key="k1", status="1"),
_TenantLLMRow(id=2, llm_name="tenant-only", llm_factory="CustomFactory", model_type="chat", api_key="k2", status="1"),
]
monkeypatch.setattr(module.TenantLLMService, "query", lambda **_kwargs: tenant_rows)
@ -263,7 +273,7 @@ def test_list_app_grouping_availability_and_merge(monkeypatch):
monkeypatch.setenv("TEI_MODEL", "tei-embed")
res = _run(module.list_app())
assert res["code"] == 0
assert res["code"] == 0, res["message"]
assert ensure_calls == ["tenant-1"]
data = res["data"]
@ -291,8 +301,8 @@ def test_list_app_model_type_filter(monkeypatch):
module.TenantLLMService,
"query",
lambda **_kwargs: [
_TenantLLMRow(llm_name="fast-emb", llm_factory="FastEmbed", model_type="embedding", api_key="k1", status="1"),
_TenantLLMRow(llm_name="tenant-only", llm_factory="CustomFactory", model_type="chat", api_key="k2", status="1"),
_TenantLLMRow(id=1, llm_name="fast-emb", llm_factory="FastEmbed", model_type="embedding", api_key="k1", status="1"),
_TenantLLMRow(id=2, llm_name="tenant-only", llm_factory="CustomFactory", model_type="chat", api_key="k2", status="1"),
],
)
monkeypatch.setattr(
@ -306,7 +316,7 @@ def test_list_app_model_type_filter(monkeypatch):
monkeypatch.setattr(module, "request", SimpleNamespace(args={"model_type": "chat"}))
res = _run(module.list_app())
assert res["code"] == 0
assert res["code"] == 0, res["message"]
assert list(res["data"].keys()) == ["CustomFactory"]
assert res["data"]["CustomFactory"][0]["model_type"] == "chat"
@ -799,7 +809,7 @@ def test_add_llm_model_type_probe_and_persistence_matrix_unit(monkeypatch):
monkeypatch.setattr(module.TenantLLMService, "filter_update", lambda _filters, _payload: False)
monkeypatch.setattr(module.TenantLLMService, "save", lambda **kwargs: saved.append(kwargs) or True)
res = _call({"llm_factory": "FChatPass", "llm_name": "m", "model_type": module.LLMType.CHAT.value, "api_key": "k"})
assert res["code"] == 0
assert res["code"] == 0, res["message"]
assert res["data"] is True
assert saved
assert saved[0]["llm_factory"] == "FChatPass"
@ -841,6 +851,7 @@ def test_my_llms_include_details_and_exception_unit(monkeypatch):
"query",
lambda **_kwargs: [
_TenantLLMRow(
id=1,
llm_name="chat-model",
llm_factory="FactoryX",
model_type="chat",

View File

@ -32,7 +32,7 @@ def add_memory_func(request, WebApiAuth):
payload = {
"name": f"test_memory_{i}",
"memory_type": ["raw"] + random.choices(["semantic", "episodic", "procedural"], k=random.randint(0, 3)),
"embd_id": "BAAI/bge-large-zh-v1.5@SILICONFLOW",
"embd_id": "BAAI/bge-small-en-v1.5@Builtin",
"llm_id": "glm-4-flash@ZHIPU-AI"
}
res = create_memory(WebApiAuth, payload)

View File

@ -45,7 +45,7 @@ class TestMemoryCreate:
payload = {
"name": name,
"memory_type": ["raw"] + random.choices(["semantic", "episodic", "procedural"], k=random.randint(0, 3)),
"embd_id": "BAAI/bge-large-zh-v1.5@SILICONFLOW",
"embd_id": "BAAI/bge-small-en-v1.5@Builtin",
"llm_id": "glm-4-flash@ZHIPU-AI"
}
res = create_memory(WebApiAuth, payload)
@ -68,7 +68,7 @@ class TestMemoryCreate:
payload = {
"name": name,
"memory_type": ["raw"] + random.choices(["semantic", "episodic", "procedural"], k=random.randint(0, 3)),
"embd_id": "BAAI/bge-large-zh-v1.5@SILICONFLOW",
"embd_id": "BAAI/bge-small-en-v1.5@Builtin",
"llm_id": "glm-4-flash@ZHIPU-AI"
}
res = create_memory(WebApiAuth, payload)
@ -80,7 +80,7 @@ class TestMemoryCreate:
payload = {
"name": name,
"memory_type": ["something"],
"embd_id": "BAAI/bge-large-zh-v1.5@SILICONFLOW",
"embd_id": "BAAI/bge-small-en-v1.5@Builtin",
"llm_id": "glm-4-flash@ZHIPU-AI"
}
res = create_memory(WebApiAuth, payload)
@ -92,7 +92,7 @@ class TestMemoryCreate:
payload = {
"name": name,
"memory_type": ["raw"] + random.choices(["semantic", "episodic", "procedural"], k=random.randint(0, 3)),
"embd_id": "BAAI/bge-large-zh-v1.5@SILICONFLOW",
"embd_id": "BAAI/bge-small-en-v1.5@Builtin",
"llm_id": "glm-4-flash@ZHIPU-AI"
}
res1 = create_memory(WebApiAuth, payload)

View File

@ -216,11 +216,34 @@ def _load_user_app(monkeypatch):
tenant_llm_service_mod = ModuleType("api.db.services.tenant_llm_service")
class _MockTableObject:
def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
def to_dict(self):
return {k: v for k, v in self.__dict__.items()}
class _StubTenantLLMService:
@staticmethod
def insert_many(_payload):
return True
@staticmethod
def get_api_key(tenant_id, model_name):
return _MockTableObject(
id=1,
tenant_id=tenant_id,
llm_factory="",
model_type="chat",
llm_name=model_name,
api_key="fake-api-key",
api_base="https://api.example.com",
max_tokens=8192,
used_tokens=0,
status=1
)
tenant_llm_service_mod.TenantLLMService = _StubTenantLLMService
monkeypatch.setitem(sys.modules, "api.db.services.tenant_llm_service", tenant_llm_service_mod)