mirror of
https://github.com/langgenius/dify.git
synced 2026-03-30 10:30:16 +08:00
Merge branch 'feat/model-plugins-implementing' into deploy/dev
# Conflicts: # web/app/components/workflow/nodes/http/components/key-value/key-value-edit/index.tsx # web/app/components/workflow/nodes/human-input/components/delivery-method/recipient/email-item.tsx # web/app/components/workflow/nodes/trigger-webhook/components/generic-table.tsx
This commit is contained in:
@ -138,20 +138,25 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
query = self.application_generate_entity.query
|
||||
|
||||
# moderation
|
||||
if self.handle_input_moderation(
|
||||
stop, new_inputs, new_query = self.handle_input_moderation(
|
||||
app_record=self._app,
|
||||
app_generate_entity=self.application_generate_entity,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
message_id=self.message.id,
|
||||
):
|
||||
)
|
||||
if stop:
|
||||
return
|
||||
|
||||
self.application_generate_entity.inputs = new_inputs
|
||||
self.application_generate_entity.query = new_query
|
||||
system_inputs.query = new_query
|
||||
|
||||
# annotation reply
|
||||
if self.handle_annotation_reply(
|
||||
app_record=self._app,
|
||||
message=self.message,
|
||||
query=query,
|
||||
query=new_query,
|
||||
app_generate_entity=self.application_generate_entity,
|
||||
):
|
||||
return
|
||||
@ -163,7 +168,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
# init variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables=system_inputs,
|
||||
user_inputs=inputs,
|
||||
user_inputs=new_inputs,
|
||||
environment_variables=self._workflow.environment_variables,
|
||||
# Based on the definition of `Variable`,
|
||||
# `VariableBase` instances can be safely used as `Variable` since they are compatible.
|
||||
@ -240,10 +245,10 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
inputs: Mapping[str, Any],
|
||||
query: str,
|
||||
message_id: str,
|
||||
) -> bool:
|
||||
) -> tuple[bool, Mapping[str, Any], str]:
|
||||
try:
|
||||
# process sensitive_word_avoidance
|
||||
_, inputs, query = self.moderation_for_inputs(
|
||||
_, new_inputs, new_query = self.moderation_for_inputs(
|
||||
app_id=app_record.id,
|
||||
tenant_id=app_generate_entity.app_config.tenant_id,
|
||||
app_generate_entity=app_generate_entity,
|
||||
@ -253,9 +258,9 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
)
|
||||
except ModerationError as e:
|
||||
self._complete_with_stream_output(text=str(e), stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION)
|
||||
return True
|
||||
return True, inputs, query
|
||||
|
||||
return False
|
||||
return False, new_inputs, new_query
|
||||
|
||||
def handle_annotation_reply(
|
||||
self, app_record: App, message: Message, query: str, app_generate_entity: AdvancedChatAppGenerateEntity
|
||||
|
||||
@ -33,6 +33,8 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ToolTransformService:
|
||||
_MCP_SCHEMA_TYPE_RESOLUTION_MAX_DEPTH = 10
|
||||
|
||||
@classmethod
|
||||
def get_tool_provider_icon_url(
|
||||
cls, provider_type: str, provider_name: str, icon: str | Mapping[str, str]
|
||||
@ -435,6 +437,46 @@ class ToolTransformService:
|
||||
:return: list of ToolParameter instances
|
||||
"""
|
||||
|
||||
def resolve_property_type(prop: dict[str, Any], depth: int = 0) -> str:
|
||||
"""
|
||||
Resolve a JSON schema property type while guarding against cyclic or deeply nested unions.
|
||||
"""
|
||||
if depth >= ToolTransformService._MCP_SCHEMA_TYPE_RESOLUTION_MAX_DEPTH:
|
||||
return "string"
|
||||
prop_type = prop.get("type")
|
||||
if isinstance(prop_type, list):
|
||||
non_null_types = [type_name for type_name in prop_type if type_name != "null"]
|
||||
if non_null_types:
|
||||
return non_null_types[0]
|
||||
if prop_type:
|
||||
return "string"
|
||||
elif isinstance(prop_type, str):
|
||||
if prop_type == "null":
|
||||
return "string"
|
||||
return prop_type
|
||||
|
||||
for union_key in ("anyOf", "oneOf"):
|
||||
union_schemas = prop.get(union_key)
|
||||
if not isinstance(union_schemas, list):
|
||||
continue
|
||||
|
||||
for union_schema in union_schemas:
|
||||
if not isinstance(union_schema, dict):
|
||||
continue
|
||||
union_type = resolve_property_type(union_schema, depth + 1)
|
||||
if union_type != "null":
|
||||
return union_type
|
||||
|
||||
all_of_schemas = prop.get("allOf")
|
||||
if isinstance(all_of_schemas, list):
|
||||
for all_of_schema in all_of_schemas:
|
||||
if not isinstance(all_of_schema, dict):
|
||||
continue
|
||||
all_of_type = resolve_property_type(all_of_schema, depth + 1)
|
||||
if all_of_type != "null":
|
||||
return all_of_type
|
||||
return "string"
|
||||
|
||||
def create_parameter(
|
||||
name: str, description: str, param_type: str, required: bool, input_schema: dict[str, Any] | None = None
|
||||
) -> ToolParameter:
|
||||
@ -461,10 +503,7 @@ class ToolTransformService:
|
||||
parameters = []
|
||||
for name, prop in props.items():
|
||||
current_description = prop.get("description", "")
|
||||
prop_type = prop.get("type", "string")
|
||||
|
||||
if isinstance(prop_type, list):
|
||||
prop_type = prop_type[0]
|
||||
prop_type = resolve_property_type(prop)
|
||||
if prop_type in TYPE_MAPPING:
|
||||
prop_type = TYPE_MAPPING[prop_type]
|
||||
input_schema = prop if prop_type in COMPLEX_TYPES else None
|
||||
|
||||
@ -125,7 +125,11 @@ class TestAdvancedChatAppRunnerConversationVariables:
|
||||
patch("core.app.apps.advanced_chat.app_runner.select") as mock_select,
|
||||
patch("core.app.apps.advanced_chat.app_runner.db") as mock_db,
|
||||
patch.object(runner, "_init_graph") as mock_init_graph,
|
||||
patch.object(runner, "handle_input_moderation", return_value=False),
|
||||
patch.object(
|
||||
runner,
|
||||
"handle_input_moderation",
|
||||
return_value=(False, mock_app_generate_entity.inputs, mock_app_generate_entity.query),
|
||||
),
|
||||
patch.object(runner, "handle_annotation_reply", return_value=False),
|
||||
patch("core.app.apps.advanced_chat.app_runner.WorkflowEntry") as mock_workflow_entry_class,
|
||||
patch("core.app.apps.advanced_chat.app_runner.GraphRuntimeState") as mock_graph_runtime_state_class,
|
||||
@ -265,7 +269,11 @@ class TestAdvancedChatAppRunnerConversationVariables:
|
||||
patch("core.app.apps.advanced_chat.app_runner.select") as mock_select,
|
||||
patch("core.app.apps.advanced_chat.app_runner.db") as mock_db,
|
||||
patch.object(runner, "_init_graph") as mock_init_graph,
|
||||
patch.object(runner, "handle_input_moderation", return_value=False),
|
||||
patch.object(
|
||||
runner,
|
||||
"handle_input_moderation",
|
||||
return_value=(False, mock_app_generate_entity.inputs, mock_app_generate_entity.query),
|
||||
),
|
||||
patch.object(runner, "handle_annotation_reply", return_value=False),
|
||||
patch("core.app.apps.advanced_chat.app_runner.WorkflowEntry") as mock_workflow_entry_class,
|
||||
patch("core.app.apps.advanced_chat.app_runner.GraphRuntimeState") as mock_graph_runtime_state_class,
|
||||
@ -412,7 +420,11 @@ class TestAdvancedChatAppRunnerConversationVariables:
|
||||
patch("core.app.apps.advanced_chat.app_runner.select") as mock_select,
|
||||
patch("core.app.apps.advanced_chat.app_runner.db") as mock_db,
|
||||
patch.object(runner, "_init_graph") as mock_init_graph,
|
||||
patch.object(runner, "handle_input_moderation", return_value=False),
|
||||
patch.object(
|
||||
runner,
|
||||
"handle_input_moderation",
|
||||
return_value=(False, mock_app_generate_entity.inputs, mock_app_generate_entity.query),
|
||||
),
|
||||
patch.object(runner, "handle_annotation_reply", return_value=False),
|
||||
patch("core.app.apps.advanced_chat.app_runner.WorkflowEntry") as mock_workflow_entry_class,
|
||||
patch("core.app.apps.advanced_chat.app_runner.GraphRuntimeState") as mock_graph_runtime_state_class,
|
||||
|
||||
@ -0,0 +1,170 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
||||
from core.app.entities.queue_entities import QueueStopEvent
|
||||
from core.moderation.base import ModerationError
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def build_runner():
|
||||
"""Construct a minimal AdvancedChatAppRunner with heavy dependencies mocked."""
|
||||
app_id = str(uuid4())
|
||||
workflow_id = str(uuid4())
|
||||
|
||||
# Mocks for constructor args
|
||||
mock_queue_manager = MagicMock()
|
||||
|
||||
mock_conversation = MagicMock()
|
||||
mock_conversation.id = str(uuid4())
|
||||
mock_conversation.app_id = app_id
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.id = str(uuid4())
|
||||
|
||||
mock_workflow = MagicMock()
|
||||
mock_workflow.id = workflow_id
|
||||
mock_workflow.tenant_id = str(uuid4())
|
||||
mock_workflow.app_id = app_id
|
||||
mock_workflow.type = "chat"
|
||||
mock_workflow.graph_dict = {}
|
||||
mock_workflow.environment_variables = []
|
||||
|
||||
mock_app_config = MagicMock()
|
||||
mock_app_config.app_id = app_id
|
||||
mock_app_config.workflow_id = workflow_id
|
||||
mock_app_config.tenant_id = str(uuid4())
|
||||
|
||||
gen = MagicMock(spec=AdvancedChatAppGenerateEntity)
|
||||
gen.app_config = mock_app_config
|
||||
gen.inputs = {"q": "raw"}
|
||||
gen.query = "raw-query"
|
||||
gen.files = []
|
||||
gen.user_id = str(uuid4())
|
||||
gen.invoke_from = InvokeFrom.SERVICE_API
|
||||
gen.workflow_run_id = str(uuid4())
|
||||
gen.task_id = str(uuid4())
|
||||
gen.call_depth = 0
|
||||
gen.single_iteration_run = None
|
||||
gen.single_loop_run = None
|
||||
gen.trace_manager = None
|
||||
|
||||
runner = AdvancedChatAppRunner(
|
||||
application_generate_entity=gen,
|
||||
queue_manager=mock_queue_manager,
|
||||
conversation=mock_conversation,
|
||||
message=mock_message,
|
||||
dialogue_count=1,
|
||||
variable_loader=MagicMock(),
|
||||
workflow=mock_workflow,
|
||||
system_user_id=str(uuid4()),
|
||||
app=MagicMock(),
|
||||
workflow_execution_repository=MagicMock(),
|
||||
workflow_node_execution_repository=MagicMock(),
|
||||
)
|
||||
|
||||
return runner
|
||||
|
||||
|
||||
def _patch_common_run_deps(runner: AdvancedChatAppRunner):
|
||||
"""Context manager that patches common heavy deps used by run()."""
|
||||
return patch.multiple(
|
||||
"core.app.apps.advanced_chat.app_runner",
|
||||
Session=MagicMock(
|
||||
return_value=MagicMock(
|
||||
__enter__=lambda s: s,
|
||||
__exit__=lambda *a, **k: False,
|
||||
scalar=lambda *a, **k: MagicMock(),
|
||||
),
|
||||
),
|
||||
select=MagicMock(),
|
||||
db=MagicMock(engine=MagicMock()),
|
||||
RedisChannel=MagicMock(),
|
||||
redis_client=MagicMock(),
|
||||
WorkflowEntry=MagicMock(**{"return_value.run.return_value": iter([])}),
|
||||
GraphRuntimeState=MagicMock(),
|
||||
)
|
||||
|
||||
|
||||
def test_handle_input_moderation_stops_on_moderation_error(build_runner):
|
||||
runner = build_runner
|
||||
|
||||
# moderation_for_inputs raises ModerationError -> should stop and emit stop event
|
||||
with (
|
||||
patch.object(runner, "moderation_for_inputs", side_effect=ModerationError("blocked")),
|
||||
patch.object(runner, "_complete_with_stream_output") as mock_complete,
|
||||
):
|
||||
stop, new_inputs, new_query = runner.handle_input_moderation(
|
||||
app_record=MagicMock(),
|
||||
app_generate_entity=runner.application_generate_entity,
|
||||
inputs={"k": "v"},
|
||||
query="hello",
|
||||
message_id="mid",
|
||||
)
|
||||
|
||||
assert stop is True
|
||||
# inputs/query should be unchanged on error path
|
||||
assert new_inputs == {"k": "v"}
|
||||
assert new_query == "hello"
|
||||
# ensure stopped_by reason is INPUT_MODERATION
|
||||
assert mock_complete.called
|
||||
args, kwargs = mock_complete.call_args
|
||||
assert kwargs.get("stopped_by") == QueueStopEvent.StopBy.INPUT_MODERATION
|
||||
|
||||
|
||||
def test_run_applies_overridden_inputs_and_query_from_moderation(build_runner):
|
||||
runner = build_runner
|
||||
|
||||
overridden_inputs = {"q": "sanitized"}
|
||||
overridden_query = "sanitized-query"
|
||||
|
||||
with (
|
||||
_patch_common_run_deps(runner),
|
||||
patch.object(
|
||||
runner,
|
||||
"moderation_for_inputs",
|
||||
return_value=(True, overridden_inputs, overridden_query),
|
||||
) as mock_moderate,
|
||||
patch.object(runner, "handle_annotation_reply", return_value=False) as mock_anno,
|
||||
patch.object(runner, "_init_graph", return_value=MagicMock()) as mock_init_graph,
|
||||
):
|
||||
runner.run()
|
||||
|
||||
# moderation called with original values
|
||||
mock_moderate.assert_called_once()
|
||||
|
||||
# application_generate_entity should be updated to overridden values
|
||||
assert runner.application_generate_entity.inputs == overridden_inputs
|
||||
assert runner.application_generate_entity.query == overridden_query
|
||||
|
||||
# annotation reply should use the new query
|
||||
mock_anno.assert_called()
|
||||
assert mock_anno.call_args.kwargs.get("query") == overridden_query
|
||||
|
||||
# since not stopped, graph initialization should proceed
|
||||
assert mock_init_graph.called
|
||||
|
||||
|
||||
def test_run_returns_early_when_direct_output_via_handle_input_moderation(build_runner):
|
||||
runner = build_runner
|
||||
|
||||
with (
|
||||
_patch_common_run_deps(runner),
|
||||
# Simulate handle_input_moderation signalling to stop
|
||||
patch.object(
|
||||
runner,
|
||||
"handle_input_moderation",
|
||||
return_value=(True, runner.application_generate_entity.inputs, runner.application_generate_entity.query),
|
||||
) as mock_handle,
|
||||
patch.object(runner, "_init_graph") as mock_init_graph,
|
||||
patch.object(runner, "handle_annotation_reply") as mock_anno,
|
||||
):
|
||||
runner.run()
|
||||
|
||||
mock_handle.assert_called_once()
|
||||
# Ensure no further steps executed
|
||||
mock_anno.assert_not_called()
|
||||
mock_init_graph.assert_not_called()
|
||||
466
api/tests/unit_tests/services/test_api_token_service.py
Normal file
466
api/tests/unit_tests/services/test_api_token_service.py
Normal file
@ -0,0 +1,466 @@
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
import services.api_token_service as api_token_service_module
|
||||
from services.api_token_service import ApiTokenCache, CachedApiToken
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session():
|
||||
"""Fixture providing common DB session mocking for query_token_from_db tests."""
|
||||
fake_engine = MagicMock()
|
||||
|
||||
session = MagicMock()
|
||||
session_context = MagicMock()
|
||||
session_context.__enter__.return_value = session
|
||||
session_context.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
patch.object(api_token_service_module, "db", new=SimpleNamespace(engine=fake_engine)),
|
||||
patch.object(api_token_service_module, "Session", return_value=session_context) as mock_session_class,
|
||||
patch.object(api_token_service_module.ApiTokenCache, "set") as mock_cache_set,
|
||||
patch.object(api_token_service_module, "record_token_usage") as mock_record_usage,
|
||||
):
|
||||
yield {
|
||||
"session": session,
|
||||
"mock_session_class": mock_session_class,
|
||||
"mock_cache_set": mock_cache_set,
|
||||
"mock_record_usage": mock_record_usage,
|
||||
"fake_engine": fake_engine,
|
||||
}
|
||||
|
||||
|
||||
class TestQueryTokenFromDb:
|
||||
def test_should_return_api_token_and_cache_when_token_exists(self, mock_db_session):
|
||||
"""Test DB lookup success path caches token and records usage."""
|
||||
# Arrange
|
||||
auth_token = "token-123"
|
||||
scope = "app"
|
||||
api_token = MagicMock()
|
||||
|
||||
mock_db_session["session"].scalar.return_value = api_token
|
||||
|
||||
# Act
|
||||
result = api_token_service_module.query_token_from_db(auth_token, scope)
|
||||
|
||||
# Assert
|
||||
assert result == api_token
|
||||
mock_db_session["mock_session_class"].assert_called_once_with(
|
||||
mock_db_session["fake_engine"], expire_on_commit=False
|
||||
)
|
||||
mock_db_session["mock_cache_set"].assert_called_once_with(auth_token, scope, api_token)
|
||||
mock_db_session["mock_record_usage"].assert_called_once_with(auth_token, scope)
|
||||
|
||||
def test_should_cache_null_and_raise_unauthorized_when_token_not_found(self, mock_db_session):
|
||||
"""Test DB lookup miss path caches null marker and raises Unauthorized."""
|
||||
# Arrange
|
||||
auth_token = "missing-token"
|
||||
scope = "app"
|
||||
|
||||
mock_db_session["session"].scalar.return_value = None
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(Unauthorized, match="Access token is invalid"):
|
||||
api_token_service_module.query_token_from_db(auth_token, scope)
|
||||
|
||||
mock_db_session["mock_cache_set"].assert_called_once_with(auth_token, scope, None)
|
||||
mock_db_session["mock_record_usage"].assert_not_called()
|
||||
|
||||
|
||||
class TestRecordTokenUsage:
|
||||
def test_should_write_active_key_with_iso_timestamp_and_ttl(self):
|
||||
"""Test record_token_usage writes usage timestamp with one-hour TTL."""
|
||||
# Arrange
|
||||
auth_token = "token-123"
|
||||
scope = "dataset"
|
||||
fixed_time = datetime(2026, 2, 24, 12, 0, 0)
|
||||
expected_key = ApiTokenCache.make_active_key(auth_token, scope)
|
||||
|
||||
with (
|
||||
patch.object(api_token_service_module, "naive_utc_now", return_value=fixed_time),
|
||||
patch.object(api_token_service_module, "redis_client") as mock_redis,
|
||||
):
|
||||
# Act
|
||||
api_token_service_module.record_token_usage(auth_token, scope)
|
||||
|
||||
# Assert
|
||||
mock_redis.set.assert_called_once_with(expected_key, fixed_time.isoformat(), ex=3600)
|
||||
|
||||
def test_should_not_raise_when_redis_write_fails(self):
|
||||
"""Test record_token_usage swallows Redis errors."""
|
||||
# Arrange
|
||||
with patch.object(api_token_service_module, "redis_client") as mock_redis:
|
||||
mock_redis.set.side_effect = Exception("redis unavailable")
|
||||
|
||||
# Act / Assert
|
||||
api_token_service_module.record_token_usage("token-123", "app")
|
||||
|
||||
|
||||
class TestFetchTokenWithSingleFlight:
|
||||
def test_should_return_cached_token_when_lock_acquired_and_cache_filled(self):
|
||||
"""Test single-flight returns cache when another request already populated it."""
|
||||
# Arrange
|
||||
auth_token = "token-123"
|
||||
scope = "app"
|
||||
cached_token = CachedApiToken(
|
||||
id="id-1",
|
||||
app_id="app-1",
|
||||
tenant_id="tenant-1",
|
||||
type="app",
|
||||
token=auth_token,
|
||||
last_used_at=None,
|
||||
created_at=None,
|
||||
)
|
||||
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = True
|
||||
|
||||
with (
|
||||
patch.object(api_token_service_module, "redis_client") as mock_redis,
|
||||
patch.object(api_token_service_module.ApiTokenCache, "get", return_value=cached_token) as mock_cache_get,
|
||||
patch.object(api_token_service_module, "query_token_from_db") as mock_query_db,
|
||||
):
|
||||
mock_redis.lock.return_value = lock
|
||||
|
||||
# Act
|
||||
result = api_token_service_module.fetch_token_with_single_flight(auth_token, scope)
|
||||
|
||||
# Assert
|
||||
assert result == cached_token
|
||||
mock_redis.lock.assert_called_once_with(
|
||||
f"api_token_query_lock:{scope}:{auth_token}",
|
||||
timeout=10,
|
||||
blocking_timeout=5,
|
||||
)
|
||||
lock.acquire.assert_called_once_with(blocking=True)
|
||||
lock.release.assert_called_once()
|
||||
mock_cache_get.assert_called_once_with(auth_token, scope)
|
||||
mock_query_db.assert_not_called()
|
||||
|
||||
def test_should_query_db_when_lock_acquired_and_cache_missed(self):
|
||||
"""Test single-flight queries DB when cache remains empty after lock acquisition."""
|
||||
# Arrange
|
||||
auth_token = "token-123"
|
||||
scope = "app"
|
||||
db_token = MagicMock()
|
||||
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = True
|
||||
|
||||
with (
|
||||
patch.object(api_token_service_module, "redis_client") as mock_redis,
|
||||
patch.object(api_token_service_module.ApiTokenCache, "get", return_value=None),
|
||||
patch.object(api_token_service_module, "query_token_from_db", return_value=db_token) as mock_query_db,
|
||||
):
|
||||
mock_redis.lock.return_value = lock
|
||||
|
||||
# Act
|
||||
result = api_token_service_module.fetch_token_with_single_flight(auth_token, scope)
|
||||
|
||||
# Assert
|
||||
assert result == db_token
|
||||
mock_query_db.assert_called_once_with(auth_token, scope)
|
||||
lock.release.assert_called_once()
|
||||
|
||||
def test_should_query_db_directly_when_lock_not_acquired(self):
|
||||
"""Test lock timeout branch falls back to direct DB query."""
|
||||
# Arrange
|
||||
auth_token = "token-123"
|
||||
scope = "app"
|
||||
db_token = MagicMock()
|
||||
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = False
|
||||
|
||||
with (
|
||||
patch.object(api_token_service_module, "redis_client") as mock_redis,
|
||||
patch.object(api_token_service_module.ApiTokenCache, "get") as mock_cache_get,
|
||||
patch.object(api_token_service_module, "query_token_from_db", return_value=db_token) as mock_query_db,
|
||||
):
|
||||
mock_redis.lock.return_value = lock
|
||||
|
||||
# Act
|
||||
result = api_token_service_module.fetch_token_with_single_flight(auth_token, scope)
|
||||
|
||||
# Assert
|
||||
assert result == db_token
|
||||
mock_cache_get.assert_not_called()
|
||||
mock_query_db.assert_called_once_with(auth_token, scope)
|
||||
lock.release.assert_not_called()
|
||||
|
||||
def test_should_reraise_unauthorized_from_db_query(self):
|
||||
"""Test Unauthorized from DB query is propagated unchanged."""
|
||||
# Arrange
|
||||
auth_token = "token-123"
|
||||
scope = "app"
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = True
|
||||
|
||||
with (
|
||||
patch.object(api_token_service_module, "redis_client") as mock_redis,
|
||||
patch.object(api_token_service_module.ApiTokenCache, "get", return_value=None),
|
||||
patch.object(
|
||||
api_token_service_module,
|
||||
"query_token_from_db",
|
||||
side_effect=Unauthorized("Access token is invalid"),
|
||||
),
|
||||
):
|
||||
mock_redis.lock.return_value = lock
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(Unauthorized, match="Access token is invalid"):
|
||||
api_token_service_module.fetch_token_with_single_flight(auth_token, scope)
|
||||
|
||||
lock.release.assert_called_once()
|
||||
|
||||
def test_should_fallback_to_db_query_when_lock_raises_exception(self):
|
||||
"""Test Redis lock errors fall back to direct DB query."""
|
||||
# Arrange
|
||||
auth_token = "token-123"
|
||||
scope = "app"
|
||||
db_token = MagicMock()
|
||||
|
||||
lock = MagicMock()
|
||||
lock.acquire.side_effect = RuntimeError("redis lock error")
|
||||
|
||||
with (
|
||||
patch.object(api_token_service_module, "redis_client") as mock_redis,
|
||||
patch.object(api_token_service_module, "query_token_from_db", return_value=db_token) as mock_query_db,
|
||||
):
|
||||
mock_redis.lock.return_value = lock
|
||||
|
||||
# Act
|
||||
result = api_token_service_module.fetch_token_with_single_flight(auth_token, scope)
|
||||
|
||||
# Assert
|
||||
assert result == db_token
|
||||
mock_query_db.assert_called_once_with(auth_token, scope)
|
||||
|
||||
|
||||
class TestApiTokenCacheTenantBranches:
|
||||
@patch("services.api_token_service.redis_client")
|
||||
def test_delete_with_scope_should_remove_from_tenant_index_when_tenant_found(self, mock_redis):
|
||||
"""Test scoped delete removes cache key and tenant index membership."""
|
||||
# Arrange
|
||||
token = "token-123"
|
||||
scope = "app"
|
||||
cache_key = ApiTokenCache._make_cache_key(token, scope)
|
||||
cached_token = CachedApiToken(
|
||||
id="id-1",
|
||||
app_id="app-1",
|
||||
tenant_id="tenant-1",
|
||||
type="app",
|
||||
token=token,
|
||||
last_used_at=None,
|
||||
created_at=None,
|
||||
)
|
||||
mock_redis.get.return_value = cached_token.model_dump_json().encode("utf-8")
|
||||
|
||||
with patch.object(ApiTokenCache, "_remove_from_tenant_index") as mock_remove_index:
|
||||
# Act
|
||||
result = ApiTokenCache.delete(token, scope)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_redis.delete.assert_called_once_with(cache_key)
|
||||
mock_remove_index.assert_called_once_with("tenant-1", cache_key)
|
||||
|
||||
@patch("services.api_token_service.redis_client")
|
||||
def test_invalidate_by_tenant_should_delete_all_indexed_cache_keys(self, mock_redis):
|
||||
"""Test tenant invalidation deletes indexed cache entries and index key."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-1"
|
||||
index_key = ApiTokenCache._make_tenant_index_key(tenant_id)
|
||||
mock_redis.smembers.return_value = {
|
||||
b"api_token:app:token-1",
|
||||
b"api_token:any:token-2",
|
||||
}
|
||||
|
||||
# Act
|
||||
result = ApiTokenCache.invalidate_by_tenant(tenant_id)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_redis.smembers.assert_called_once_with(index_key)
|
||||
mock_redis.delete.assert_any_call("api_token:app:token-1")
|
||||
mock_redis.delete.assert_any_call("api_token:any:token-2")
|
||||
mock_redis.delete.assert_any_call(index_key)
|
||||
|
||||
|
||||
class TestApiTokenCacheCoreBranches:
|
||||
def test_cached_api_token_repr_should_include_id_and_type(self):
|
||||
"""Test CachedApiToken __repr__ includes key identity fields."""
|
||||
token = CachedApiToken(
|
||||
id="id-123",
|
||||
app_id="app-123",
|
||||
tenant_id="tenant-123",
|
||||
type="app",
|
||||
token="token-123",
|
||||
last_used_at=None,
|
||||
created_at=None,
|
||||
)
|
||||
|
||||
assert repr(token) == "<CachedApiToken id=id-123 type=app>"
|
||||
|
||||
def test_serialize_token_should_handle_cached_api_token_instances(self):
|
||||
"""Test serialization path when input is already a CachedApiToken."""
|
||||
token = CachedApiToken(
|
||||
id="id-123",
|
||||
app_id="app-123",
|
||||
tenant_id="tenant-123",
|
||||
type="app",
|
||||
token="token-123",
|
||||
last_used_at=None,
|
||||
created_at=None,
|
||||
)
|
||||
|
||||
serialized = ApiTokenCache._serialize_token(token)
|
||||
|
||||
assert isinstance(serialized, bytes)
|
||||
assert b'"id":"id-123"' in serialized
|
||||
assert b'"token":"token-123"' in serialized
|
||||
|
||||
def test_deserialize_token_should_return_none_for_null_markers(self):
|
||||
"""Test null cache marker deserializes to None."""
|
||||
assert ApiTokenCache._deserialize_token("null") is None
|
||||
assert ApiTokenCache._deserialize_token(b"null") is None
|
||||
|
||||
def test_deserialize_token_should_return_none_for_invalid_payload(self):
|
||||
"""Test invalid serialized payload returns None."""
|
||||
assert ApiTokenCache._deserialize_token("not-json") is None
|
||||
|
||||
@patch("services.api_token_service.redis_client")
|
||||
def test_get_should_return_none_on_cache_miss(self, mock_redis):
|
||||
"""Test cache miss branch in ApiTokenCache.get."""
|
||||
mock_redis.get.return_value = None
|
||||
|
||||
result = ApiTokenCache.get("token-123", "app")
|
||||
|
||||
assert result is None
|
||||
mock_redis.get.assert_called_once_with("api_token:app:token-123")
|
||||
|
||||
@patch("services.api_token_service.redis_client")
|
||||
def test_get_should_deserialize_cached_payload_on_cache_hit(self, mock_redis):
|
||||
"""Test cache hit branch in ApiTokenCache.get."""
|
||||
token = CachedApiToken(
|
||||
id="id-123",
|
||||
app_id="app-123",
|
||||
tenant_id="tenant-123",
|
||||
type="app",
|
||||
token="token-123",
|
||||
last_used_at=None,
|
||||
created_at=None,
|
||||
)
|
||||
mock_redis.get.return_value = token.model_dump_json().encode("utf-8")
|
||||
|
||||
result = ApiTokenCache.get("token-123", "app")
|
||||
|
||||
assert isinstance(result, CachedApiToken)
|
||||
assert result.id == "id-123"
|
||||
|
||||
@patch("services.api_token_service.redis_client")
|
||||
def test_add_to_tenant_index_should_skip_when_tenant_id_missing(self, mock_redis):
|
||||
"""Test tenant index update exits early for missing tenant id."""
|
||||
ApiTokenCache._add_to_tenant_index(None, "api_token:app:token-123")
|
||||
|
||||
mock_redis.sadd.assert_not_called()
|
||||
mock_redis.expire.assert_not_called()
|
||||
|
||||
@patch("services.api_token_service.redis_client")
|
||||
def test_add_to_tenant_index_should_swallow_index_update_errors(self, mock_redis):
|
||||
"""Test tenant index update handles Redis write errors gracefully."""
|
||||
mock_redis.sadd.side_effect = Exception("redis down")
|
||||
|
||||
ApiTokenCache._add_to_tenant_index("tenant-123", "api_token:app:token-123")
|
||||
|
||||
mock_redis.sadd.assert_called_once()
|
||||
|
||||
@patch("services.api_token_service.redis_client")
|
||||
def test_remove_from_tenant_index_should_skip_when_tenant_id_missing(self, mock_redis):
|
||||
"""Test tenant index removal exits early for missing tenant id."""
|
||||
ApiTokenCache._remove_from_tenant_index(None, "api_token:app:token-123")
|
||||
|
||||
mock_redis.srem.assert_not_called()
|
||||
|
||||
@patch("services.api_token_service.redis_client")
|
||||
def test_remove_from_tenant_index_should_swallow_redis_errors(self, mock_redis):
|
||||
"""Test tenant index removal handles Redis errors gracefully."""
|
||||
mock_redis.srem.side_effect = Exception("redis down")
|
||||
|
||||
ApiTokenCache._remove_from_tenant_index("tenant-123", "api_token:app:token-123")
|
||||
|
||||
mock_redis.srem.assert_called_once()
|
||||
|
||||
@patch("services.api_token_service.redis_client")
|
||||
def test_set_should_return_false_when_cache_write_raises_exception(self, mock_redis):
|
||||
"""Test set returns False when Redis setex fails."""
|
||||
mock_redis.setex.side_effect = Exception("redis write failed")
|
||||
api_token = MagicMock()
|
||||
api_token.id = "id-123"
|
||||
api_token.app_id = "app-123"
|
||||
api_token.tenant_id = "tenant-123"
|
||||
api_token.type = "app"
|
||||
api_token.token = "token-123"
|
||||
api_token.last_used_at = None
|
||||
api_token.created_at = None
|
||||
|
||||
result = ApiTokenCache.set("token-123", "app", api_token)
|
||||
|
||||
assert result is False
|
||||
|
||||
@patch("services.api_token_service.redis_client")
|
||||
def test_delete_without_scope_should_return_false_when_scan_fails(self, mock_redis):
|
||||
"""Test delete(scope=None) returns False when scan_iter raises."""
|
||||
mock_redis.scan_iter.side_effect = Exception("scan failed")
|
||||
|
||||
result = ApiTokenCache.delete("token-123", None)
|
||||
|
||||
assert result is False
|
||||
|
||||
@patch("services.api_token_service.redis_client")
|
||||
def test_delete_with_scope_should_continue_when_tenant_lookup_raises(self, mock_redis):
|
||||
"""Test scoped delete still succeeds when tenant lookup from cache fails."""
|
||||
token = "token-123"
|
||||
scope = "app"
|
||||
cache_key = ApiTokenCache._make_cache_key(token, scope)
|
||||
mock_redis.get.side_effect = Exception("get failed")
|
||||
|
||||
result = ApiTokenCache.delete(token, scope)
|
||||
|
||||
assert result is True
|
||||
mock_redis.delete.assert_called_once_with(cache_key)
|
||||
|
||||
@patch("services.api_token_service.redis_client")
|
||||
def test_delete_with_scope_should_return_false_when_delete_raises(self, mock_redis):
|
||||
"""Test scoped delete returns False when delete operation fails."""
|
||||
token = "token-123"
|
||||
scope = "app"
|
||||
mock_redis.get.return_value = None
|
||||
mock_redis.delete.side_effect = Exception("delete failed")
|
||||
|
||||
result = ApiTokenCache.delete(token, scope)
|
||||
|
||||
assert result is False
|
||||
|
||||
@patch("services.api_token_service.redis_client")
|
||||
def test_invalidate_by_tenant_should_return_true_when_index_not_found(self, mock_redis):
|
||||
"""Test tenant invalidation returns True when tenant index is empty."""
|
||||
mock_redis.smembers.return_value = set()
|
||||
|
||||
result = ApiTokenCache.invalidate_by_tenant("tenant-123")
|
||||
|
||||
assert result is True
|
||||
mock_redis.delete.assert_not_called()
|
||||
|
||||
@patch("services.api_token_service.redis_client")
|
||||
def test_invalidate_by_tenant_should_return_false_when_redis_raises(self, mock_redis):
|
||||
"""Test tenant invalidation returns False when Redis operation fails."""
|
||||
mock_redis.smembers.side_effect = Exception("redis failed")
|
||||
|
||||
result = ApiTokenCache.invalidate_by_tenant("tenant-123")
|
||||
|
||||
assert result is False
|
||||
@ -0,0 +1,88 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from models.model import AppMode
|
||||
from services.app_model_config_service import AppModelConfigService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config_managers():
|
||||
"""Fixture that patches all app config manager validate methods.
|
||||
|
||||
Returns a dictionary containing the mocked config_validate methods for each manager.
|
||||
"""
|
||||
with (
|
||||
patch("services.app_model_config_service.ChatAppConfigManager.config_validate") as mock_chat_validate,
|
||||
patch("services.app_model_config_service.AgentChatAppConfigManager.config_validate") as mock_agent_validate,
|
||||
patch(
|
||||
"services.app_model_config_service.CompletionAppConfigManager.config_validate"
|
||||
) as mock_completion_validate,
|
||||
):
|
||||
mock_chat_validate.return_value = {"manager": "chat"}
|
||||
mock_agent_validate.return_value = {"manager": "agent"}
|
||||
mock_completion_validate.return_value = {"manager": "completion"}
|
||||
|
||||
yield {
|
||||
"chat": mock_chat_validate,
|
||||
"agent": mock_agent_validate,
|
||||
"completion": mock_completion_validate,
|
||||
}
|
||||
|
||||
|
||||
class TestAppModelConfigService:
|
||||
@pytest.mark.parametrize(
|
||||
("app_mode", "selected_manager"),
|
||||
[
|
||||
(AppMode.CHAT, "chat"),
|
||||
(AppMode.AGENT_CHAT, "agent"),
|
||||
(AppMode.COMPLETION, "completion"),
|
||||
],
|
||||
)
|
||||
def test_should_route_validation_to_correct_manager_based_on_app_mode(
|
||||
self, app_mode, selected_manager, mock_config_managers
|
||||
):
|
||||
"""Test configuration validation is delegated to the expected manager for each supported app mode."""
|
||||
tenant_id = "tenant-123"
|
||||
config = {"temperature": 0.5}
|
||||
|
||||
mock_chat_validate = mock_config_managers["chat"]
|
||||
mock_agent_validate = mock_config_managers["agent"]
|
||||
mock_completion_validate = mock_config_managers["completion"]
|
||||
|
||||
result = AppModelConfigService.validate_configuration(tenant_id=tenant_id, config=config, app_mode=app_mode)
|
||||
|
||||
assert result == {"manager": selected_manager}
|
||||
|
||||
if selected_manager == "chat":
|
||||
mock_chat_validate.assert_called_once_with(tenant_id, config)
|
||||
mock_agent_validate.assert_not_called()
|
||||
mock_completion_validate.assert_not_called()
|
||||
elif selected_manager == "agent":
|
||||
mock_agent_validate.assert_called_once_with(tenant_id, config)
|
||||
mock_chat_validate.assert_not_called()
|
||||
mock_completion_validate.assert_not_called()
|
||||
else:
|
||||
mock_completion_validate.assert_called_once_with(tenant_id, config)
|
||||
mock_chat_validate.assert_not_called()
|
||||
mock_agent_validate.assert_not_called()
|
||||
|
||||
def test_should_raise_value_error_when_app_mode_is_not_supported(self, mock_config_managers):
|
||||
"""Test unsupported app modes raise ValueError with the invalid mode in the message."""
|
||||
tenant_id = "tenant-123"
|
||||
config = {"temperature": 0.5}
|
||||
|
||||
mock_chat_validate = mock_config_managers["chat"]
|
||||
mock_agent_validate = mock_config_managers["agent"]
|
||||
mock_completion_validate = mock_config_managers["completion"]
|
||||
|
||||
with pytest.raises(ValueError, match=f"Invalid app mode: {AppMode.WORKFLOW}"):
|
||||
AppModelConfigService.validate_configuration(
|
||||
tenant_id=tenant_id,
|
||||
config=config,
|
||||
app_mode=AppMode.WORKFLOW,
|
||||
)
|
||||
|
||||
mock_chat_validate.assert_not_called()
|
||||
mock_agent_validate.assert_not_called()
|
||||
mock_completion_validate.assert_not_called()
|
||||
507
api/tests/unit_tests/services/test_async_workflow_service.py
Normal file
507
api/tests/unit_tests/services/test_async_workflow_service.py
Normal file
@ -0,0 +1,507 @@
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
import services.async_workflow_service as async_workflow_service_module
|
||||
from models.enums import AppTriggerType, CreatorUserRole, WorkflowRunTriggeredFrom, WorkflowTriggerStatus
|
||||
from services.async_workflow_service import AsyncWorkflowService
|
||||
from services.errors.app import QuotaExceededError, WorkflowNotFoundError, WorkflowQuotaLimitError
|
||||
from services.workflow.entities import AsyncTriggerResponse, TriggerData
|
||||
from services.workflow.queue_dispatcher import QueuePriority
|
||||
|
||||
|
||||
class AsyncWorkflowServiceTestDataFactory:
|
||||
"""Factory helpers for async workflow service unit tests."""
|
||||
|
||||
@staticmethod
|
||||
def create_trigger_data(
|
||||
app_id: str = "app-123",
|
||||
tenant_id: str = "tenant-123",
|
||||
workflow_id: str | None = "workflow-123",
|
||||
root_node_id: str = "root-node-123",
|
||||
) -> TriggerData:
|
||||
"""Create valid trigger data for async workflow execution tests."""
|
||||
return TriggerData(
|
||||
app_id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
workflow_id=workflow_id,
|
||||
root_node_id=root_node_id,
|
||||
inputs={"name": "dify"},
|
||||
files=[],
|
||||
trigger_type=AppTriggerType.UNKNOWN,
|
||||
trigger_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
trigger_metadata=None,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_trigger_log_with_data(trigger_data: TriggerData, retry_count: int = 0) -> MagicMock:
|
||||
"""Create a mock trigger log with serialized trigger data."""
|
||||
trigger_log = MagicMock()
|
||||
trigger_log.id = "trigger-log-123"
|
||||
trigger_log.trigger_data = trigger_data.model_dump_json()
|
||||
trigger_log.retry_count = retry_count
|
||||
trigger_log.error = "previous-error"
|
||||
trigger_log.status = WorkflowTriggerStatus.FAILED
|
||||
trigger_log.to_dict.return_value = {"id": trigger_log.id}
|
||||
return trigger_log
|
||||
|
||||
|
||||
class TestAsyncWorkflowService:
|
||||
@pytest.fixture
|
||||
def async_workflow_trigger_mocks(self):
|
||||
"""Shared fixture for async workflow trigger tests.
|
||||
|
||||
Yields mocks for:
|
||||
- repo: SQLAlchemyWorkflowTriggerLogRepository
|
||||
- dispatcher_manager_class: QueueDispatcherManager class
|
||||
- dispatcher: dispatcher instance
|
||||
- quota_workflow: QuotaType.WORKFLOW
|
||||
- get_workflow: AsyncWorkflowService._get_workflow method
|
||||
- professional_task: execute_workflow_professional
|
||||
- team_task: execute_workflow_team
|
||||
- sandbox_task: execute_workflow_sandbox
|
||||
"""
|
||||
mock_repo = MagicMock()
|
||||
|
||||
def _create_side_effect(new_log):
|
||||
new_log.id = "trigger-log-123"
|
||||
return new_log
|
||||
|
||||
mock_repo.create.side_effect = _create_side_effect
|
||||
|
||||
mock_dispatcher = MagicMock()
|
||||
quota_workflow = MagicMock()
|
||||
mock_get_workflow = MagicMock()
|
||||
|
||||
mock_professional_task = MagicMock()
|
||||
mock_team_task = MagicMock()
|
||||
mock_sandbox_task = MagicMock()
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
async_workflow_service_module,
|
||||
"SQLAlchemyWorkflowTriggerLogRepository",
|
||||
return_value=mock_repo,
|
||||
),
|
||||
patch.object(async_workflow_service_module, "QueueDispatcherManager") as mock_dispatcher_manager_class,
|
||||
patch.object(async_workflow_service_module, "WorkflowService"),
|
||||
patch.object(
|
||||
async_workflow_service_module.AsyncWorkflowService,
|
||||
"_get_workflow",
|
||||
) as mock_get_workflow,
|
||||
patch.object(
|
||||
async_workflow_service_module,
|
||||
"QuotaType",
|
||||
new=SimpleNamespace(WORKFLOW=quota_workflow),
|
||||
),
|
||||
patch.object(async_workflow_service_module, "execute_workflow_professional") as mock_professional_task,
|
||||
patch.object(async_workflow_service_module, "execute_workflow_team") as mock_team_task,
|
||||
patch.object(async_workflow_service_module, "execute_workflow_sandbox") as mock_sandbox_task,
|
||||
):
|
||||
# Configure dispatcher_manager to return our mock_dispatcher
|
||||
mock_dispatcher_manager_class.return_value.get_dispatcher.return_value = mock_dispatcher
|
||||
|
||||
yield {
|
||||
"repo": mock_repo,
|
||||
"dispatcher_manager_class": mock_dispatcher_manager_class,
|
||||
"dispatcher": mock_dispatcher,
|
||||
"quota_workflow": quota_workflow,
|
||||
"get_workflow": mock_get_workflow,
|
||||
"professional_task": mock_professional_task,
|
||||
"team_task": mock_team_task,
|
||||
"sandbox_task": mock_sandbox_task,
|
||||
}
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("queue_name", "selected_task_attr"),
|
||||
[
|
||||
(QueuePriority.PROFESSIONAL, "execute_workflow_professional"),
|
||||
(QueuePriority.TEAM, "execute_workflow_team"),
|
||||
(QueuePriority.SANDBOX, "execute_workflow_sandbox"),
|
||||
],
|
||||
)
|
||||
def test_should_dispatch_to_matching_celery_task_when_triggering_workflow(
|
||||
self, queue_name, selected_task_attr, async_workflow_trigger_mocks
|
||||
):
|
||||
"""Test queue-based task routing and successful async trigger response."""
|
||||
# Arrange
|
||||
session = MagicMock()
|
||||
session.commit = MagicMock()
|
||||
app_model = MagicMock()
|
||||
app_model.id = "app-123"
|
||||
session.scalar.return_value = app_model
|
||||
trigger_data = AsyncWorkflowServiceTestDataFactory.create_trigger_data()
|
||||
workflow = MagicMock()
|
||||
workflow.id = "workflow-123"
|
||||
|
||||
mocks = async_workflow_trigger_mocks
|
||||
mocks["dispatcher"].get_queue_name.return_value = queue_name
|
||||
mocks["get_workflow"].return_value = workflow
|
||||
|
||||
task_result = MagicMock()
|
||||
task_result.id = "task-123"
|
||||
mocks["professional_task"].delay.return_value = task_result
|
||||
mocks["team_task"].delay.return_value = task_result
|
||||
mocks["sandbox_task"].delay.return_value = task_result
|
||||
|
||||
class DummyAccount:
|
||||
def __init__(self, user_id: str):
|
||||
self.id = user_id
|
||||
|
||||
with patch.object(async_workflow_service_module, "Account", DummyAccount):
|
||||
user = DummyAccount("account-123")
|
||||
|
||||
# Act
|
||||
result = AsyncWorkflowService.trigger_workflow_async(session=session, user=user, trigger_data=trigger_data)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, AsyncTriggerResponse)
|
||||
assert result.workflow_trigger_log_id == "trigger-log-123"
|
||||
assert result.task_id == "task-123"
|
||||
assert result.status == "queued"
|
||||
assert result.queue == queue_name
|
||||
|
||||
mocks["quota_workflow"].consume.assert_called_once_with("tenant-123")
|
||||
assert session.commit.call_count == 2
|
||||
|
||||
created_log = mocks["repo"].create.call_args[0][0]
|
||||
assert created_log.status == WorkflowTriggerStatus.QUEUED
|
||||
assert created_log.queue_name == queue_name
|
||||
assert created_log.created_by_role == CreatorUserRole.ACCOUNT
|
||||
assert created_log.created_by == "account-123"
|
||||
assert created_log.trigger_data == trigger_data.model_dump_json()
|
||||
assert created_log.inputs == json.dumps(dict(trigger_data.inputs))
|
||||
assert created_log.celery_task_id == "task-123"
|
||||
|
||||
task_mocks = {
|
||||
"execute_workflow_professional": mocks["professional_task"],
|
||||
"execute_workflow_team": mocks["team_task"],
|
||||
"execute_workflow_sandbox": mocks["sandbox_task"],
|
||||
}
|
||||
for task_attr, task_mock in task_mocks.items():
|
||||
if task_attr == selected_task_attr:
|
||||
task_mock.delay.assert_called_once_with({"workflow_trigger_log_id": "trigger-log-123"})
|
||||
else:
|
||||
task_mock.delay.assert_not_called()
|
||||
|
||||
def test_should_set_end_user_role_when_triggered_by_end_user(self, async_workflow_trigger_mocks):
|
||||
"""Test that non-account users are tracked as END_USER in trigger logs."""
|
||||
# Arrange
|
||||
session = MagicMock()
|
||||
session.commit = MagicMock()
|
||||
app_model = MagicMock()
|
||||
app_model.id = "app-123"
|
||||
session.scalar.return_value = app_model
|
||||
trigger_data = AsyncWorkflowServiceTestDataFactory.create_trigger_data()
|
||||
workflow = MagicMock()
|
||||
workflow.id = "workflow-123"
|
||||
|
||||
mocks = async_workflow_trigger_mocks
|
||||
mocks["dispatcher"].get_queue_name.return_value = QueuePriority.SANDBOX
|
||||
mocks["get_workflow"].return_value = workflow
|
||||
|
||||
task_result = MagicMock(id="task-123")
|
||||
mocks["sandbox_task"].delay.return_value = task_result
|
||||
|
||||
user = SimpleNamespace(id="end-user-123")
|
||||
|
||||
# Act
|
||||
AsyncWorkflowService.trigger_workflow_async(session=session, user=user, trigger_data=trigger_data)
|
||||
|
||||
# Assert
|
||||
created_log = mocks["repo"].create.call_args[0][0]
|
||||
assert created_log.created_by_role == CreatorUserRole.END_USER
|
||||
assert created_log.created_by == "end-user-123"
|
||||
|
||||
def test_should_raise_workflow_not_found_when_app_does_not_exist(self):
|
||||
"""Test trigger failure when app lookup returns no result."""
|
||||
# Arrange
|
||||
session = MagicMock()
|
||||
session.scalar.return_value = None
|
||||
trigger_data = AsyncWorkflowServiceTestDataFactory.create_trigger_data(app_id="missing-app")
|
||||
|
||||
with (
|
||||
patch.object(async_workflow_service_module, "SQLAlchemyWorkflowTriggerLogRepository"),
|
||||
patch.object(async_workflow_service_module, "QueueDispatcherManager"),
|
||||
patch.object(async_workflow_service_module, "WorkflowService"),
|
||||
):
|
||||
# Act / Assert
|
||||
with pytest.raises(WorkflowNotFoundError, match="App not found: missing-app"):
|
||||
AsyncWorkflowService.trigger_workflow_async(
|
||||
session=session,
|
||||
user=SimpleNamespace(id="user-123"),
|
||||
trigger_data=trigger_data,
|
||||
)
|
||||
|
||||
def test_should_mark_log_rate_limited_and_raise_when_quota_exceeded(self, async_workflow_trigger_mocks):
|
||||
"""Test quota-exceeded path updates trigger log and raises WorkflowQuotaLimitError."""
|
||||
# Arrange
|
||||
session = MagicMock()
|
||||
session.commit = MagicMock()
|
||||
app_model = MagicMock()
|
||||
app_model.id = "app-123"
|
||||
session.scalar.return_value = app_model
|
||||
trigger_data = AsyncWorkflowServiceTestDataFactory.create_trigger_data()
|
||||
workflow = MagicMock()
|
||||
workflow.id = "workflow-123"
|
||||
|
||||
mocks = async_workflow_trigger_mocks
|
||||
mocks["dispatcher"].get_queue_name.return_value = QueuePriority.TEAM
|
||||
mocks["get_workflow"].return_value = workflow
|
||||
mocks["quota_workflow"].consume.side_effect = QuotaExceededError(
|
||||
feature="workflow",
|
||||
tenant_id="tenant-123",
|
||||
required=1,
|
||||
)
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(
|
||||
WorkflowQuotaLimitError,
|
||||
match="Workflow execution quota limit reached for tenant tenant-123",
|
||||
):
|
||||
AsyncWorkflowService.trigger_workflow_async(
|
||||
session=session,
|
||||
user=SimpleNamespace(id="user-123"),
|
||||
trigger_data=trigger_data,
|
||||
)
|
||||
|
||||
assert session.commit.call_count == 2
|
||||
updated_log = mocks["repo"].update.call_args[0][0]
|
||||
assert updated_log.status == WorkflowTriggerStatus.RATE_LIMITED
|
||||
assert "Quota limit reached" in updated_log.error
|
||||
mocks["professional_task"].delay.assert_not_called()
|
||||
mocks["team_task"].delay.assert_not_called()
|
||||
mocks["sandbox_task"].delay.assert_not_called()
|
||||
|
||||
def test_should_raise_when_reinvoke_target_log_does_not_exist(self):
|
||||
"""Test reinvoke_trigger error path when original trigger log is missing."""
|
||||
# Arrange
|
||||
session = MagicMock()
|
||||
repo = MagicMock()
|
||||
repo.get_by_id.return_value = None
|
||||
|
||||
with patch.object(async_workflow_service_module, "SQLAlchemyWorkflowTriggerLogRepository", return_value=repo):
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="Trigger log not found: missing-log"):
|
||||
AsyncWorkflowService.reinvoke_trigger(
|
||||
session=session,
|
||||
user=SimpleNamespace(id="user-123"),
|
||||
workflow_trigger_log_id="missing-log",
|
||||
)
|
||||
|
||||
def test_should_update_original_log_and_requeue_when_reinvoking(self):
|
||||
"""Test reinvoke flow updates original log state and triggers a new async run."""
|
||||
# Arrange
|
||||
session = MagicMock()
|
||||
trigger_data = AsyncWorkflowServiceTestDataFactory.create_trigger_data()
|
||||
trigger_log = AsyncWorkflowServiceTestDataFactory.create_trigger_log_with_data(trigger_data, retry_count=1)
|
||||
repo = MagicMock()
|
||||
repo.get_by_id.return_value = trigger_log
|
||||
|
||||
expected_response = AsyncTriggerResponse(
|
||||
workflow_trigger_log_id="new-trigger-log-456",
|
||||
task_id="task-456",
|
||||
status="queued",
|
||||
queue=QueuePriority.TEAM,
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(async_workflow_service_module, "SQLAlchemyWorkflowTriggerLogRepository", return_value=repo),
|
||||
patch.object(
|
||||
async_workflow_service_module.AsyncWorkflowService,
|
||||
"trigger_workflow_async",
|
||||
return_value=expected_response,
|
||||
) as mock_trigger_workflow_async,
|
||||
):
|
||||
user = SimpleNamespace(id="user-123")
|
||||
|
||||
# Act
|
||||
response = AsyncWorkflowService.reinvoke_trigger(
|
||||
session=session,
|
||||
user=user,
|
||||
workflow_trigger_log_id="trigger-log-123",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert response == expected_response
|
||||
assert trigger_log.status == WorkflowTriggerStatus.RETRYING
|
||||
assert trigger_log.retry_count == 2
|
||||
assert trigger_log.error is None
|
||||
assert trigger_log.triggered_at is not None
|
||||
repo.update.assert_called_once_with(trigger_log)
|
||||
session.commit.assert_called_once()
|
||||
called_trigger_data = mock_trigger_workflow_async.call_args[0][2]
|
||||
assert isinstance(called_trigger_data, TriggerData)
|
||||
assert called_trigger_data.app_id == "app-123"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("repo_result", "expected"),
|
||||
[
|
||||
(None, None),
|
||||
(MagicMock(), {"id": "trigger-log-123"}),
|
||||
],
|
||||
)
|
||||
def test_should_return_trigger_log_dict_or_none(self, repo_result, expected):
|
||||
"""Test get_trigger_log returns serialized log data or None."""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_repo = MagicMock()
|
||||
fake_engine = MagicMock()
|
||||
mock_repo.get_by_id.return_value = repo_result
|
||||
if repo_result:
|
||||
repo_result.to_dict.return_value = expected
|
||||
|
||||
mock_session_context = MagicMock()
|
||||
mock_session_context.__enter__.return_value = mock_session
|
||||
mock_session_context.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
patch.object(async_workflow_service_module, "db", new=SimpleNamespace(engine=fake_engine)),
|
||||
patch.object(
|
||||
async_workflow_service_module, "Session", return_value=mock_session_context
|
||||
) as mock_session_class,
|
||||
patch.object(
|
||||
async_workflow_service_module,
|
||||
"SQLAlchemyWorkflowTriggerLogRepository",
|
||||
return_value=mock_repo,
|
||||
),
|
||||
):
|
||||
# Act
|
||||
result = AsyncWorkflowService.get_trigger_log("trigger-log-123", tenant_id="tenant-123")
|
||||
|
||||
# Assert
|
||||
assert result == expected
|
||||
mock_session_class.assert_called_once_with(fake_engine)
|
||||
mock_repo.get_by_id.assert_called_once_with("trigger-log-123", "tenant-123")
|
||||
|
||||
def test_should_return_recent_logs_as_dict_list(self):
|
||||
"""Test get_recent_logs converts repository models into dictionaries."""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_repo = MagicMock()
|
||||
log1 = MagicMock()
|
||||
log1.to_dict.return_value = {"id": "log-1"}
|
||||
log2 = MagicMock()
|
||||
log2.to_dict.return_value = {"id": "log-2"}
|
||||
mock_repo.get_recent_logs.return_value = [log1, log2]
|
||||
|
||||
mock_session_context = MagicMock()
|
||||
mock_session_context.__enter__.return_value = mock_session
|
||||
mock_session_context.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
patch.object(async_workflow_service_module, "db", new=SimpleNamespace(engine=MagicMock())),
|
||||
patch.object(async_workflow_service_module, "Session", return_value=mock_session_context),
|
||||
patch.object(
|
||||
async_workflow_service_module,
|
||||
"SQLAlchemyWorkflowTriggerLogRepository",
|
||||
return_value=mock_repo,
|
||||
),
|
||||
):
|
||||
# Act
|
||||
result = AsyncWorkflowService.get_recent_logs(
|
||||
tenant_id="tenant-123",
|
||||
app_id="app-123",
|
||||
hours=12,
|
||||
limit=50,
|
||||
offset=10,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == [{"id": "log-1"}, {"id": "log-2"}]
|
||||
mock_repo.get_recent_logs.assert_called_once_with(
|
||||
tenant_id="tenant-123",
|
||||
app_id="app-123",
|
||||
hours=12,
|
||||
limit=50,
|
||||
offset=10,
|
||||
)
|
||||
|
||||
def test_should_return_failed_logs_for_retry_as_dict_list(self):
|
||||
"""Test get_failed_logs_for_retry serializes repository logs into dicts."""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_repo = MagicMock()
|
||||
log = MagicMock()
|
||||
log.to_dict.return_value = {"id": "failed-log-1"}
|
||||
mock_repo.get_failed_for_retry.return_value = [log]
|
||||
|
||||
mock_session_context = MagicMock()
|
||||
mock_session_context.__enter__.return_value = mock_session
|
||||
mock_session_context.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
patch.object(async_workflow_service_module, "db", new=SimpleNamespace(engine=MagicMock())),
|
||||
patch.object(async_workflow_service_module, "Session", return_value=mock_session_context),
|
||||
patch.object(
|
||||
async_workflow_service_module,
|
||||
"SQLAlchemyWorkflowTriggerLogRepository",
|
||||
return_value=mock_repo,
|
||||
),
|
||||
):
|
||||
# Act
|
||||
result = AsyncWorkflowService.get_failed_logs_for_retry(tenant_id="tenant-123", max_retry_count=4, limit=20)
|
||||
|
||||
# Assert
|
||||
assert result == [{"id": "failed-log-1"}]
|
||||
mock_repo.get_failed_for_retry.assert_called_once_with(tenant_id="tenant-123", max_retry_count=4, limit=20)
|
||||
|
||||
|
||||
class TestAsyncWorkflowServiceGetWorkflow:
|
||||
def test_should_return_specific_workflow_when_workflow_id_exists(self):
|
||||
"""Test _get_workflow returns published workflow by id when provided."""
|
||||
# Arrange
|
||||
workflow_service = MagicMock()
|
||||
app_model = MagicMock()
|
||||
workflow = MagicMock()
|
||||
workflow_service.get_published_workflow_by_id.return_value = workflow
|
||||
|
||||
# Act
|
||||
result = AsyncWorkflowService._get_workflow(workflow_service, app_model, workflow_id="workflow-123")
|
||||
|
||||
# Assert
|
||||
assert result == workflow
|
||||
workflow_service.get_published_workflow_by_id.assert_called_once_with(app_model, "workflow-123")
|
||||
workflow_service.get_published_workflow.assert_not_called()
|
||||
|
||||
def test_should_raise_when_specific_workflow_id_not_found(self):
|
||||
"""Test _get_workflow raises WorkflowNotFoundError for unknown workflow id."""
|
||||
# Arrange
|
||||
workflow_service = MagicMock()
|
||||
app_model = MagicMock()
|
||||
workflow_service.get_published_workflow_by_id.return_value = None
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(WorkflowNotFoundError, match="Published workflow not found: workflow-404"):
|
||||
AsyncWorkflowService._get_workflow(workflow_service, app_model, workflow_id="workflow-404")
|
||||
|
||||
def test_should_return_default_published_workflow_when_workflow_id_not_provided(self):
|
||||
"""Test _get_workflow returns default published workflow when no id is provided."""
|
||||
# Arrange
|
||||
workflow_service = MagicMock()
|
||||
app_model = MagicMock()
|
||||
app_model.id = "app-123"
|
||||
workflow = MagicMock()
|
||||
workflow_service.get_published_workflow.return_value = workflow
|
||||
|
||||
# Act
|
||||
result = AsyncWorkflowService._get_workflow(workflow_service, app_model)
|
||||
|
||||
# Assert
|
||||
assert result == workflow
|
||||
workflow_service.get_published_workflow.assert_called_once_with(app_model)
|
||||
workflow_service.get_published_workflow_by_id.assert_not_called()
|
||||
|
||||
def test_should_raise_when_default_published_workflow_not_found(self):
|
||||
"""Test _get_workflow raises WorkflowNotFoundError when app has no published workflow."""
|
||||
# Arrange
|
||||
workflow_service = MagicMock()
|
||||
app_model = MagicMock()
|
||||
app_model.id = "app-123"
|
||||
workflow_service.get_published_workflow.return_value = None
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(WorkflowNotFoundError, match="No published workflow found for app: app-123"):
|
||||
AsyncWorkflowService._get_workflow(workflow_service, app_model)
|
||||
73
api/tests/unit_tests/services/test_attachment_service.py
Normal file
73
api/tests/unit_tests/services/test_attachment_service.py
Normal file
@ -0,0 +1,73 @@
|
||||
import base64
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
import services.attachment_service as attachment_service_module
|
||||
from models.model import UploadFile
|
||||
from services.attachment_service import AttachmentService
|
||||
|
||||
|
||||
class TestAttachmentService:
|
||||
def test_should_initialize_with_sessionmaker_when_sessionmaker_is_provided(self):
|
||||
"""Test that AttachmentService keeps the provided sessionmaker instance."""
|
||||
session_factory = sessionmaker()
|
||||
|
||||
service = AttachmentService(session_factory=session_factory)
|
||||
|
||||
assert service._session_maker is session_factory
|
||||
|
||||
def test_should_initialize_with_bound_sessionmaker_when_engine_is_provided(self):
|
||||
"""Test that AttachmentService builds a sessionmaker bound to the provided engine."""
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
|
||||
service = AttachmentService(session_factory=engine)
|
||||
session = service._session_maker()
|
||||
try:
|
||||
assert session.bind == engine
|
||||
finally:
|
||||
session.close()
|
||||
engine.dispose()
|
||||
|
||||
@pytest.mark.parametrize("invalid_session_factory", [None, "not-a-session-factory", 1])
|
||||
def test_should_raise_assertion_error_when_session_factory_type_is_invalid(self, invalid_session_factory):
|
||||
"""Test that invalid session_factory types are rejected."""
|
||||
with pytest.raises(AssertionError, match="must be a sessionmaker or an Engine."):
|
||||
AttachmentService(session_factory=invalid_session_factory)
|
||||
|
||||
def test_should_return_base64_encoded_blob_when_file_exists(self):
|
||||
"""Test that existing files are loaded from storage and returned as base64."""
|
||||
service = AttachmentService(session_factory=sessionmaker())
|
||||
upload_file = MagicMock(spec=UploadFile)
|
||||
upload_file.key = "upload-file-key"
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = upload_file
|
||||
service._session_maker = MagicMock(return_value=session)
|
||||
|
||||
with patch.object(attachment_service_module.storage, "load_once", return_value=b"binary-content") as mock_load:
|
||||
result = service.get_file_base64("file-123")
|
||||
|
||||
assert result == base64.b64encode(b"binary-content").decode()
|
||||
service._session_maker.assert_called_once_with(expire_on_commit=False)
|
||||
session.query.assert_called_once_with(UploadFile)
|
||||
mock_load.assert_called_once_with("upload-file-key")
|
||||
|
||||
def test_should_raise_not_found_when_file_does_not_exist(self):
|
||||
"""Test that missing files raise NotFound and never call storage."""
|
||||
service = AttachmentService(session_factory=sessionmaker())
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
service._session_maker = MagicMock(return_value=session)
|
||||
|
||||
with patch.object(attachment_service_module.storage, "load_once") as mock_load:
|
||||
with pytest.raises(NotFound, match="File not found"):
|
||||
service.get_file_base64("missing-file")
|
||||
|
||||
service._session_maker.assert_called_once_with(expire_on_commit=False)
|
||||
session.query.assert_called_once_with(UploadFile)
|
||||
mock_load.assert_not_called()
|
||||
@ -0,0 +1,89 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from services.code_based_extension_service import CodeBasedExtensionService
|
||||
|
||||
|
||||
class TestCodeBasedExtensionService:
|
||||
def test_should_return_only_non_builtin_extensions_with_public_fields(self, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Test service returns only non-builtin extensions with name/label/form_schema fields."""
|
||||
moderation_extension = SimpleNamespace(
|
||||
name="custom-moderation",
|
||||
label={"en-US": "Custom Moderation"},
|
||||
form_schema=[{"variable": "api_key"}],
|
||||
builtin=False,
|
||||
extension_class=object,
|
||||
position=20,
|
||||
)
|
||||
builtin_extension = SimpleNamespace(
|
||||
name="builtin-moderation",
|
||||
label={"en-US": "Builtin Moderation"},
|
||||
form_schema=[{"variable": "token"}],
|
||||
builtin=True,
|
||||
extension_class=object,
|
||||
position=1,
|
||||
)
|
||||
retrieval_extension = SimpleNamespace(
|
||||
name="custom-retrieval",
|
||||
label={"en-US": "Custom Retrieval"},
|
||||
form_schema=None,
|
||||
builtin=False,
|
||||
extension_class=object,
|
||||
position=30,
|
||||
)
|
||||
module_extensions_mock = MagicMock(return_value=[moderation_extension, builtin_extension, retrieval_extension])
|
||||
monkeypatch.setattr(
|
||||
"services.code_based_extension_service.code_based_extension.module_extensions",
|
||||
module_extensions_mock,
|
||||
)
|
||||
|
||||
result = CodeBasedExtensionService.get_code_based_extension("external_data_tool")
|
||||
|
||||
assert result == [
|
||||
{
|
||||
"name": "custom-moderation",
|
||||
"label": {"en-US": "Custom Moderation"},
|
||||
"form_schema": [{"variable": "api_key"}],
|
||||
},
|
||||
{
|
||||
"name": "custom-retrieval",
|
||||
"label": {"en-US": "Custom Retrieval"},
|
||||
"form_schema": None,
|
||||
},
|
||||
]
|
||||
assert set(result[0].keys()) == {"name", "label", "form_schema"}
|
||||
module_extensions_mock.assert_called_once_with("external_data_tool")
|
||||
|
||||
def test_should_return_empty_list_when_all_extensions_are_builtin(self, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Test builtin extensions are filtered out completely."""
|
||||
builtin_extension = SimpleNamespace(
|
||||
name="builtin-moderation",
|
||||
label={"en-US": "Builtin Moderation"},
|
||||
form_schema=[{"variable": "token"}],
|
||||
builtin=True,
|
||||
)
|
||||
module_extensions_mock = MagicMock(return_value=[builtin_extension])
|
||||
monkeypatch.setattr(
|
||||
"services.code_based_extension_service.code_based_extension.module_extensions",
|
||||
module_extensions_mock,
|
||||
)
|
||||
|
||||
result = CodeBasedExtensionService.get_code_based_extension("moderation")
|
||||
|
||||
assert result == []
|
||||
module_extensions_mock.assert_called_once_with("moderation")
|
||||
|
||||
def test_should_propagate_error_when_module_extensions_lookup_fails(self, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Test ValueError from extension lookup bubbles up unchanged."""
|
||||
module_extensions_mock = MagicMock(side_effect=ValueError("Extension Module invalid-module not found"))
|
||||
monkeypatch.setattr(
|
||||
"services.code_based_extension_service.code_based_extension.module_extensions",
|
||||
module_extensions_mock,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Extension Module invalid-module not found"):
|
||||
CodeBasedExtensionService.get_code_based_extension("invalid-module")
|
||||
|
||||
module_extensions_mock.assert_called_once_with("invalid-module")
|
||||
@ -0,0 +1,75 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.variables.variables import StringVariable
|
||||
from services.conversation_variable_updater import ConversationVariableNotFoundError, ConversationVariableUpdater
|
||||
|
||||
|
||||
class TestConversationVariableUpdater:
|
||||
def test_should_update_conversation_variable_data_and_commit(self):
|
||||
"""Test update persists serialized variable data when the row exists."""
|
||||
conversation_id = "conv-123"
|
||||
variable = StringVariable(
|
||||
id="var-123",
|
||||
name="topic",
|
||||
value="new value",
|
||||
)
|
||||
expected_json = variable.model_dump_json()
|
||||
|
||||
row = SimpleNamespace(data="old value")
|
||||
session = MagicMock()
|
||||
session.scalar.return_value = row
|
||||
|
||||
session_context = MagicMock()
|
||||
session_context.__enter__.return_value = session
|
||||
session_context.__exit__.return_value = None
|
||||
|
||||
session_maker = MagicMock(return_value=session_context)
|
||||
updater = ConversationVariableUpdater(session_maker)
|
||||
|
||||
updater.update(conversation_id=conversation_id, variable=variable)
|
||||
|
||||
session_maker.assert_called_once_with()
|
||||
session.scalar.assert_called_once()
|
||||
stmt = session.scalar.call_args.args[0]
|
||||
compiled_params = stmt.compile().params
|
||||
assert variable.id in compiled_params.values()
|
||||
assert conversation_id in compiled_params.values()
|
||||
assert row.data == expected_json
|
||||
session.commit.assert_called_once()
|
||||
|
||||
def test_should_raise_not_found_error_when_conversation_variable_missing(self):
|
||||
"""Test update raises ConversationVariableNotFoundError when no matching row exists."""
|
||||
conversation_id = "conv-404"
|
||||
variable = StringVariable(
|
||||
id="var-404",
|
||||
name="topic",
|
||||
value="value",
|
||||
)
|
||||
|
||||
session = MagicMock()
|
||||
session.scalar.return_value = None
|
||||
|
||||
session_context = MagicMock()
|
||||
session_context.__enter__.return_value = session
|
||||
session_context.__exit__.return_value = None
|
||||
|
||||
session_maker = MagicMock(return_value=session_context)
|
||||
updater = ConversationVariableUpdater(session_maker)
|
||||
|
||||
with pytest.raises(ConversationVariableNotFoundError, match="conversation variable not found in the database"):
|
||||
updater.update(conversation_id=conversation_id, variable=variable)
|
||||
|
||||
session.commit.assert_not_called()
|
||||
|
||||
def test_should_do_nothing_when_flush_is_called(self):
|
||||
"""Test flush currently behaves as a no-op and returns None."""
|
||||
session_maker = MagicMock()
|
||||
updater = ConversationVariableUpdater(session_maker)
|
||||
|
||||
result = updater.flush()
|
||||
|
||||
assert result is None
|
||||
session_maker.assert_not_called()
|
||||
157
api/tests/unit_tests/services/test_credit_pool_service.py
Normal file
157
api/tests/unit_tests/services/test_credit_pool_service.py
Normal file
@ -0,0 +1,157 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
import services.credit_pool_service as credit_pool_service_module
|
||||
from core.errors.error import QuotaExceededError
|
||||
from models import TenantCreditPool
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_credit_deduction_setup():
|
||||
"""Fixture providing common setup for credit deduction tests."""
|
||||
pool = SimpleNamespace(remaining_credits=50)
|
||||
fake_engine = MagicMock()
|
||||
session = MagicMock()
|
||||
session_context = MagicMock()
|
||||
session_context.__enter__.return_value = session
|
||||
session_context.__exit__.return_value = None
|
||||
|
||||
mock_get_pool = patch.object(CreditPoolService, "get_pool", return_value=pool)
|
||||
mock_db = patch.object(credit_pool_service_module, "db", new=SimpleNamespace(engine=fake_engine))
|
||||
mock_session = patch.object(credit_pool_service_module, "Session", return_value=session_context)
|
||||
|
||||
return {
|
||||
"pool": pool,
|
||||
"fake_engine": fake_engine,
|
||||
"session": session,
|
||||
"session_context": session_context,
|
||||
"patches": (mock_get_pool, mock_db, mock_session),
|
||||
}
|
||||
|
||||
|
||||
class TestCreditPoolService:
|
||||
def test_should_create_default_pool_with_trial_type_and_configured_quota(self):
|
||||
"""Test create_default_pool persists a trial pool using configured hosted credits."""
|
||||
tenant_id = "tenant-123"
|
||||
hosted_pool_credits = 5000
|
||||
|
||||
with (
|
||||
patch.object(credit_pool_service_module.dify_config, "HOSTED_POOL_CREDITS", hosted_pool_credits),
|
||||
patch.object(credit_pool_service_module, "db") as mock_db,
|
||||
):
|
||||
pool = CreditPoolService.create_default_pool(tenant_id)
|
||||
|
||||
assert isinstance(pool, TenantCreditPool)
|
||||
assert pool.tenant_id == tenant_id
|
||||
assert pool.pool_type == "trial"
|
||||
assert pool.quota_limit == hosted_pool_credits
|
||||
assert pool.quota_used == 0
|
||||
mock_db.session.add.assert_called_once_with(pool)
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
def test_should_return_first_pool_from_query_when_get_pool_called(self):
|
||||
"""Test get_pool queries by tenant and pool_type and returns first result."""
|
||||
tenant_id = "tenant-123"
|
||||
pool_type = "enterprise"
|
||||
expected_pool = MagicMock(spec=TenantCreditPool)
|
||||
|
||||
with patch.object(credit_pool_service_module, "db") as mock_db:
|
||||
query = mock_db.session.query.return_value
|
||||
filtered_query = query.filter_by.return_value
|
||||
filtered_query.first.return_value = expected_pool
|
||||
|
||||
result = CreditPoolService.get_pool(tenant_id=tenant_id, pool_type=pool_type)
|
||||
|
||||
assert result == expected_pool
|
||||
mock_db.session.query.assert_called_once_with(TenantCreditPool)
|
||||
query.filter_by.assert_called_once_with(tenant_id=tenant_id, pool_type=pool_type)
|
||||
filtered_query.first.assert_called_once()
|
||||
|
||||
def test_should_return_false_when_pool_not_found_in_check_credits_available(self):
|
||||
"""Test check_credits_available returns False when tenant has no pool."""
|
||||
with patch.object(CreditPoolService, "get_pool", return_value=None) as mock_get_pool:
|
||||
result = CreditPoolService.check_credits_available(tenant_id="tenant-123", credits_required=10)
|
||||
|
||||
assert result is False
|
||||
mock_get_pool.assert_called_once_with("tenant-123", "trial")
|
||||
|
||||
def test_should_return_true_when_remaining_credits_cover_required_amount(self):
|
||||
"""Test check_credits_available returns True when remaining credits are sufficient."""
|
||||
pool = SimpleNamespace(remaining_credits=100)
|
||||
|
||||
with patch.object(CreditPoolService, "get_pool", return_value=pool) as mock_get_pool:
|
||||
result = CreditPoolService.check_credits_available(tenant_id="tenant-123", credits_required=60)
|
||||
|
||||
assert result is True
|
||||
mock_get_pool.assert_called_once_with("tenant-123", "trial")
|
||||
|
||||
def test_should_return_false_when_remaining_credits_are_insufficient(self):
|
||||
"""Test check_credits_available returns False when required credits exceed remaining credits."""
|
||||
pool = SimpleNamespace(remaining_credits=30)
|
||||
|
||||
with patch.object(CreditPoolService, "get_pool", return_value=pool):
|
||||
result = CreditPoolService.check_credits_available(tenant_id="tenant-123", credits_required=60)
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_should_raise_quota_exceeded_when_pool_not_found_in_check_and_deduct(self):
|
||||
"""Test check_and_deduct_credits raises when tenant credit pool does not exist."""
|
||||
with patch.object(CreditPoolService, "get_pool", return_value=None):
|
||||
with pytest.raises(QuotaExceededError, match="Credit pool not found"):
|
||||
CreditPoolService.check_and_deduct_credits(tenant_id="tenant-123", credits_required=10)
|
||||
|
||||
def test_should_raise_quota_exceeded_when_pool_has_no_remaining_credits(self):
|
||||
"""Test check_and_deduct_credits raises when remaining credits are zero or negative."""
|
||||
pool = SimpleNamespace(remaining_credits=0)
|
||||
|
||||
with patch.object(CreditPoolService, "get_pool", return_value=pool):
|
||||
with pytest.raises(QuotaExceededError, match="No credits remaining"):
|
||||
CreditPoolService.check_and_deduct_credits(tenant_id="tenant-123", credits_required=10)
|
||||
|
||||
def test_should_deduct_minimum_of_required_and_remaining_credits(self, mock_credit_deduction_setup):
|
||||
"""Test check_and_deduct_credits updates quota_used by the actual deducted amount."""
|
||||
tenant_id = "tenant-123"
|
||||
pool_type = "trial"
|
||||
credits_required = 200
|
||||
remaining_credits = 120
|
||||
expected_deducted_credits = 120
|
||||
|
||||
mock_credit_deduction_setup["pool"].remaining_credits = remaining_credits
|
||||
patches = mock_credit_deduction_setup["patches"]
|
||||
session = mock_credit_deduction_setup["session"]
|
||||
|
||||
with patches[0], patches[1], patches[2]:
|
||||
result = CreditPoolService.check_and_deduct_credits(
|
||||
tenant_id=tenant_id,
|
||||
credits_required=credits_required,
|
||||
pool_type=pool_type,
|
||||
)
|
||||
|
||||
assert result == expected_deducted_credits
|
||||
session.execute.assert_called_once()
|
||||
session.commit.assert_called_once()
|
||||
|
||||
stmt = session.execute.call_args.args[0]
|
||||
compiled_params = stmt.compile().params
|
||||
assert tenant_id in compiled_params.values()
|
||||
assert pool_type in compiled_params.values()
|
||||
assert expected_deducted_credits in compiled_params.values()
|
||||
|
||||
def test_should_raise_quota_exceeded_when_deduction_update_fails(self, mock_credit_deduction_setup):
|
||||
"""Test check_and_deduct_credits translates DB update failures to QuotaExceededError."""
|
||||
mock_credit_deduction_setup["pool"].remaining_credits = 50
|
||||
mock_credit_deduction_setup["session"].execute.side_effect = Exception("db failure")
|
||||
session = mock_credit_deduction_setup["session"]
|
||||
|
||||
patches = mock_credit_deduction_setup["patches"]
|
||||
mock_logger = patch.object(credit_pool_service_module, "logger")
|
||||
|
||||
with patches[0], patches[1], patches[2], mock_logger as mock_logger_obj:
|
||||
with pytest.raises(QuotaExceededError, match="Failed to deduct credits"):
|
||||
CreditPoolService.check_and_deduct_credits(tenant_id="tenant-123", credits_required=10)
|
||||
|
||||
session.commit.assert_not_called()
|
||||
mock_logger_obj.exception.assert_called_once()
|
||||
@ -7,7 +7,7 @@ import pytest
|
||||
from core.mcp.types import Tool as MCPTool
|
||||
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.entities.tool_entities import ToolParameter, ToolProviderType
|
||||
from models.tools import MCPToolProvider
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
||||
@ -175,6 +175,137 @@ class TestMCPToolTransform:
|
||||
# The actual parameter conversion is handled by convert_mcp_schema_to_parameter
|
||||
# which should be tested separately
|
||||
|
||||
def test_convert_mcp_schema_to_parameter_preserves_anyof_object_type(self):
|
||||
"""Nullable object schemas should keep the object parameter type."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"retrieval_model": {
|
||||
"anyOf": [{"type": "object"}, {"type": "null"}],
|
||||
"description": "检索模型配置",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
result = ToolTransformService.convert_mcp_schema_to_parameter(schema)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].name == "retrieval_model"
|
||||
assert result[0].type == ToolParameter.ToolParameterType.OBJECT
|
||||
assert result[0].input_schema == schema["properties"]["retrieval_model"]
|
||||
|
||||
def test_convert_mcp_schema_to_parameter_preserves_oneof_object_type(self):
|
||||
"""Nullable oneOf object schemas should keep the object parameter type."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"retrieval_model": {
|
||||
"oneOf": [{"type": "object"}, {"type": "null"}],
|
||||
"description": "检索模型配置",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
result = ToolTransformService.convert_mcp_schema_to_parameter(schema)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].name == "retrieval_model"
|
||||
assert result[0].type == ToolParameter.ToolParameterType.OBJECT
|
||||
assert result[0].input_schema == schema["properties"]["retrieval_model"]
|
||||
|
||||
def test_convert_mcp_schema_to_parameter_handles_null_type(self):
|
||||
"""Schemas with only a null type should fall back to string."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"null_prop_str": {"type": "null"},
|
||||
"null_prop_list": {"type": ["null"]},
|
||||
},
|
||||
}
|
||||
|
||||
result = ToolTransformService.convert_mcp_schema_to_parameter(schema)
|
||||
|
||||
assert len(result) == 2
|
||||
param_map = {parameter.name: parameter for parameter in result}
|
||||
assert "null_prop_str" in param_map
|
||||
assert param_map["null_prop_str"].type == ToolParameter.ToolParameterType.STRING
|
||||
assert "null_prop_list" in param_map
|
||||
assert param_map["null_prop_list"].type == ToolParameter.ToolParameterType.STRING
|
||||
|
||||
def test_convert_mcp_schema_to_parameter_preserves_allof_object_type_with_multiple_object_items(self):
|
||||
"""Property-level allOf with multiple object items should still resolve to object."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"config": {
|
||||
"allOf": [
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"enabled": {"type": "boolean"},
|
||||
},
|
||||
"required": ["enabled"],
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"priority": {"type": "integer", "minimum": 1, "maximum": 10},
|
||||
},
|
||||
"required": ["priority"],
|
||||
},
|
||||
],
|
||||
"description": "Config must match all schemas (allOf)",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
result = ToolTransformService.convert_mcp_schema_to_parameter(schema)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].name == "config"
|
||||
assert result[0].type == ToolParameter.ToolParameterType.OBJECT
|
||||
assert result[0].input_schema == schema["properties"]["config"]
|
||||
|
||||
def test_convert_mcp_schema_to_parameter_preserves_allof_object_type(self):
|
||||
"""Composed property schemas should keep the object parameter type."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"retrieval_model": {
|
||||
"allOf": [
|
||||
{"type": "object"},
|
||||
{"properties": {"top_k": {"type": "integer"}}},
|
||||
],
|
||||
"description": "检索模型配置",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
result = ToolTransformService.convert_mcp_schema_to_parameter(schema)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].name == "retrieval_model"
|
||||
assert result[0].type == ToolParameter.ToolParameterType.OBJECT
|
||||
assert result[0].input_schema == schema["properties"]["retrieval_model"]
|
||||
|
||||
def test_convert_mcp_schema_to_parameter_limits_recursive_schema_depth(self):
|
||||
"""Self-referential composed schemas should stop resolving after the configured max depth."""
|
||||
recursive_property: dict[str, object] = {"description": "Recursive schema"}
|
||||
recursive_property["anyOf"] = [recursive_property]
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"recursive_config": recursive_property,
|
||||
},
|
||||
}
|
||||
|
||||
result = ToolTransformService.convert_mcp_schema_to_parameter(schema)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].name == "recursive_config"
|
||||
assert result[0].type == ToolParameter.ToolParameterType.STRING
|
||||
assert result[0].input_schema is None
|
||||
|
||||
def test_mcp_provider_to_user_provider_for_list(self, mock_provider_full):
|
||||
"""Test mcp_provider_to_user_provider with for_list=True."""
|
||||
# Set tools data with null description
|
||||
|
||||
@ -35,7 +35,7 @@ COPY --from=packages /app/web/ .
|
||||
COPY . .
|
||||
|
||||
ENV NODE_OPTIONS="--max-old-space-size=4096"
|
||||
RUN pnpm build:docker
|
||||
RUN pnpm build
|
||||
|
||||
|
||||
# production stage
|
||||
|
||||
@ -11,7 +11,7 @@ import { useTranslation } from 'react-i18next'
|
||||
import { useContext } from 'use-context-selector'
|
||||
import ImageInput from '@/app/components/base/app-icon-picker/ImageInput'
|
||||
import getCroppedImg from '@/app/components/base/app-icon-picker/utils'
|
||||
import Avatar from '@/app/components/base/avatar'
|
||||
import { Avatar } from '@/app/components/base/avatar'
|
||||
import Button from '@/app/components/base/button'
|
||||
import Divider from '@/app/components/base/divider'
|
||||
import { useLocalFileUploader } from '@/app/components/base/image-uploader/hooks'
|
||||
@ -103,7 +103,7 @@ const AvatarWithEdit = ({ onSave, ...props }: AvatarWithEditProps) => {
|
||||
<>
|
||||
<div>
|
||||
<div className="group relative">
|
||||
<Avatar {...props} onError={(x: boolean) => setOnAvatarError(x)} />
|
||||
<Avatar {...props} onLoadingStatusChange={status => setOnAvatarError(status === 'error')} />
|
||||
<div
|
||||
className="absolute inset-0 flex cursor-pointer items-center justify-center rounded-full bg-black/50 opacity-0 transition-opacity group-hover:opacity-100"
|
||||
onClick={() => {
|
||||
|
||||
@ -4,6 +4,7 @@ import type { App } from '@/types/app'
|
||||
import {
|
||||
RiGraduationCapFill,
|
||||
} from '@remixicon/react'
|
||||
import { useQueryClient } from '@tanstack/react-query'
|
||||
import { useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useContext } from 'use-context-selector'
|
||||
@ -15,11 +16,11 @@ import PremiumBadge from '@/app/components/base/premium-badge'
|
||||
import { ToastContext } from '@/app/components/base/toast/context'
|
||||
import Collapse from '@/app/components/header/account-setting/collapse'
|
||||
import { IS_CE_EDITION, validPassword } from '@/config'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||
import { useProviderContext } from '@/context/provider-context'
|
||||
import { updateUserProfile } from '@/service/common'
|
||||
import { useAppList } from '@/service/use-apps'
|
||||
import { commonQueryKeys, useUserProfile } from '@/service/use-common'
|
||||
import DeleteAccount from '../delete-account'
|
||||
|
||||
import AvatarWithEdit from './AvatarWithEdit'
|
||||
@ -37,7 +38,10 @@ export default function AccountPage() {
|
||||
const { systemFeatures } = useGlobalPublicStore()
|
||||
const { data: appList } = useAppList({ page: 1, limit: 100, name: '' })
|
||||
const apps = appList?.data || []
|
||||
const { mutateUserProfile, userProfile } = useAppContext()
|
||||
const queryClient = useQueryClient()
|
||||
const { data: userProfileResp } = useUserProfile()
|
||||
const userProfile = userProfileResp?.profile
|
||||
const mutateUserProfile = () => queryClient.invalidateQueries({ queryKey: commonQueryKeys.userProfile })
|
||||
const { isEducationAccount } = useProviderContext()
|
||||
const { notify } = useContext(ToastContext)
|
||||
const [editNameModalVisible, setEditNameModalVisible] = useState(false)
|
||||
@ -53,6 +57,9 @@ export default function AccountPage() {
|
||||
const [showConfirmPassword, setShowConfirmPassword] = useState(false)
|
||||
const [showUpdateEmail, setShowUpdateEmail] = useState(false)
|
||||
|
||||
if (!userProfile)
|
||||
return null
|
||||
|
||||
const handleEditName = () => {
|
||||
setEditNameModalVisible(true)
|
||||
setEditName(userProfile.name)
|
||||
@ -149,7 +156,7 @@ export default function AccountPage() {
|
||||
<h4 className="text-text-primary title-2xl-semi-bold">{t('account.myAccount', { ns: 'common' })}</h4>
|
||||
</div>
|
||||
<div className="mb-8 flex items-center rounded-xl bg-gradient-to-r from-background-gradient-bg-fill-chat-bg-2 to-background-gradient-bg-fill-chat-bg-1 p-6">
|
||||
<AvatarWithEdit avatar={userProfile.avatar_url} name={userProfile.name} onSave={mutateUserProfile} size={64} />
|
||||
<AvatarWithEdit avatar={userProfile.avatar_url} name={userProfile.name} onSave={mutateUserProfile} size="3xl" />
|
||||
<div className="ml-4">
|
||||
<p className="text-text-primary system-xl-semibold">
|
||||
{userProfile.name}
|
||||
|
||||
@ -7,12 +7,11 @@ import { useRouter } from 'next/navigation'
|
||||
import { Fragment } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { resetUser } from '@/app/components/base/amplitude/utils'
|
||||
import Avatar from '@/app/components/base/avatar'
|
||||
import { Avatar } from '@/app/components/base/avatar'
|
||||
import { LogOut01 } from '@/app/components/base/icons/src/vender/line/general'
|
||||
import PremiumBadge from '@/app/components/base/premium-badge'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import { useProviderContext } from '@/context/provider-context'
|
||||
import { useLogout } from '@/service/use-common'
|
||||
import { useLogout, useUserProfile } from '@/service/use-common'
|
||||
|
||||
export type IAppSelector = {
|
||||
isMobile: boolean
|
||||
@ -21,10 +20,15 @@ export type IAppSelector = {
|
||||
export default function AppSelector() {
|
||||
const router = useRouter()
|
||||
const { t } = useTranslation()
|
||||
const { userProfile } = useAppContext()
|
||||
const { data: userProfileResp } = useUserProfile()
|
||||
const userProfile = userProfileResp?.profile
|
||||
const { isEducationAccount } = useProviderContext()
|
||||
|
||||
const { mutateAsync: logout } = useLogout()
|
||||
|
||||
if (!userProfile)
|
||||
return null
|
||||
|
||||
const handleLogout = async () => {
|
||||
await logout()
|
||||
|
||||
@ -50,7 +54,7 @@ export default function AppSelector() {
|
||||
${open && 'bg-components-panel-bg-blur'}
|
||||
`}
|
||||
>
|
||||
<Avatar avatar={userProfile.avatar_url} name={userProfile.name} size={32} />
|
||||
<Avatar avatar={userProfile.avatar_url} name={userProfile.name} />
|
||||
</MenuButton>
|
||||
</div>
|
||||
<Transition
|
||||
@ -84,7 +88,7 @@ export default function AppSelector() {
|
||||
</div>
|
||||
<div className="break-all text-text-tertiary system-xs-regular">{userProfile.email}</div>
|
||||
</div>
|
||||
<Avatar avatar={userProfile.avatar_url} name={userProfile.name} size={32} />
|
||||
<Avatar avatar={userProfile.avatar_url} name={userProfile.name} />
|
||||
</div>
|
||||
</div>
|
||||
</MenuItem>
|
||||
|
||||
@ -11,14 +11,13 @@ import { useRouter, useSearchParams } from 'next/navigation'
|
||||
import * as React from 'react'
|
||||
import { useEffect, useRef } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Avatar from '@/app/components/base/avatar'
|
||||
import { Avatar } from '@/app/components/base/avatar'
|
||||
import Button from '@/app/components/base/button'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import { useLanguage } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
||||
import { setPostLoginRedirect } from '@/app/signin/utils/post-login-redirect'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import { useIsLogin } from '@/service/use-common'
|
||||
import { useIsLogin, useUserProfile } from '@/service/use-common'
|
||||
import { useAuthorizeOAuthApp, useOAuthAppInfo } from '@/service/use-oauth'
|
||||
|
||||
function buildReturnUrl(pathname: string, search: string) {
|
||||
@ -62,7 +61,8 @@ export default function OAuthAuthorize() {
|
||||
const searchParams = useSearchParams()
|
||||
const client_id = decodeURIComponent(searchParams.get('client_id') || '')
|
||||
const redirect_uri = decodeURIComponent(searchParams.get('redirect_uri') || '')
|
||||
const { userProfile } = useAppContext()
|
||||
const { data: userProfileResp } = useUserProfile()
|
||||
const userProfile = userProfileResp?.profile
|
||||
const { data: authAppInfo, isLoading: isOAuthLoading, isError } = useOAuthAppInfo(client_id, redirect_uri)
|
||||
const { mutateAsync: authorize, isPending: authorizing } = useAuthorizeOAuthApp()
|
||||
const hasNotifiedRef = useRef(false)
|
||||
@ -138,7 +138,7 @@ export default function OAuthAuthorize() {
|
||||
{isLoggedIn && userProfile && (
|
||||
<div className="flex items-center justify-between rounded-xl bg-background-section-burn-inverted p-3">
|
||||
<div className="flex items-center gap-2.5">
|
||||
<Avatar avatar={userProfile.avatar_url} name={userProfile.name} size={36} />
|
||||
<Avatar avatar={userProfile.avatar_url} name={userProfile.name} size="lg" />
|
||||
<div>
|
||||
<div className="system-md-semi-bold text-text-secondary">{userProfile.name}</div>
|
||||
<div className="text-text-tertiary system-xs-regular">{userProfile.email}</div>
|
||||
|
||||
@ -10,7 +10,7 @@ import { SubjectType } from '@/models/access-control'
|
||||
import { useSearchForWhiteListCandidates } from '@/service/access-control'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import useAccessControlStore from '../../../../context/access-control-store'
|
||||
import Avatar from '../../base/avatar'
|
||||
import { Avatar } from '../../base/avatar'
|
||||
import Button from '../../base/button'
|
||||
import Checkbox from '../../base/checkbox'
|
||||
import Input from '../../base/input'
|
||||
@ -203,7 +203,7 @@ function MemberItem({ member }: MemberItemProps) {
|
||||
<div className="flex grow items-center">
|
||||
<div className="mr-2 h-5 w-5 overflow-hidden rounded-full bg-components-icon-bg-blue-solid">
|
||||
<div className="bg-access-app-icon-mask-bg flex h-full w-full items-center justify-center">
|
||||
<Avatar className="h-[14px] w-[14px]" textClassName="text-[12px]" avatar={null} name={member.name} />
|
||||
<Avatar size="xxs" avatar={null} name={member.name} />
|
||||
</div>
|
||||
</div>
|
||||
<p className="mr-1 text-text-secondary system-sm-medium">{member.name}</p>
|
||||
|
||||
@ -6,7 +6,7 @@ import { useTranslation } from 'react-i18next'
|
||||
import { AccessMode } from '@/models/access-control'
|
||||
import { useAppWhiteListSubjects } from '@/service/access-control'
|
||||
import useAccessControlStore from '../../../../context/access-control-store'
|
||||
import Avatar from '../../base/avatar'
|
||||
import { Avatar } from '../../base/avatar'
|
||||
import Loading from '../../base/loading'
|
||||
import Tooltip from '../../base/tooltip'
|
||||
import AddMemberOrGroupDialog from './add-member-or-group-pop'
|
||||
@ -106,7 +106,7 @@ function MemberItem({ member }: MemberItemProps) {
|
||||
}, [member, setSpecificMembers, specificMembers])
|
||||
return (
|
||||
<BaseItem
|
||||
icon={<Avatar className="h-[14px] w-[14px]" textClassName="text-[12px]" avatar={null} name={member.name} />}
|
||||
icon={<Avatar size="xxs" avatar={null} name={member.name} />}
|
||||
onRemove={handleRemoveMember}
|
||||
>
|
||||
<p className="text-text-primary system-xs-regular">{member.name}</p>
|
||||
|
||||
@ -91,7 +91,7 @@ vi.mock('@/app/components/base/chat/chat', () => ({
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/avatar', () => ({
|
||||
default: ({ name }: { name: string }) => <div data-testid="avatar">{name}</div>,
|
||||
Avatar: ({ name }: { name: string }) => <div data-testid="avatar">{name}</div>,
|
||||
}))
|
||||
|
||||
const createModelAndParameter = (overrides: Partial<ModelAndParameter> = {}): ModelAndParameter => ({
|
||||
|
||||
@ -7,7 +7,7 @@ import {
|
||||
useCallback,
|
||||
useMemo,
|
||||
} from 'react'
|
||||
import Avatar from '@/app/components/base/avatar'
|
||||
import { Avatar } from '@/app/components/base/avatar'
|
||||
import Chat from '@/app/components/base/chat/chat'
|
||||
import { useChat } from '@/app/components/base/chat/chat/hooks'
|
||||
import { getLastAnswer } from '@/app/components/base/chat/utils'
|
||||
@ -149,7 +149,7 @@ const ChatItem: FC<ChatItemProps> = ({
|
||||
suggestedQuestions={suggestedQuestions}
|
||||
onSend={doSend}
|
||||
showPromptLog
|
||||
questionIcon={<Avatar avatar={userProfile.avatar_url} name={userProfile.name} size={40} />}
|
||||
questionIcon={<Avatar avatar={userProfile.avatar_url} name={userProfile.name} size="xl" />}
|
||||
allToolIcons={allToolIcons}
|
||||
hideLogModal
|
||||
noSpacing
|
||||
|
||||
@ -3,7 +3,7 @@ import type { ChatConfig, ChatItem, OnSend } from '@/app/components/base/chat/ty
|
||||
import type { FileEntity } from '@/app/components/base/file-uploader/types'
|
||||
import { memo, useCallback, useImperativeHandle, useMemo } from 'react'
|
||||
import { useStore as useAppStore } from '@/app/components/app/store'
|
||||
import Avatar from '@/app/components/base/avatar'
|
||||
import { Avatar } from '@/app/components/base/avatar'
|
||||
import Chat from '@/app/components/base/chat/chat'
|
||||
import { useChat } from '@/app/components/base/chat/chat/hooks'
|
||||
import { getLastAnswer, isValidGeneratedAnswer } from '@/app/components/base/chat/utils'
|
||||
@ -168,7 +168,7 @@ const DebugWithSingleModel = (
|
||||
switchSibling={siblingMessageId => setTargetMessageId(siblingMessageId)}
|
||||
onStopResponding={handleStop}
|
||||
showPromptLog
|
||||
questionIcon={<Avatar avatar={userProfile.avatar_url} name={userProfile.name} size={40} />}
|
||||
questionIcon={<Avatar avatar={userProfile.avatar_url} name={userProfile.name} size="xl" />}
|
||||
allToolIcons={allToolIcons}
|
||||
onAnnotationEdited={handleAnnotationEdited}
|
||||
onAnnotationAdded={handleAnnotationAdded}
|
||||
|
||||
@ -1,308 +1,114 @@
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import Avatar from '../index'
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import { Avatar } from '../index'
|
||||
|
||||
describe('Avatar', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
// Rendering tests - verify component renders correctly in different states
|
||||
describe('Rendering', () => {
|
||||
it('should render img element with correct alt and src when avatar URL is provided', () => {
|
||||
const avatarUrl = 'https://example.com/avatar.jpg'
|
||||
const props = { name: 'John Doe', avatar: avatarUrl }
|
||||
|
||||
render(<Avatar {...props} />)
|
||||
it('should render img element when avatar URL is provided', () => {
|
||||
render(<Avatar name="John Doe" avatar="https://example.com/avatar.jpg" />)
|
||||
|
||||
const img = screen.getByRole('img', { name: 'John Doe' })
|
||||
expect(img).toBeInTheDocument()
|
||||
expect(img).toHaveAttribute('src', avatarUrl)
|
||||
expect(img).toHaveAttribute('src', 'https://example.com/avatar.jpg')
|
||||
})
|
||||
|
||||
it('should render fallback div with uppercase initial when avatar is null', () => {
|
||||
const props = { name: 'alice', avatar: null }
|
||||
|
||||
render(<Avatar {...props} />)
|
||||
it('should render fallback with uppercase initial when avatar is null', () => {
|
||||
render(<Avatar name="alice" avatar={null} />)
|
||||
|
||||
expect(screen.queryByRole('img')).not.toBeInTheDocument()
|
||||
expect(screen.getByText('A')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
// Props tests - verify all props are applied correctly
|
||||
describe('Props', () => {
|
||||
describe('size prop', () => {
|
||||
it.each([
|
||||
{ size: undefined, expected: '30px', label: 'default (30px)' },
|
||||
{ size: 50, expected: '50px', label: 'custom (50px)' },
|
||||
])('should apply $label size to img element', ({ size, expected }) => {
|
||||
const props = { name: 'Test', avatar: 'https://example.com/avatar.jpg', size }
|
||||
it('should render both image and fallback when avatar is provided', () => {
|
||||
render(<Avatar name="John" avatar="https://example.com/avatar.jpg" />)
|
||||
|
||||
render(<Avatar {...props} />)
|
||||
|
||||
expect(screen.getByRole('img')).toHaveStyle({
|
||||
width: expected,
|
||||
height: expected,
|
||||
fontSize: expected,
|
||||
lineHeight: expected,
|
||||
})
|
||||
})
|
||||
|
||||
it('should apply size to fallback div when avatar is null', () => {
|
||||
const props = { name: 'Test', avatar: null, size: 40 }
|
||||
|
||||
render(<Avatar {...props} />)
|
||||
|
||||
const textElement = screen.getByText('T')
|
||||
const outerDiv = textElement.parentElement as HTMLElement
|
||||
expect(outerDiv).toHaveStyle({ width: '40px', height: '40px' })
|
||||
})
|
||||
})
|
||||
|
||||
describe('className prop', () => {
|
||||
it('should merge className with default avatar classes on img', () => {
|
||||
const props = {
|
||||
name: 'Test',
|
||||
avatar: 'https://example.com/avatar.jpg',
|
||||
className: 'custom-class',
|
||||
}
|
||||
|
||||
render(<Avatar {...props} />)
|
||||
|
||||
const img = screen.getByRole('img')
|
||||
expect(img).toHaveClass('custom-class')
|
||||
expect(img).toHaveClass('shrink-0', 'flex', 'items-center', 'rounded-full', 'bg-primary-600')
|
||||
})
|
||||
|
||||
it('should merge className with default avatar classes on fallback div', () => {
|
||||
const props = {
|
||||
name: 'Test',
|
||||
avatar: null,
|
||||
className: 'my-custom-class',
|
||||
}
|
||||
|
||||
render(<Avatar {...props} />)
|
||||
|
||||
const textElement = screen.getByText('T')
|
||||
const outerDiv = textElement.parentElement as HTMLElement
|
||||
expect(outerDiv).toHaveClass('my-custom-class')
|
||||
expect(outerDiv).toHaveClass('shrink-0', 'flex', 'items-center', 'rounded-full', 'bg-primary-600')
|
||||
})
|
||||
})
|
||||
|
||||
describe('textClassName prop', () => {
|
||||
it('should apply textClassName to the initial text element', () => {
|
||||
const props = {
|
||||
name: 'Test',
|
||||
avatar: null,
|
||||
textClassName: 'custom-text-class',
|
||||
}
|
||||
|
||||
render(<Avatar {...props} />)
|
||||
|
||||
const textElement = screen.getByText('T')
|
||||
expect(textElement).toHaveClass('custom-text-class')
|
||||
expect(textElement).toHaveClass('scale-[0.4]', 'text-center', 'text-white')
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
// State Management tests - verify useState and useEffect behavior
|
||||
describe('State Management', () => {
|
||||
it('should switch to fallback when image fails to load', async () => {
|
||||
const props = { name: 'John', avatar: 'https://example.com/broken.jpg' }
|
||||
render(<Avatar {...props} />)
|
||||
const img = screen.getByRole('img')
|
||||
|
||||
fireEvent.error(img)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.queryByRole('img')).not.toBeInTheDocument()
|
||||
})
|
||||
expect(screen.getByText('J')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should reset error state when avatar URL changes', async () => {
|
||||
const initialProps = { name: 'John', avatar: 'https://example.com/broken.jpg' }
|
||||
const { rerender } = render(<Avatar {...initialProps} />)
|
||||
const img = screen.getByRole('img')
|
||||
|
||||
// First, trigger error
|
||||
fireEvent.error(img)
|
||||
await waitFor(() => {
|
||||
expect(screen.queryByRole('img')).not.toBeInTheDocument()
|
||||
})
|
||||
expect(screen.getByText('J')).toBeInTheDocument()
|
||||
|
||||
rerender(<Avatar name="John" avatar="https://example.com/new-avatar.jpg" />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByRole('img')).toBeInTheDocument()
|
||||
})
|
||||
expect(screen.queryByText('J')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not reset error state if avatar becomes null', async () => {
|
||||
const initialProps = { name: 'John', avatar: 'https://example.com/broken.jpg' }
|
||||
const { rerender } = render(<Avatar {...initialProps} />)
|
||||
|
||||
// Trigger error
|
||||
fireEvent.error(screen.getByRole('img'))
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('J')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
rerender(<Avatar name="John" avatar={null} />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.queryByRole('img')).not.toBeInTheDocument()
|
||||
})
|
||||
expect(screen.getByRole('img')).toBeInTheDocument()
|
||||
expect(screen.getByText('J')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
// Event Handlers tests - verify onError callback behavior
|
||||
describe('Event Handlers', () => {
|
||||
it('should call onError with true when image fails to load', () => {
|
||||
const onErrorMock = vi.fn()
|
||||
const props = {
|
||||
name: 'John',
|
||||
avatar: 'https://example.com/broken.jpg',
|
||||
onError: onErrorMock,
|
||||
}
|
||||
render(<Avatar {...props} />)
|
||||
describe('Size variants', () => {
|
||||
it.each([
|
||||
{ size: 'xxs' as const, expectedClass: 'size-4' },
|
||||
{ size: 'xs' as const, expectedClass: 'size-5' },
|
||||
{ size: 'sm' as const, expectedClass: 'size-6' },
|
||||
{ size: 'md' as const, expectedClass: 'size-8' },
|
||||
{ size: 'lg' as const, expectedClass: 'size-9' },
|
||||
{ size: 'xl' as const, expectedClass: 'size-10' },
|
||||
{ size: '2xl' as const, expectedClass: 'size-12' },
|
||||
{ size: '3xl' as const, expectedClass: 'size-16' },
|
||||
])('should apply $expectedClass for size="$size"', ({ size, expectedClass }) => {
|
||||
const { container } = render(<Avatar name="Test" avatar={null} size={size} />)
|
||||
|
||||
fireEvent.error(screen.getByRole('img'))
|
||||
|
||||
expect(onErrorMock).toHaveBeenCalledTimes(1)
|
||||
expect(onErrorMock).toHaveBeenCalledWith(true)
|
||||
const root = container.firstElementChild as HTMLElement
|
||||
expect(root).toHaveClass(expectedClass)
|
||||
})
|
||||
|
||||
it('should call onError with false when image loads successfully', () => {
|
||||
const onErrorMock = vi.fn()
|
||||
const props = {
|
||||
name: 'John',
|
||||
avatar: 'https://example.com/avatar.jpg',
|
||||
onError: onErrorMock,
|
||||
}
|
||||
render(<Avatar {...props} />)
|
||||
it('should default to md size when size is not specified', () => {
|
||||
const { container } = render(<Avatar name="Test" avatar={null} />)
|
||||
|
||||
fireEvent.load(screen.getByRole('img'))
|
||||
|
||||
expect(onErrorMock).toHaveBeenCalledTimes(1)
|
||||
expect(onErrorMock).toHaveBeenCalledWith(false)
|
||||
})
|
||||
|
||||
it('should not throw when onError is not provided', async () => {
|
||||
const props = { name: 'John', avatar: 'https://example.com/broken.jpg' }
|
||||
render(<Avatar {...props} />)
|
||||
|
||||
expect(() => fireEvent.error(screen.getByRole('img'))).not.toThrow()
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('J')).toBeInTheDocument()
|
||||
})
|
||||
const root = container.firstElementChild as HTMLElement
|
||||
expect(root).toHaveClass('size-8')
|
||||
})
|
||||
})
|
||||
|
||||
describe('className prop', () => {
|
||||
it('should merge className with avatar variant classes on root', () => {
|
||||
const { container } = render(
|
||||
<Avatar name="Test" avatar={null} className="custom-class" />,
|
||||
)
|
||||
|
||||
const root = container.firstElementChild as HTMLElement
|
||||
expect(root).toHaveClass('custom-class')
|
||||
expect(root).toHaveClass('rounded-full', 'bg-primary-600')
|
||||
})
|
||||
})
|
||||
|
||||
// Edge Cases tests - verify handling of unusual inputs
|
||||
describe('Edge Cases', () => {
|
||||
it('should handle empty string name gracefully', () => {
|
||||
const props = { name: '', avatar: null }
|
||||
const { container } = render(<Avatar name="" avatar={null} />)
|
||||
|
||||
const { container } = render(<Avatar {...props} />)
|
||||
|
||||
// Note: Using querySelector here because empty name produces no visible text,
|
||||
// making semantic queries (getByRole, getByText) impossible
|
||||
const textElement = container.querySelector('.text-white') as HTMLElement
|
||||
expect(textElement).toBeInTheDocument()
|
||||
expect(textElement.textContent).toBe('')
|
||||
const fallback = container.querySelector('.text-white') as HTMLElement
|
||||
expect(fallback).toBeInTheDocument()
|
||||
expect(fallback.textContent).toBe('')
|
||||
})
|
||||
|
||||
it.each([
|
||||
{ name: '中文名', expected: '中', label: 'Chinese characters' },
|
||||
{ name: '123User', expected: '1', label: 'number' },
|
||||
])('should display first character when name starts with $label', ({ name, expected }) => {
|
||||
const props = { name, avatar: null }
|
||||
|
||||
render(<Avatar {...props} />)
|
||||
render(<Avatar name={name} avatar={null} />)
|
||||
|
||||
expect(screen.getByText(expected)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle empty string avatar as falsy value', () => {
|
||||
const props = { name: 'Test', avatar: '' as string | null }
|
||||
|
||||
render(<Avatar {...props} />)
|
||||
render(<Avatar name="Test" avatar={'' as string | null} />)
|
||||
|
||||
expect(screen.queryByRole('img')).not.toBeInTheDocument()
|
||||
expect(screen.getByText('T')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle undefined className and textClassName', () => {
|
||||
const props = { name: 'Test', avatar: null }
|
||||
|
||||
render(<Avatar {...props} />)
|
||||
|
||||
const textElement = screen.getByText('T')
|
||||
const outerDiv = textElement.parentElement as HTMLElement
|
||||
expect(outerDiv).toHaveClass('shrink-0', 'flex', 'items-center', 'rounded-full', 'bg-primary-600')
|
||||
})
|
||||
|
||||
it.each([
|
||||
{ size: 0, expected: '0px', label: 'zero' },
|
||||
{ size: 1000, expected: '1000px', label: 'very large' },
|
||||
])('should handle $label size value', ({ size, expected }) => {
|
||||
const props = { name: 'Test', avatar: null, size }
|
||||
|
||||
render(<Avatar {...props} />)
|
||||
|
||||
const textElement = screen.getByText('T')
|
||||
const outerDiv = textElement.parentElement as HTMLElement
|
||||
expect(outerDiv).toHaveStyle({ width: expected, height: expected })
|
||||
})
|
||||
})
|
||||
|
||||
// Combined props tests - verify props work together correctly
|
||||
describe('Combined Props', () => {
|
||||
it('should apply all props correctly when used together', () => {
|
||||
const onErrorMock = vi.fn()
|
||||
const props = {
|
||||
name: 'Test User',
|
||||
avatar: 'https://example.com/avatar.jpg',
|
||||
size: 64,
|
||||
className: 'custom-avatar',
|
||||
onError: onErrorMock,
|
||||
}
|
||||
describe('onLoadingStatusChange', () => {
|
||||
it('should render image when avatar and onLoadingStatusChange are provided', () => {
|
||||
render(
|
||||
<Avatar
|
||||
name="John"
|
||||
avatar="https://example.com/avatar.jpg"
|
||||
onLoadingStatusChange={vi.fn()}
|
||||
/>,
|
||||
)
|
||||
|
||||
render(<Avatar {...props} />)
|
||||
|
||||
const img = screen.getByRole('img')
|
||||
expect(img).toHaveAttribute('alt', 'Test User')
|
||||
expect(img).toHaveAttribute('src', 'https://example.com/avatar.jpg')
|
||||
expect(img).toHaveStyle({ width: '64px', height: '64px' })
|
||||
expect(img).toHaveClass('custom-avatar')
|
||||
|
||||
// Trigger load to verify onError callback
|
||||
fireEvent.load(img)
|
||||
expect(onErrorMock).toHaveBeenCalledWith(false)
|
||||
expect(screen.getByRole('img')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should apply all fallback props correctly when used together', () => {
|
||||
const props = {
|
||||
name: 'Fallback User',
|
||||
avatar: null,
|
||||
size: 48,
|
||||
className: 'fallback-custom',
|
||||
textClassName: 'custom-text-style',
|
||||
}
|
||||
it('should not render image when avatar is null even with onLoadingStatusChange', () => {
|
||||
const onStatusChange = vi.fn()
|
||||
render(
|
||||
<Avatar name="John" avatar={null} onLoadingStatusChange={onStatusChange} />,
|
||||
)
|
||||
|
||||
render(<Avatar {...props} />)
|
||||
|
||||
const textElement = screen.getByText('F')
|
||||
const outerDiv = textElement.parentElement as HTMLElement
|
||||
expect(outerDiv).toHaveClass('fallback-custom')
|
||||
expect(outerDiv).toHaveStyle({ width: '48px', height: '48px' })
|
||||
expect(textElement).toHaveClass('custom-text-style')
|
||||
expect(screen.queryByRole('img')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs-vite'
|
||||
import Avatar from '.'
|
||||
import { Avatar } from '.'
|
||||
|
||||
const meta = {
|
||||
title: 'Base/Data Display/Avatar',
|
||||
@ -7,12 +7,12 @@ const meta = {
|
||||
parameters: {
|
||||
docs: {
|
||||
description: {
|
||||
component: 'Initials or image-based avatar used across contacts and member lists. Falls back to the first letter when the image fails to load.',
|
||||
component: 'Initials or image-based avatar built on Base UI. Falls back to the first letter when the image fails to load.',
|
||||
},
|
||||
source: {
|
||||
language: 'tsx',
|
||||
code: `
|
||||
<Avatar name="Alex Doe" avatar="https://cloud.dify.ai/logo/logo.svg" size={40} />
|
||||
<Avatar name="Alex Doe" avatar="https://i.pravatar.cc/96?u=avatar-default" size="xl" />
|
||||
`.trim(),
|
||||
},
|
||||
},
|
||||
@ -20,8 +20,8 @@ const meta = {
|
||||
tags: ['autodocs'],
|
||||
args: {
|
||||
name: 'Alex Doe',
|
||||
avatar: 'https://cloud.dify.ai/logo/logo.svg',
|
||||
size: 40,
|
||||
avatar: 'https://i.pravatar.cc/96?u=avatar-default',
|
||||
size: 'xl',
|
||||
},
|
||||
} satisfies Meta<typeof Avatar>
|
||||
|
||||
@ -40,23 +40,20 @@ export const WithFallback: Story = {
|
||||
source: {
|
||||
language: 'tsx',
|
||||
code: `
|
||||
<Avatar name="Fallback" avatar={null} size={40} />
|
||||
<Avatar name="Fallback" avatar={null} size="xl" />
|
||||
`.trim(),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
export const CustomSizes: Story = {
|
||||
export const AllSizes: Story = {
|
||||
render: args => (
|
||||
<div className="flex items-end gap-4">
|
||||
{[24, 32, 48, 64].map(size => (
|
||||
{(['xxs', 'xs', 'sm', 'md', 'lg', 'xl', '2xl', '3xl'] as const).map(size => (
|
||||
<div key={size} className="flex flex-col items-center gap-2">
|
||||
<Avatar {...args} size={size} avatar="https://i.pravatar.cc/96?u=size-test" />
|
||||
<span className="text-xs text-text-tertiary">
|
||||
{size}
|
||||
px
|
||||
</span>
|
||||
<span className="text-xs text-text-tertiary">{size}</span>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
@ -66,7 +63,7 @@ export const CustomSizes: Story = {
|
||||
source: {
|
||||
language: 'tsx',
|
||||
code: `
|
||||
{[24, 32, 48, 64].map(size => (
|
||||
{(['xxs', 'xs', 'sm', 'md', 'lg', 'xl', '2xl', '3xl'] as const).map(size => (
|
||||
<Avatar key={size} name="Size Test" size={size} avatar="https://i.pravatar.cc/96?u=size-test" />
|
||||
))}
|
||||
`.trim(),
|
||||
@ -74,3 +71,16 @@ export const CustomSizes: Story = {
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
export const AllFallbackSizes: Story = {
|
||||
render: args => (
|
||||
<div className="flex items-end gap-4">
|
||||
{(['xxs', 'xs', 'sm', 'md', 'lg', 'xl', '2xl', '3xl'] as const).map(size => (
|
||||
<div key={size} className="flex flex-col items-center gap-2">
|
||||
<Avatar {...args} size={size} avatar={null} name="Alex" />
|
||||
<span className="text-xs text-text-tertiary">{size}</span>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
),
|
||||
}
|
||||
|
||||
@ -1,64 +1,52 @@
|
||||
'use client'
|
||||
import { useEffect, useState } from 'react'
|
||||
import type { ImageLoadingStatus } from '@base-ui/react/avatar'
|
||||
import { Avatar as BaseAvatar } from '@base-ui/react/avatar'
|
||||
import { cn } from '@/utils/classnames'
|
||||
|
||||
const SIZES = {
|
||||
'xxs': { root: 'size-4', text: 'text-[7px]' },
|
||||
'xs': { root: 'size-5', text: 'text-[8px]' },
|
||||
'sm': { root: 'size-6', text: 'text-[10px]' },
|
||||
'md': { root: 'size-8', text: 'text-xs' },
|
||||
'lg': { root: 'size-9', text: 'text-sm' },
|
||||
'xl': { root: 'size-10', text: 'text-base' },
|
||||
'2xl': { root: 'size-12', text: 'text-xl' },
|
||||
'3xl': { root: 'size-16', text: 'text-2xl' },
|
||||
} as const
|
||||
|
||||
export type AvatarSize = keyof typeof SIZES
|
||||
|
||||
export type AvatarProps = {
|
||||
name: string
|
||||
avatar: string | null
|
||||
size?: number
|
||||
size?: AvatarSize
|
||||
className?: string
|
||||
textClassName?: string
|
||||
onError?: (x: boolean) => void
|
||||
onLoadingStatusChange?: (status: ImageLoadingStatus) => void
|
||||
}
|
||||
const Avatar = ({
|
||||
|
||||
const BASE_CLASS = 'relative inline-flex shrink-0 select-none items-center justify-center overflow-hidden rounded-full bg-primary-600'
|
||||
|
||||
export const Avatar = ({
|
||||
name,
|
||||
avatar,
|
||||
size = 30,
|
||||
size = 'md',
|
||||
className,
|
||||
textClassName,
|
||||
onError,
|
||||
onLoadingStatusChange,
|
||||
}: AvatarProps) => {
|
||||
const avatarClassName = 'shrink-0 flex items-center rounded-full bg-primary-600'
|
||||
const style = { width: `${size}px`, height: `${size}px`, fontSize: `${size}px`, lineHeight: `${size}px` }
|
||||
const [imgError, setImgError] = useState(false)
|
||||
|
||||
const handleError = () => {
|
||||
setImgError(true)
|
||||
onError?.(true)
|
||||
}
|
||||
|
||||
// after uploaded, api would first return error imgs url: '.../files//file-preview/...'. Then return the right url, Which caused not show the avatar
|
||||
useEffect(() => {
|
||||
if (avatar && imgError)
|
||||
setImgError(false)
|
||||
}, [avatar])
|
||||
|
||||
if (avatar && !imgError) {
|
||||
return (
|
||||
<img
|
||||
className={cn(avatarClassName, className)}
|
||||
style={style}
|
||||
alt={name}
|
||||
src={avatar}
|
||||
onError={handleError}
|
||||
onLoad={() => onError?.(false)}
|
||||
/>
|
||||
)
|
||||
}
|
||||
const sizeConfig = SIZES[size]
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(avatarClassName, className)}
|
||||
style={style}
|
||||
>
|
||||
<div
|
||||
className={cn(textClassName, 'scale-[0.4] text-center text-white')}
|
||||
style={style}
|
||||
>
|
||||
{name && name[0].toLocaleUpperCase()}
|
||||
</div>
|
||||
</div>
|
||||
<BaseAvatar.Root className={cn(BASE_CLASS, sizeConfig.root, className)}>
|
||||
{avatar && (
|
||||
<BaseAvatar.Image
|
||||
src={avatar}
|
||||
alt={name}
|
||||
className="absolute inset-0 size-full object-cover"
|
||||
onLoadingStatusChange={onLoadingStatusChange}
|
||||
/>
|
||||
)}
|
||||
<BaseAvatar.Fallback className={cn('font-medium text-white', sizeConfig.text)}>
|
||||
{name?.[0]?.toLocaleUpperCase()}
|
||||
</BaseAvatar.Fallback>
|
||||
</BaseAvatar.Root>
|
||||
)
|
||||
}
|
||||
|
||||
export default Avatar
|
||||
|
||||
@ -23,7 +23,7 @@ import { submitHumanInputForm as submitHumanInputFormService } from '@/service/w
|
||||
import { TransferMethod } from '@/types/app'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { formatBooleanInputs } from '@/utils/model-config'
|
||||
import Avatar from '../../avatar'
|
||||
import { Avatar } from '../../avatar'
|
||||
import Chat from '../chat'
|
||||
import { useChat } from '../chat/hooks'
|
||||
import { getLastAnswer, isValidGeneratedAnswer } from '../utils'
|
||||
@ -351,7 +351,7 @@ const ChatWrapper = () => {
|
||||
<Avatar
|
||||
avatar={initUserVariables.avatar_url}
|
||||
name={initUserVariables.name || 'user'}
|
||||
size={40}
|
||||
size="xl"
|
||||
/>
|
||||
)
|
||||
: undefined
|
||||
|
||||
@ -23,7 +23,7 @@ import {
|
||||
import { submitHumanInputForm as submitHumanInputFormService } from '@/service/workflow'
|
||||
import { TransferMethod } from '@/types/app'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import Avatar from '../../avatar'
|
||||
import { Avatar } from '../../avatar'
|
||||
import Chat from '../chat'
|
||||
import { useChat } from '../chat/hooks'
|
||||
import { getLastAnswer, isValidGeneratedAnswer } from '../utils'
|
||||
@ -337,7 +337,7 @@ const ChatWrapper = () => {
|
||||
<Avatar
|
||||
avatar={initUserVariables.avatar_url}
|
||||
name={initUserVariables.name || 'user'}
|
||||
size={40}
|
||||
size="xl"
|
||||
/>
|
||||
)
|
||||
: undefined
|
||||
|
||||
@ -20,17 +20,21 @@ const OnBlurBlock: FC<OnBlurBlockProps> = ({
|
||||
}) => {
|
||||
const [editor] = useLexicalComposerContext()
|
||||
|
||||
const ref = useRef<any>(null)
|
||||
const ref = useRef<ReturnType<typeof setTimeout> | null>(null)
|
||||
|
||||
useEffect(() => {
|
||||
return mergeRegister(
|
||||
const clearHideMenuTimeout = () => {
|
||||
if (ref.current) {
|
||||
clearTimeout(ref.current)
|
||||
ref.current = null
|
||||
}
|
||||
}
|
||||
|
||||
const unregister = mergeRegister(
|
||||
editor.registerCommand(
|
||||
CLEAR_HIDE_MENU_TIMEOUT,
|
||||
() => {
|
||||
if (ref.current) {
|
||||
clearTimeout(ref.current)
|
||||
ref.current = null
|
||||
}
|
||||
clearHideMenuTimeout()
|
||||
return true
|
||||
},
|
||||
COMMAND_PRIORITY_EDITOR,
|
||||
@ -41,6 +45,7 @@ const OnBlurBlock: FC<OnBlurBlockProps> = ({
|
||||
// Check if the clicked target element is var-search-input
|
||||
const target = event?.relatedTarget as HTMLElement
|
||||
if (!target?.classList?.contains('var-search-input')) {
|
||||
clearHideMenuTimeout()
|
||||
ref.current = setTimeout(() => {
|
||||
editor.dispatchCommand(KEY_ESCAPE_COMMAND, new KeyboardEvent('keydown', { key: 'Escape' }))
|
||||
}, 200)
|
||||
@ -61,6 +66,11 @@ const OnBlurBlock: FC<OnBlurBlockProps> = ({
|
||||
COMMAND_PRIORITY_EDITOR,
|
||||
),
|
||||
)
|
||||
|
||||
return () => {
|
||||
clearHideMenuTimeout()
|
||||
unregister()
|
||||
}
|
||||
}, [editor, onBlur, onFocus])
|
||||
|
||||
return null
|
||||
|
||||
@ -4,7 +4,7 @@ import { useDebounceFn } from 'ahooks'
|
||||
import * as React from 'react'
|
||||
import { useCallback, useMemo, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Avatar from '@/app/components/base/avatar'
|
||||
import { Avatar } from '@/app/components/base/avatar'
|
||||
import Input from '@/app/components/base/input'
|
||||
import {
|
||||
PortalToFollowElem,
|
||||
@ -106,7 +106,7 @@ const PermissionSelector = ({
|
||||
isOnlyMe && (
|
||||
<>
|
||||
<div className="flex size-6 shrink-0 items-center justify-center">
|
||||
<Avatar avatar={userProfile.avatar_url} name={userProfile.name} size={20} />
|
||||
<Avatar avatar={userProfile.avatar_url} name={userProfile.name} size="xs" />
|
||||
</div>
|
||||
<div className="grow p-1 text-components-input-text-filled system-sm-regular">
|
||||
{t('form.permissionsOnlyMe', { ns: 'datasetSettings' })}
|
||||
@ -135,7 +135,7 @@ const PermissionSelector = ({
|
||||
<Avatar
|
||||
avatar={selectedMembers[0].avatar_url}
|
||||
name={selectedMembers[0].name}
|
||||
size={20}
|
||||
size="xs"
|
||||
/>
|
||||
)
|
||||
}
|
||||
@ -146,13 +146,13 @@ const PermissionSelector = ({
|
||||
avatar={selectedMembers[0].avatar_url}
|
||||
name={selectedMembers[0].name}
|
||||
className="absolute left-0 top-0 z-0"
|
||||
size={16}
|
||||
size="xxs"
|
||||
/>
|
||||
<Avatar
|
||||
avatar={selectedMembers[1].avatar_url}
|
||||
name={selectedMembers[1].name}
|
||||
className="absolute bottom-0 right-0 z-10"
|
||||
size={16}
|
||||
size="xxs"
|
||||
/>
|
||||
</>
|
||||
)
|
||||
@ -182,7 +182,7 @@ const PermissionSelector = ({
|
||||
{/* Only me */}
|
||||
<Item
|
||||
leftIcon={
|
||||
<Avatar avatar={userProfile.avatar_url} name={userProfile.name} className="shrink-0" size={24} />
|
||||
<Avatar avatar={userProfile.avatar_url} name={userProfile.name} className="shrink-0" size="sm" />
|
||||
}
|
||||
text={t('form.permissionsOnlyMe', { ns: 'datasetSettings' })}
|
||||
onClick={onSelectOnlyMe}
|
||||
@ -226,7 +226,7 @@ const PermissionSelector = ({
|
||||
{showMe && (
|
||||
<MemberItem
|
||||
leftIcon={
|
||||
<Avatar avatar={userProfile.avatar_url} name={userProfile.name} className="shrink-0" size={24} />
|
||||
<Avatar avatar={userProfile.avatar_url} name={userProfile.name} className="shrink-0" size="sm" />
|
||||
}
|
||||
name={userProfile.name}
|
||||
email={userProfile.email}
|
||||
@ -237,7 +237,7 @@ const PermissionSelector = ({
|
||||
{filteredMemberList.map(member => (
|
||||
<MemberItem
|
||||
leftIcon={
|
||||
<Avatar avatar={member.avatar_url} name={member.name} className="shrink-0" size={24} />
|
||||
<Avatar avatar={member.avatar_url} name={member.name} className="shrink-0" size="sm" />
|
||||
}
|
||||
name={member.name}
|
||||
email={member.email}
|
||||
|
||||
@ -6,7 +6,7 @@ import { useRouter } from 'next/navigation'
|
||||
import { useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { resetUser } from '@/app/components/base/amplitude/utils'
|
||||
import Avatar from '@/app/components/base/avatar'
|
||||
import { Avatar } from '@/app/components/base/avatar'
|
||||
import PremiumBadge from '@/app/components/base/premium-badge'
|
||||
import ThemeSwitcher from '@/app/components/base/theme-switcher'
|
||||
import { DropdownMenu, DropdownMenuContent, DropdownMenuGroup, DropdownMenuItem, DropdownMenuLinkItem, DropdownMenuSeparator, DropdownMenuTrigger } from '@/app/components/base/ui/dropdown-menu'
|
||||
@ -140,7 +140,7 @@ export default function AppSelector() {
|
||||
aria-label={t('account.account', { ns: 'common' })}
|
||||
className={cn('inline-flex items-center rounded-[20px] p-0.5 hover:bg-background-default-dodge', isAccountMenuOpen && 'bg-background-default-dodge')}
|
||||
>
|
||||
<Avatar avatar={userProfile.avatar_url} name={userProfile.name} size={36} />
|
||||
<Avatar avatar={userProfile.avatar_url} name={userProfile.name} size="lg" />
|
||||
</DropdownMenuTrigger>
|
||||
<DropdownMenuContent
|
||||
sideOffset={6}
|
||||
@ -160,7 +160,7 @@ export default function AppSelector() {
|
||||
</div>
|
||||
<div className="break-all text-text-tertiary system-xs-regular">{userProfile.email}</div>
|
||||
</div>
|
||||
<Avatar avatar={userProfile.avatar_url} name={userProfile.name} size={36} />
|
||||
<Avatar avatar={userProfile.avatar_url} name={userProfile.name} size="lg" />
|
||||
</div>
|
||||
<AccountMenuRouteItem
|
||||
href="/account"
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
import type { InvitationResult } from '@/models/common'
|
||||
import { useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Avatar from '@/app/components/base/avatar'
|
||||
import { Avatar } from '@/app/components/base/avatar'
|
||||
import Tooltip from '@/app/components/base/tooltip'
|
||||
import { NUM_INFINITE } from '@/app/components/billing/config'
|
||||
import { Plan } from '@/app/components/billing/type'
|
||||
@ -120,7 +120,7 @@ const MembersPage = () => {
|
||||
accounts.map(account => (
|
||||
<div key={account.id} className="flex border-b border-divider-subtle">
|
||||
<div className="flex grow items-center px-3 py-2">
|
||||
<Avatar avatar={account.avatar_url} size={24} className="mr-2" name={account.name} />
|
||||
<Avatar avatar={account.avatar_url} size="sm" className="mr-2" name={account.name} />
|
||||
<div className="">
|
||||
<div className="text-text-secondary system-sm-medium">
|
||||
{account.name}
|
||||
|
||||
@ -3,7 +3,7 @@ import type { FC } from 'react'
|
||||
import * as React from 'react'
|
||||
import { useMemo, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Avatar from '@/app/components/base/avatar'
|
||||
import { Avatar } from '@/app/components/base/avatar'
|
||||
import Input from '@/app/components/base/input'
|
||||
import { PortalToFollowElem, PortalToFollowElemContent, PortalToFollowElemTrigger } from '@/app/components/base/portal-to-follow-elem'
|
||||
import { useMembers } from '@/service/use-common'
|
||||
@ -69,7 +69,7 @@ const MemberSelector: FC<Props> = ({
|
||||
)}
|
||||
{currentValue && (
|
||||
<>
|
||||
<Avatar avatar={currentValue.avatar_url} size={24} name={currentValue.name} />
|
||||
<Avatar avatar={currentValue.avatar_url} size="sm" name={currentValue.name} />
|
||||
<div className="grow truncate text-text-secondary system-sm-medium">{currentValue.name}</div>
|
||||
<div className="text-text-quaternary system-xs-regular">{currentValue.email}</div>
|
||||
</>
|
||||
@ -98,7 +98,7 @@ const MemberSelector: FC<Props> = ({
|
||||
setOpen(false)
|
||||
}}
|
||||
>
|
||||
<Avatar avatar={account.avatar_url} size={24} name={account.name} />
|
||||
<Avatar avatar={account.avatar_url} size="sm" name={account.name} />
|
||||
<div className="grow truncate text-text-secondary system-sm-medium">{account.name}</div>
|
||||
<div className="text-text-quaternary system-xs-regular">{account.email}</div>
|
||||
</div>
|
||||
|
||||
@ -59,9 +59,9 @@ const KeyValueList: FC<Props> = ({
|
||||
return (
|
||||
<div className="overflow-hidden rounded-lg border border-divider-regular">
|
||||
<div className={cn('flex h-7 items-center leading-7 text-text-tertiary system-xs-medium-uppercase')}>
|
||||
<div className={cn('h-full border-r border-divider-regular pl-3', isSupportFile ? 'w-[140px]' : 'w-1/2')}>{t(`${i18nPrefix}.key`, { ns: 'workflow' })}</div>
|
||||
{isSupportFile && <div className="h-full w-[70px] shrink-0 border-r border-divider-regular pl-3">{t(`${i18nPrefix}.type`, { ns: 'workflow' })}</div>}
|
||||
<div className={cn('h-full items-center justify-between pl-3 pr-1', isSupportFile ? 'grow' : 'w-1/2')}>{t(`${i18nPrefix}.value`, { ns: 'workflow' })}</div>
|
||||
<div className={cn('flex h-full items-center border-r border-divider-regular pl-3', isSupportFile ? 'w-[140px]' : 'w-1/2')}>{t(`${i18nPrefix}.key`, { ns: 'workflow' })}</div>
|
||||
{isSupportFile && <div className="flex h-full w-[70px] shrink-0 items-center border-r border-divider-regular pl-3">{t(`${i18nPrefix}.type`, { ns: 'workflow' })}</div>}
|
||||
<div className={cn('flex h-full items-center justify-between pl-3 pr-1', isSupportFile ? 'grow' : 'w-1/2')}>{t(`${i18nPrefix}.value`, { ns: 'workflow' })}</div>
|
||||
</div>
|
||||
{
|
||||
list.map((item, index) => (
|
||||
|
||||
@ -3,7 +3,7 @@ import type { Member } from '@/models/common'
|
||||
import { RiCloseCircleFill, RiErrorWarningFill } from '@remixicon/react'
|
||||
import * as React from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Avatar from '@/app/components/base/avatar'
|
||||
import { Avatar } from '@/app/components/base/avatar'
|
||||
import { cn } from '@/utils/classnames'
|
||||
|
||||
type Props = {
|
||||
@ -34,8 +34,8 @@ const EmailItem = ({
|
||||
{isError && (
|
||||
<RiErrorWarningFill className="h-4 w-4 text-text-destructive" />
|
||||
)}
|
||||
{!isError && <Avatar avatar={data.avatar_url} size={16} name={data.name || data.email} />}
|
||||
<div title={data.email} className="max-w-[500px] truncate text-text-primary system-xs-regular">
|
||||
{!isError && <Avatar avatar={data.avatar_url} size="xxs" name={data.name || data.email} />}
|
||||
<div title={data.email} className="system-xs-regular max-w-[500px] truncate text-text-primary">
|
||||
{email === data.email ? data.name : data.email}
|
||||
{email === data.email && <span className="text-text-tertiary system-xs-regular">{t('members.you', { ns: 'common' })}</span>}
|
||||
</div>
|
||||
|
||||
@ -4,7 +4,7 @@ import type { Recipient } from '@/app/components/workflow/nodes/human-input/type
|
||||
import type { Member } from '@/models/common'
|
||||
import { useMemo } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Avatar from '@/app/components/base/avatar'
|
||||
import { Avatar } from '@/app/components/base/avatar'
|
||||
import Input from '@/app/components/base/input'
|
||||
import { cn } from '@/utils/classnames'
|
||||
|
||||
@ -65,7 +65,7 @@ const MemberList: FC<Props> = ({ searchValue, list, value, onSearchChange, onSel
|
||||
onSelect(account.id)
|
||||
}}
|
||||
>
|
||||
<Avatar className={cn(value.some(item => item.user_id === account.id) && 'opacity-50')} avatar={account.avatar_url} size={24} name={account.name} />
|
||||
<Avatar className={cn(value.some(item => item.user_id === account.id) && 'opacity-50')} avatar={account.avatar_url} size="sm" name={account.name} />
|
||||
<div className={cn('grow', value.some(item => item.user_id === account.id) && 'opacity-50')}>
|
||||
<div className="text-text-secondary system-sm-medium">
|
||||
{account.name}
|
||||
|
||||
@ -96,7 +96,10 @@ const GenericTable: FC<GenericTableProps> = ({
|
||||
})
|
||||
|
||||
// If the last configured row has content, append a trailing empty row
|
||||
const lastHasContent = !isEmptyRow(data.at(-1))
|
||||
const lastRow = data.at(-1)
|
||||
if (!lastRow)
|
||||
return rows
|
||||
const lastHasContent = !isEmptyRow(lastRow)
|
||||
if (lastHasContent)
|
||||
rows.push({ row: { ...emptyRowData }, dataIndex: null, isVirtual: true })
|
||||
|
||||
@ -217,7 +220,7 @@ const GenericTable: FC<GenericTableProps> = ({
|
||||
<div
|
||||
key={column.key}
|
||||
className={cn(
|
||||
'h-full pl-3',
|
||||
'flex h-full items-center pl-3',
|
||||
column.width && column.width.startsWith('w-') ? 'shrink-0' : 'flex-1',
|
||||
column.width,
|
||||
// Add right border except for last column
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import { useRouter } from 'next/navigation'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Avatar from '@/app/components/base/avatar'
|
||||
import { Avatar } from '@/app/components/base/avatar'
|
||||
import Button from '@/app/components/base/button'
|
||||
import { Triangle } from '@/app/components/base/icons/src/public/education'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
@ -34,7 +34,7 @@ const UserInfo = () => {
|
||||
className="mr-4"
|
||||
avatar={userProfile.avatar_url}
|
||||
name={userProfile.name}
|
||||
size={48}
|
||||
size="2xl"
|
||||
/>
|
||||
<div className="pt-1.5">
|
||||
<div className="text-text-primary system-md-semibold">
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,5 +1,5 @@
|
||||
// @ts-check
|
||||
import antfu, { GLOB_TESTS, GLOB_TS, GLOB_TSX } from '@antfu/eslint-config'
|
||||
import antfu, { GLOB_TESTS, GLOB_TS, GLOB_TSX, isInEditorEnv, isInGitHooksOrLintStaged } from '@antfu/eslint-config'
|
||||
import pluginQuery from '@tanstack/eslint-plugin-query'
|
||||
import tailwindcss from 'eslint-plugin-better-tailwindcss'
|
||||
import hyoban from 'eslint-plugin-hyoban'
|
||||
@ -12,6 +12,8 @@ import dify from './plugins/eslint/index.js'
|
||||
// See: tailwind-css-plugin.ts
|
||||
process.env.TAILWIND_MODE ??= 'ESLINT'
|
||||
|
||||
const disableRuleAutoFix = !(isInEditorEnv() || isInGitHooksOrLintStaged())
|
||||
|
||||
export default antfu(
|
||||
{
|
||||
react: {
|
||||
@ -46,6 +48,7 @@ export default antfu(
|
||||
'antfu/top-level-function': 'off',
|
||||
},
|
||||
},
|
||||
e18e: false,
|
||||
},
|
||||
{
|
||||
rules: {
|
||||
@ -148,7 +151,7 @@ export default antfu(
|
||||
},
|
||||
{
|
||||
name: 'dify/base-ui-primitives',
|
||||
files: ['app/components/base/ui/**/*.tsx'],
|
||||
files: ['app/components/base/ui/**/*.tsx', 'app/components/base/avatar/**/*.tsx'],
|
||||
rules: {
|
||||
'react-refresh/only-export-components': 'off',
|
||||
},
|
||||
@ -218,3 +221,10 @@ export default antfu(
|
||||
},
|
||||
},
|
||||
)
|
||||
.disableRulesFix(disableRuleAutoFix
|
||||
? [
|
||||
'tailwindcss/enforce-consistent-class-order',
|
||||
'tailwindcss/no-duplicate-classes',
|
||||
'tailwindcss/no-unnecessary-whitespace',
|
||||
]
|
||||
: [])
|
||||
|
||||
@ -26,39 +26,36 @@
|
||||
"node": "^22.22.1"
|
||||
},
|
||||
"scripts": {
|
||||
"analyze": "next experimental-analyze",
|
||||
"analyze-component": "node ./scripts/analyze-component.js",
|
||||
"build": "next build",
|
||||
"build:vinext": "vinext build",
|
||||
"dev": "next dev",
|
||||
"dev:inspect": "next dev --inspect",
|
||||
"dev:vinext": "vinext dev",
|
||||
"build": "next build",
|
||||
"build:docker": "next build && node scripts/optimize-standalone.js",
|
||||
"build:vinext": "vinext build",
|
||||
"start": "node ./scripts/copy-and-start.mjs",
|
||||
"start:vinext": "vinext start",
|
||||
"gen-doc-paths": "tsx ./scripts/gen-doc-paths.ts",
|
||||
"gen-icons": "node ./scripts/gen-icons.mjs && eslint --fix app/components/base/icons/src/",
|
||||
"i18n:check": "tsx ./scripts/check-i18n.js",
|
||||
"knip": "knip",
|
||||
"lint": "eslint --cache --concurrency=auto",
|
||||
"lint:ci": "eslint --cache --concurrency 2",
|
||||
"lint:fix": "pnpm lint --fix",
|
||||
"lint:quiet": "pnpm lint --quiet",
|
||||
"lint:complexity": "pnpm lint --rule 'complexity: [error, {max: 15}]' --quiet",
|
||||
"lint:report": "pnpm lint --output-file eslint_report.json --format json",
|
||||
"lint:tss": "tsslint --project tsconfig.json",
|
||||
"type-check": "tsc --noEmit",
|
||||
"type-check:tsgo": "tsgo --noEmit",
|
||||
"preinstall": "npx only-allow pnpm",
|
||||
"prepare": "cd ../ && node -e \"if (process.env.NODE_ENV !== 'production'){process.exit(1)} \" || husky ./web/.husky",
|
||||
"gen-icons": "node ./scripts/gen-icons.mjs && eslint --fix app/components/base/icons/src/",
|
||||
"gen-doc-paths": "tsx ./scripts/gen-doc-paths.ts",
|
||||
"uglify-embed": "node ./bin/uglify-embed",
|
||||
"i18n:check": "tsx ./scripts/check-i18n.js",
|
||||
"test": "vitest run",
|
||||
"test:coverage": "vitest run --coverage",
|
||||
"test:ci": "vitest run --coverage --silent=passed-only",
|
||||
"test:watch": "vitest --watch",
|
||||
"analyze-component": "node ./scripts/analyze-component.js",
|
||||
"refactor-component": "node ./scripts/refactor-component.js",
|
||||
"start": "node ./scripts/copy-and-start.mjs",
|
||||
"start:vinext": "vinext start",
|
||||
"storybook": "storybook dev -p 6006",
|
||||
"storybook:build": "storybook build",
|
||||
"preinstall": "npx only-allow pnpm",
|
||||
"analyze": "next experimental-analyze",
|
||||
"knip": "knip"
|
||||
"test": "vitest run",
|
||||
"test:ci": "vitest run --coverage --silent=passed-only",
|
||||
"test:coverage": "vitest run --coverage",
|
||||
"test:watch": "vitest --watch",
|
||||
"type-check": "tsc --noEmit",
|
||||
"type-check:tsgo": "tsgo --noEmit",
|
||||
"uglify-embed": "node ./bin/uglify-embed"
|
||||
},
|
||||
"dependencies": {
|
||||
"@amplitude/analytics-browser": "2.36.3",
|
||||
|
||||
@ -1,38 +0,0 @@
|
||||
# Production Build Optimization Scripts
|
||||
|
||||
## optimize-standalone.js
|
||||
|
||||
This script removes unnecessary development dependencies from the Next.js standalone build output to reduce the production Docker image size.
|
||||
|
||||
### What it does
|
||||
|
||||
The script specifically targets and removes `jest-worker` packages that are bundled with Next.js but not needed in production. These packages are included because:
|
||||
|
||||
1. Next.js includes jest-worker in its compiled dependencies
|
||||
1. terser-webpack-plugin (used by Next.js for minification) depends on jest-worker
|
||||
1. pnpm's dependency resolution creates symlinks to jest-worker in various locations
|
||||
|
||||
### Usage
|
||||
|
||||
The script is automatically run during Docker builds via the `build:docker` npm script:
|
||||
|
||||
```bash
|
||||
# Docker build (removes jest-worker after build)
|
||||
pnpm build:docker
|
||||
```
|
||||
|
||||
To run the optimization manually:
|
||||
|
||||
```bash
|
||||
node scripts/optimize-standalone.js
|
||||
```
|
||||
|
||||
### What gets removed
|
||||
|
||||
- `node_modules/.pnpm/next@*/node_modules/next/dist/compiled/jest-worker`
|
||||
- `node_modules/.pnpm/terser-webpack-plugin@*/node_modules/jest-worker` (symlinks)
|
||||
- `node_modules/.pnpm/jest-worker@*` (actual packages)
|
||||
|
||||
### Impact
|
||||
|
||||
Removing jest-worker saves approximately 36KB per instance from the production image. While this may seem small, it helps ensure production images only contain necessary runtime dependencies.
|
||||
@ -1,163 +0,0 @@
|
||||
/**
|
||||
* Script to optimize Next.js standalone output for production
|
||||
* Removes unnecessary files like jest-worker that are bundled with Next.js
|
||||
*/
|
||||
|
||||
import fs from 'node:fs'
|
||||
import path from 'node:path'
|
||||
import { fileURLToPath } from 'node:url'
|
||||
|
||||
const __filename = fileURLToPath(import.meta.url)
|
||||
const __dirname = path.dirname(__filename)
|
||||
|
||||
console.log('🔧 Optimizing standalone output...')
|
||||
|
||||
const standaloneDir = path.join(__dirname, '..', '.next', 'standalone')
|
||||
|
||||
// Check if standalone directory exists
|
||||
if (!fs.existsSync(standaloneDir)) {
|
||||
console.error('❌ Standalone directory not found. Please run "next build" first.')
|
||||
process.exit(1)
|
||||
}
|
||||
|
||||
// List of paths to remove (relative to standalone directory)
|
||||
const pathsToRemove = [
|
||||
// Remove jest-worker from Next.js compiled dependencies
|
||||
'node_modules/.pnpm/next@*/node_modules/next/dist/compiled/jest-worker',
|
||||
// Remove jest-worker symlinks from terser-webpack-plugin
|
||||
'node_modules/.pnpm/terser-webpack-plugin@*/node_modules/jest-worker',
|
||||
// Remove actual jest-worker packages (directories only, not symlinks)
|
||||
'node_modules/.pnpm/jest-worker@*',
|
||||
]
|
||||
|
||||
// Function to safely remove a path
|
||||
function removePath(basePath, relativePath) {
|
||||
const fullPath = path.join(basePath, relativePath)
|
||||
|
||||
// Handle wildcard patterns
|
||||
if (relativePath.includes('*')) {
|
||||
const parts = relativePath.split('/')
|
||||
let currentPath = basePath
|
||||
|
||||
for (let i = 0; i < parts.length; i++) {
|
||||
const part = parts[i]
|
||||
if (part.includes('*')) {
|
||||
// Find matching directories
|
||||
if (fs.existsSync(currentPath)) {
|
||||
const entries = fs.readdirSync(currentPath)
|
||||
|
||||
// replace '*' with '.*'
|
||||
const regexPattern = part.replace(/\*/g, '.*')
|
||||
|
||||
const regex = new RegExp(`^${regexPattern}$`)
|
||||
|
||||
for (const entry of entries) {
|
||||
if (regex.test(entry)) {
|
||||
const remainingPath = parts.slice(i + 1).join('/')
|
||||
const matchedPath = path.join(currentPath, entry, remainingPath)
|
||||
|
||||
try {
|
||||
// Use lstatSync to check if path exists (works for both files and symlinks)
|
||||
const stats = fs.lstatSync(matchedPath)
|
||||
|
||||
if (stats.isSymbolicLink()) {
|
||||
// Remove symlink
|
||||
fs.unlinkSync(matchedPath)
|
||||
console.log(`✅ Removed symlink: ${path.relative(basePath, matchedPath)}`)
|
||||
}
|
||||
else {
|
||||
// Remove directory/file
|
||||
fs.rmSync(matchedPath, { recursive: true, force: true })
|
||||
console.log(`✅ Removed: ${path.relative(basePath, matchedPath)}`)
|
||||
}
|
||||
}
|
||||
catch (error) {
|
||||
// Silently ignore ENOENT (path not found) errors
|
||||
if (error.code !== 'ENOENT') {
|
||||
console.error(`❌ Failed to remove ${matchedPath}: ${error.message}`)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
else {
|
||||
currentPath = path.join(currentPath, part)
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
// Direct path removal
|
||||
if (fs.existsSync(fullPath)) {
|
||||
try {
|
||||
fs.rmSync(fullPath, { recursive: true, force: true })
|
||||
console.log(`✅ Removed: ${relativePath}`)
|
||||
}
|
||||
catch (error) {
|
||||
console.error(`❌ Failed to remove ${fullPath}: ${error.message}`)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Remove unnecessary paths
|
||||
console.log('🗑️ Removing unnecessary files...')
|
||||
for (const pathToRemove of pathsToRemove) {
|
||||
removePath(standaloneDir, pathToRemove)
|
||||
}
|
||||
|
||||
// Calculate size reduction
|
||||
console.log('\n📊 Optimization complete!')
|
||||
|
||||
// Optional: Display the size of remaining jest-related files (if any)
|
||||
const checkForJest = (dir) => {
|
||||
const jestFiles = []
|
||||
|
||||
function walk(currentPath) {
|
||||
if (!fs.existsSync(currentPath))
|
||||
return
|
||||
|
||||
try {
|
||||
const entries = fs.readdirSync(currentPath)
|
||||
for (const entry of entries) {
|
||||
const fullPath = path.join(currentPath, entry)
|
||||
|
||||
try {
|
||||
const stat = fs.lstatSync(fullPath) // Use lstatSync to handle symlinks
|
||||
|
||||
if (stat.isDirectory() && !stat.isSymbolicLink()) {
|
||||
// Skip node_modules subdirectories to avoid deep traversal
|
||||
if (entry === 'node_modules' && currentPath !== standaloneDir) {
|
||||
continue
|
||||
}
|
||||
walk(fullPath)
|
||||
}
|
||||
else if (stat.isFile() && entry.includes('jest')) {
|
||||
jestFiles.push(path.relative(standaloneDir, fullPath))
|
||||
}
|
||||
}
|
||||
catch (err) {
|
||||
// Skip files that can't be accessed
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
catch (err) {
|
||||
// Skip directories that can't be read
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
walk(dir)
|
||||
return jestFiles
|
||||
}
|
||||
|
||||
const remainingJestFiles = checkForJest(standaloneDir)
|
||||
if (remainingJestFiles.length > 0) {
|
||||
console.log('\n⚠️ Warning: Some jest-related files still remain:')
|
||||
remainingJestFiles.forEach(file => console.log(` - ${file}`))
|
||||
}
|
||||
else {
|
||||
console.log('\n✨ No jest-related files found in standalone output!')
|
||||
}
|
||||
Reference in New Issue
Block a user