mirror of
https://github.com/langgenius/dify.git
synced 2026-05-29 05:07:55 +08:00
Merge branch 'main' into feat/evaluation-fe
This commit is contained in:
1
.github/CODEOWNERS
vendored
1
.github/CODEOWNERS
vendored
@ -166,6 +166,7 @@
|
||||
|
||||
# Frontend - App - API Documentation
|
||||
/web/app/components/develop/ @JzoNgKVO @iamjoel
|
||||
/web/app/components/develop/template/*.mdx @JzoNgKVO @iamjoel @RiskeyL
|
||||
|
||||
# Frontend - App - Logs and Annotations
|
||||
/web/app/components/app/workflow-log/ @JzoNgKVO @iamjoel
|
||||
|
||||
@ -11,7 +11,7 @@ from controllers.console.app.error import (
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required, with_current_tenant_id
|
||||
from core.app.app_config.entities import ModelConfig
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||
@ -22,7 +22,7 @@ from core.llm_generator.llm_generator import LLMGenerator
|
||||
from extensions.ext_database import db
|
||||
from graphon.model_runtime.entities.llm_entities import LLMMode
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models import App
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
@ -64,9 +64,9 @@ class RuleGenerateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str):
|
||||
args = RuleGeneratePayload.model_validate(console_ns.payload)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
try:
|
||||
rules = LLMGenerator.generate_rule_config(tenant_id=current_tenant_id, args=args)
|
||||
@ -93,9 +93,9 @@ class RuleCodeGenerateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str):
|
||||
args = RuleCodeGeneratePayload.model_validate(console_ns.payload)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
try:
|
||||
code_result = LLMGenerator.generate_code(
|
||||
@ -125,9 +125,9 @@ class RuleStructuredOutputGenerateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str):
|
||||
args = RuleStructuredOutputPayload.model_validate(console_ns.payload)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
try:
|
||||
structured_output = LLMGenerator.generate_structured_output(
|
||||
@ -157,9 +157,9 @@ class InstructionGenerateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str):
|
||||
args = InstructionGeneratePayload.model_validate(console_ns.payload)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
|
||||
code_provider: type[CodeNodeProvider] | None = next(
|
||||
(p for p in providers if p.is_accept_language(args.language)), None
|
||||
|
||||
@ -11,11 +11,16 @@ from werkzeug.exceptions import NotFound
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from libs.helper import to_timestamp
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models.enums import AppMCPServerStatus
|
||||
from models.model import App, AppMCPServer
|
||||
|
||||
@ -92,8 +97,8 @@ class AppMCPServerController(Resource):
|
||||
@login_required
|
||||
@setup_required
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, app_model: App):
|
||||
payload = MCPServerCreatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
description = payload.description
|
||||
@ -163,8 +168,8 @@ class AppMCPServerRefreshController(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self, server_id: UUID):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, server_id: UUID):
|
||||
server = db.session.scalar(
|
||||
select(AppMCPServer)
|
||||
.where(AppMCPServer.id == server_id, AppMCPServer.tenant_id == current_tenant_id)
|
||||
|
||||
@ -8,12 +8,17 @@ from pydantic import BaseModel, Field, field_validator
|
||||
from controllers.common.fields import SimpleResultResponse
|
||||
from controllers.common.schema import register_enum_models, register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
is_admin_or_owner_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
)
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from graphon.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from services.model_load_balancing_service import ModelLoadBalancingService
|
||||
from services.model_provider_service import ModelProviderService
|
||||
|
||||
@ -138,9 +143,8 @@ class DefaultModelApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str):
|
||||
args = ParserGetDefault.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
@ -156,9 +160,8 @@ class DefaultModelApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str):
|
||||
args = ParserPostDefault.model_validate(console_ns.payload)
|
||||
model_provider_service = ModelProviderService()
|
||||
model_settings = args.model_settings
|
||||
@ -189,9 +192,8 @@ class ModelProviderModelApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, provider):
|
||||
model_provider_service = ModelProviderService()
|
||||
models = model_provider_service.get_models_by_provider(tenant_id=tenant_id, provider=provider)
|
||||
|
||||
@ -202,9 +204,9 @@ class ModelProviderModelApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, provider: str):
|
||||
# To save the model's load balance configs
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
args = ParserPostModels.model_validate(console_ns.payload)
|
||||
|
||||
if args.config_from == "custom-model":
|
||||
@ -249,9 +251,8 @@ class ModelProviderModelApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def delete(self, tenant_id: str, provider: str):
|
||||
args = ParserDeleteModels.model_validate(console_ns.payload)
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
@ -268,9 +269,8 @@ class ModelProviderModelCredentialApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, provider: str):
|
||||
args = ParserGetCredentials.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
@ -323,9 +323,8 @@ class ModelProviderModelCredentialApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, provider: str):
|
||||
args = ParserCreateCredential.model_validate(console_ns.payload)
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
@ -355,8 +354,8 @@ class ModelProviderModelCredentialApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def put(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def put(self, current_tenant_id: str, provider: str):
|
||||
args = ParserUpdateCredential.model_validate(console_ns.payload)
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
@ -382,8 +381,8 @@ class ModelProviderModelCredentialApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def delete(self, current_tenant_id: str, provider: str):
|
||||
args = ParserDeleteCredential.model_validate(console_ns.payload)
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
@ -406,8 +405,8 @@ class ModelProviderModelCredentialSwitchApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, provider: str):
|
||||
args = ParserSwitch.model_validate(console_ns.payload)
|
||||
|
||||
service = ModelProviderService()
|
||||
@ -430,9 +429,8 @@ class ModelProviderModelEnableApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, provider: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def patch(self, tenant_id: str, provider: str):
|
||||
args = ParserDeleteModels.model_validate(console_ns.payload)
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
@ -452,9 +450,8 @@ class ModelProviderModelDisableApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, provider: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def patch(self, tenant_id: str, provider: str):
|
||||
args = ParserDeleteModels.model_validate(console_ns.payload)
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
@ -480,8 +477,8 @@ class ModelProviderModelValidateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, provider: str):
|
||||
args = ParserValidate.model_validate(console_ns.payload)
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
@ -515,9 +512,9 @@ class ModelProviderModelParameterRuleApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider: str):
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, provider: str):
|
||||
args = ParserParameter.model_validate(request.args.to_dict(flat=True))
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
parameter_rules = model_provider_service.get_model_parameter_rules(
|
||||
@ -532,8 +529,8 @@ class ModelProviderAvailableModelApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, model_type: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, model_type: str):
|
||||
model_provider_service = ModelProviderService()
|
||||
models = model_provider_service.get_models_by_model_type(tenant_id=tenant_id, model_type=model_type)
|
||||
|
||||
|
||||
@ -34,7 +34,6 @@ def test_rule_generate_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = generator_module.RuleGenerateApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||
monkeypatch.setattr(generator_module.LLMGenerator, "generate_rule_config", lambda **_kwargs: {"rules": []})
|
||||
|
||||
with app.test_request_context(
|
||||
@ -42,7 +41,7 @@ def test_rule_generate_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
method="POST",
|
||||
json={"instruction": "do it", "model_config": _model_config_payload()},
|
||||
):
|
||||
response = method()
|
||||
response = method("t1")
|
||||
|
||||
assert response == {"rules": []}
|
||||
|
||||
@ -51,8 +50,6 @@ def test_rule_code_generate_maps_token_error(app, monkeypatch: pytest.MonkeyPatc
|
||||
api = generator_module.RuleCodeGenerateApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||
|
||||
def _raise(*_args, **_kwargs):
|
||||
raise ProviderTokenNotInitError("missing token")
|
||||
|
||||
@ -64,15 +61,13 @@ def test_rule_code_generate_maps_token_error(app, monkeypatch: pytest.MonkeyPatc
|
||||
json={"instruction": "do it", "model_config": _model_config_payload()},
|
||||
):
|
||||
with pytest.raises(ProviderNotInitializeError):
|
||||
method()
|
||||
method("t1")
|
||||
|
||||
|
||||
def test_instruction_generate_app_not_found(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = generator_module.InstructionGenerateApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||
|
||||
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: None))
|
||||
|
||||
with app.test_request_context(
|
||||
@ -85,7 +80,7 @@ def test_instruction_generate_app_not_found(app, monkeypatch: pytest.MonkeyPatch
|
||||
"model_config": _model_config_payload(),
|
||||
},
|
||||
):
|
||||
response, status = method()
|
||||
response, status = method("t1")
|
||||
|
||||
assert status == 400
|
||||
assert response["error"] == "app app-1 not found"
|
||||
@ -95,8 +90,6 @@ def test_instruction_generate_workflow_not_found(app, monkeypatch: pytest.Monkey
|
||||
api = generator_module.InstructionGenerateApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: app_model))
|
||||
_install_workflow_service(monkeypatch, workflow=None)
|
||||
@ -111,7 +104,7 @@ def test_instruction_generate_workflow_not_found(app, monkeypatch: pytest.Monkey
|
||||
"model_config": _model_config_payload(),
|
||||
},
|
||||
):
|
||||
response, status = method()
|
||||
response, status = method("t1")
|
||||
|
||||
assert status == 400
|
||||
assert response["error"] == "workflow app-1 not found"
|
||||
@ -121,8 +114,6 @@ def test_instruction_generate_node_missing(app, monkeypatch: pytest.MonkeyPatch)
|
||||
api = generator_module.InstructionGenerateApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: app_model))
|
||||
|
||||
@ -139,7 +130,7 @@ def test_instruction_generate_node_missing(app, monkeypatch: pytest.MonkeyPatch)
|
||||
"model_config": _model_config_payload(),
|
||||
},
|
||||
):
|
||||
response, status = method()
|
||||
response, status = method("t1")
|
||||
|
||||
assert status == 400
|
||||
assert response["error"] == "node node-1 not found"
|
||||
@ -149,8 +140,6 @@ def test_instruction_generate_code_node(app, monkeypatch: pytest.MonkeyPatch) ->
|
||||
api = generator_module.InstructionGenerateApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: app_model))
|
||||
|
||||
@ -174,7 +163,7 @@ def test_instruction_generate_code_node(app, monkeypatch: pytest.MonkeyPatch) ->
|
||||
"model_config": _model_config_payload(),
|
||||
},
|
||||
):
|
||||
response = method()
|
||||
response = method("t1")
|
||||
|
||||
assert response == {"code": "x"}
|
||||
|
||||
@ -183,7 +172,6 @@ def test_instruction_generate_legacy_modify(app, monkeypatch: pytest.MonkeyPatch
|
||||
api = generator_module.InstructionGenerateApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||
monkeypatch.setattr(
|
||||
generator_module.LLMGenerator,
|
||||
"instruction_modify_legacy",
|
||||
@ -201,7 +189,7 @@ def test_instruction_generate_legacy_modify(app, monkeypatch: pytest.MonkeyPatch
|
||||
"model_config": _model_config_payload(),
|
||||
},
|
||||
):
|
||||
response = method()
|
||||
response = method("t1")
|
||||
|
||||
assert response == {"instruction": "ok"}
|
||||
|
||||
@ -210,8 +198,6 @@ def test_instruction_generate_incompatible_params(app, monkeypatch: pytest.Monke
|
||||
api = generator_module.InstructionGenerateApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/instruction-generate",
|
||||
method="POST",
|
||||
@ -223,7 +209,7 @@ def test_instruction_generate_incompatible_params(app, monkeypatch: pytest.Monke
|
||||
"model_config": _model_config_payload(),
|
||||
},
|
||||
):
|
||||
response, status = method()
|
||||
response, status = method("t1")
|
||||
|
||||
assert status == 400
|
||||
assert response["error"] == "incompatible parameters"
|
||||
|
||||
@ -121,7 +121,6 @@ class TestAppMCPServerController:
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload),
|
||||
patch("controllers.console.app.mcp_server.current_account_with_tenant", return_value=(None, "tenant-1")),
|
||||
patch("controllers.console.app.mcp_server.db.session.add"),
|
||||
patch("controllers.console.app.mcp_server.db.session.commit"),
|
||||
patch("controllers.console.app.mcp_server.AppMCPServer.generate_server_code", return_value="server-code"),
|
||||
@ -131,7 +130,7 @@ class TestAppMCPServerController:
|
||||
),
|
||||
):
|
||||
response, status_code = method(
|
||||
api, app_model=SimpleNamespace(id="app-1", name="Demo App", description="App description")
|
||||
api, "tenant-1", app_model=SimpleNamespace(id="app-1", name="Demo App", description="App description")
|
||||
)
|
||||
|
||||
assert response == {"id": "server-1"}
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
@ -34,15 +34,11 @@ class TestDefaultModelApi:
|
||||
"/",
|
||||
query_string={"model_type": ModelType.LLM},
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.models.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch("controllers.console.workspace.models.ModelProviderService") as service_mock,
|
||||
):
|
||||
service_mock.return_value.get_default_model_of_model_type.return_value = {"model": "gpt-4"}
|
||||
|
||||
result = method(api)
|
||||
result = method(api, "tenant1")
|
||||
|
||||
assert "data" in result
|
||||
|
||||
@ -62,13 +58,9 @@ class TestDefaultModelApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.models.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch("controllers.console.workspace.models.ModelProviderService"),
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, "tenant1")
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
@ -78,12 +70,11 @@ class TestDefaultModelApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/", query_string={"model_type": ModelType.LLM}),
|
||||
patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")),
|
||||
patch("controllers.console.workspace.models.ModelProviderService") as service,
|
||||
):
|
||||
service.return_value.get_default_model_of_model_type.return_value = None
|
||||
|
||||
result = method(api)
|
||||
result = method(api, "t1")
|
||||
|
||||
assert "data" in result
|
||||
|
||||
@ -95,15 +86,11 @@ class TestModelProviderModelApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.models.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch("controllers.console.workspace.models.ModelProviderService") as service_mock,
|
||||
):
|
||||
service_mock.return_value.get_models_by_provider.return_value = []
|
||||
|
||||
result = method(api, "openai")
|
||||
result = method(api, "tenant1", "openai")
|
||||
|
||||
assert "data" in result
|
||||
|
||||
@ -122,14 +109,10 @@ class TestModelProviderModelApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.models.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch("controllers.console.workspace.models.ModelProviderService"),
|
||||
patch("controllers.console.workspace.models.ModelLoadBalancingService"),
|
||||
):
|
||||
result, status = method(api, "openai")
|
||||
result, status = method(api, "tenant1", "openai")
|
||||
|
||||
assert status == 200
|
||||
|
||||
@ -144,13 +127,9 @@ class TestModelProviderModelApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.models.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch("controllers.console.workspace.models.ModelProviderService"),
|
||||
):
|
||||
result, status = method(api, "openai")
|
||||
result, status = method(api, "tenant1", "openai")
|
||||
|
||||
assert status == 204
|
||||
|
||||
@ -160,12 +139,11 @@ class TestModelProviderModelApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")),
|
||||
patch("controllers.console.workspace.models.ModelProviderService") as service,
|
||||
):
|
||||
service.return_value.get_models_by_provider.return_value = []
|
||||
|
||||
result = method(api, "openai")
|
||||
result = method(api, "t1", "openai")
|
||||
|
||||
assert "data" in result
|
||||
|
||||
@ -183,10 +161,6 @@ class TestModelProviderModelCredentialApi:
|
||||
"model_type": ModelType.LLM,
|
||||
},
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.models.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch("controllers.console.workspace.models.ModelProviderService") as provider_service,
|
||||
patch("controllers.console.workspace.models.ModelLoadBalancingService") as lb_service,
|
||||
):
|
||||
@ -198,7 +172,7 @@ class TestModelProviderModelCredentialApi:
|
||||
provider_service.return_value.provider_manager.get_provider_model_available_credentials.return_value = []
|
||||
lb_service.return_value.get_load_balancing_configs.return_value = (False, [])
|
||||
|
||||
result = method(api, "openai")
|
||||
result = method(api, "tenant1", "openai")
|
||||
|
||||
assert "credentials" in result
|
||||
|
||||
@ -214,13 +188,9 @@ class TestModelProviderModelCredentialApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.models.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch("controllers.console.workspace.models.ModelProviderService"),
|
||||
):
|
||||
result, status = method(api, "openai")
|
||||
result, status = method(api, "tenant1", "openai")
|
||||
|
||||
assert status == 201
|
||||
|
||||
@ -230,7 +200,6 @@ class TestModelProviderModelCredentialApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/", query_string={"model": "gpt", "model_type": ModelType.LLM}),
|
||||
patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")),
|
||||
patch("controllers.console.workspace.models.ModelProviderService") as service,
|
||||
patch("controllers.console.workspace.models.ModelLoadBalancingService") as lb,
|
||||
):
|
||||
@ -238,7 +207,7 @@ class TestModelProviderModelCredentialApi:
|
||||
service.return_value.provider_manager.get_provider_model_available_credentials.return_value = []
|
||||
lb.return_value.get_load_balancing_configs.return_value = (False, [])
|
||||
|
||||
result = method(api, "openai")
|
||||
result = method(api, "t1", "openai")
|
||||
|
||||
assert result["credentials"] == {}
|
||||
|
||||
@ -254,10 +223,9 @@ class TestModelProviderModelCredentialApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")),
|
||||
patch("controllers.console.workspace.models.ModelProviderService"),
|
||||
):
|
||||
result, status = method(api, "openai")
|
||||
result, status = method(api, "t1", "openai")
|
||||
|
||||
assert status == 204
|
||||
|
||||
@ -275,13 +243,9 @@ class TestModelProviderModelCredentialSwitchApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.models.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch("controllers.console.workspace.models.ModelProviderService"),
|
||||
):
|
||||
result = method(api, "openai")
|
||||
result = method(api, "tenant1", "openai")
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
@ -298,13 +262,9 @@ class TestModelEnableDisableApis:
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.models.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch("controllers.console.workspace.models.ModelProviderService"),
|
||||
):
|
||||
result = method(api, "openai")
|
||||
result = method(api, "tenant1", "openai")
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
@ -319,13 +279,9 @@ class TestModelEnableDisableApis:
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.models.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch("controllers.console.workspace.models.ModelProviderService"),
|
||||
):
|
||||
result = method(api, "openai")
|
||||
result = method(api, "tenant1", "openai")
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
@ -343,13 +299,9 @@ class TestModelProviderModelValidateApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.models.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch("controllers.console.workspace.models.ModelProviderService"),
|
||||
):
|
||||
result = method(api, "openai")
|
||||
result = method(api, "tenant1", "openai")
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
@ -366,15 +318,11 @@ class TestModelProviderModelValidateApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.models.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch("controllers.console.workspace.models.ModelProviderService") as service_mock,
|
||||
):
|
||||
service_mock.return_value.validate_model_credentials.side_effect = CredentialsValidateFailedError("invalid")
|
||||
|
||||
result = method(api, "openai")
|
||||
result = method(api, "tenant1", "openai")
|
||||
|
||||
assert result["result"] == "error"
|
||||
|
||||
@ -386,15 +334,11 @@ class TestParameterAndAvailableModels:
|
||||
|
||||
with (
|
||||
app.test_request_context("/", query_string={"model": "gpt-4"}),
|
||||
patch(
|
||||
"controllers.console.workspace.models.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch("controllers.console.workspace.models.ModelProviderService") as service_mock,
|
||||
):
|
||||
service_mock.return_value.get_model_parameter_rules.return_value = []
|
||||
|
||||
result = method(api, "openai")
|
||||
result = method(api, "tenant1", "openai")
|
||||
|
||||
assert "data" in result
|
||||
|
||||
@ -404,15 +348,11 @@ class TestParameterAndAvailableModels:
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.models.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch("controllers.console.workspace.models.ModelProviderService") as service_mock,
|
||||
):
|
||||
service_mock.return_value.get_models_by_model_type.return_value = []
|
||||
|
||||
result = method(api, ModelType.LLM)
|
||||
result = method(api, "tenant1", ModelType.LLM)
|
||||
|
||||
assert "data" in result
|
||||
|
||||
@ -422,12 +362,11 @@ class TestParameterAndAvailableModels:
|
||||
|
||||
with (
|
||||
app.test_request_context("/", query_string={"model": "gpt"}),
|
||||
patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")),
|
||||
patch("controllers.console.workspace.models.ModelProviderService") as service,
|
||||
):
|
||||
service.return_value.get_model_parameter_rules.return_value = []
|
||||
|
||||
result = method(api, "openai")
|
||||
result = method(api, "t1", "openai")
|
||||
|
||||
assert result["data"] == []
|
||||
|
||||
@ -437,11 +376,10 @@ class TestParameterAndAvailableModels:
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")),
|
||||
patch("controllers.console.workspace.models.ModelProviderService") as service,
|
||||
):
|
||||
service.return_value.get_models_by_model_type.return_value = []
|
||||
|
||||
result = method(api, ModelType.LLM)
|
||||
result = method(api, "t1", ModelType.LLM)
|
||||
|
||||
assert result["data"] == []
|
||||
|
||||
Reference in New Issue
Block a user