mirror of
https://github.com/langgenius/dify.git
synced 2026-07-01 11:26:49 +08:00
Compare commits
8 Commits
deploy/dev
...
fix/device
| Author | SHA1 | Date | |
|---|---|---|---|
| f903b980ea | |||
| e926ad213d | |||
| febace6bcf | |||
| cb35c6fa98 | |||
| 34f62e7df6 | |||
| 07b5dcbb19 | |||
| 23917c7b3e | |||
| 8a6ce28855 |
@ -13,8 +13,10 @@ handlers do redirects + cookie kwargs that don't fit the Resource shape.
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
import secrets
|
||||
from dataclasses import dataclass
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from flask import jsonify, make_response, redirect, request
|
||||
from pydantic import ValidationError
|
||||
@ -74,6 +76,21 @@ STATE_ENVELOPE_TTL_SECONDS = 15 * 60
|
||||
# Canonical sso-complete path. IdP-side ACS callback URL must point here.
|
||||
_SSO_COMPLETE_PATH = "/openapi/v1/oauth/device/sso-complete"
|
||||
|
||||
_ALLOWED_SSO_ERRORS = {"sso_failed", "email_belongs_to_dify_account"}
|
||||
|
||||
# user_code only ever reaches the redirect as a urlencoded query value; the
|
||||
# charset bound additionally forbids the path/scheme separators a redirection
|
||||
# attack would need, so an untrusted value cannot escape the fixed /device path.
|
||||
_USER_CODE_RE = re.compile(r"\A[A-Z0-9-]{1,16}\Z")
|
||||
|
||||
|
||||
def _device_error_redirect(code: str, user_code: str | None = None):
|
||||
safe_code = code if code in _ALLOWED_SSO_ERRORS else "sso_failed"
|
||||
params: dict[str, str] = {"sso_error": safe_code}
|
||||
if user_code and _USER_CODE_RE.match(user_code):
|
||||
params["user_code"] = user_code
|
||||
return redirect(f"/device?{urlencode(params)}", code=302)
|
||||
|
||||
|
||||
def _trusted_origin() -> str:
|
||||
base = (dify_config.CONSOLE_API_URL or "").rstrip("/")
|
||||
@ -134,9 +151,21 @@ def sso_initiate():
|
||||
@bp.route("/oauth/device/sso-complete", methods=["GET"])
|
||||
@enterprise_only
|
||||
def sso_complete():
|
||||
try:
|
||||
return _sso_complete_impl()
|
||||
except Exception:
|
||||
logger.exception("sso-complete: unhandled")
|
||||
return _device_error_redirect("sso_failed")
|
||||
|
||||
|
||||
def _sso_complete_impl():
|
||||
inbound_error = request.args.get("sso_error")
|
||||
if inbound_error:
|
||||
return _device_error_redirect(inbound_error, request.args.get("user_code"))
|
||||
|
||||
blob = request.args.get("sso_assertion")
|
||||
if not blob:
|
||||
raise BadRequest("sso_assertion required")
|
||||
return _device_error_redirect("sso_failed")
|
||||
|
||||
keyset = jws.KeySet.from_shared_secret()
|
||||
|
||||
@ -144,25 +173,26 @@ def sso_complete():
|
||||
raw_claims = jws.verify(keyset, blob, expected_aud=jws.AUD_EXT_SUBJECT_ASSERTION)
|
||||
except jws.VerifyError as e:
|
||||
logger.warning("sso-complete: rejected assertion: %s", e)
|
||||
raise BadRequest("invalid_sso_assertion") from e
|
||||
return _device_error_redirect("sso_failed")
|
||||
|
||||
try:
|
||||
claims = ExtSubjectAssertionClaims.model_validate(raw_claims)
|
||||
except ValidationError as e:
|
||||
logger.warning("sso-complete: claim shape invalid: %s", e)
|
||||
raise BadRequest("invalid_sso_assertion") from e
|
||||
|
||||
if not consume_sso_assertion_nonce(redis_client, claims.nonce):
|
||||
raise BadRequest("invalid_sso_assertion")
|
||||
return _device_error_redirect("sso_failed")
|
||||
|
||||
user_code = claims.user_code.strip().upper()
|
||||
|
||||
if not consume_sso_assertion_nonce(redis_client, claims.nonce):
|
||||
return _device_error_redirect("sso_failed", user_code)
|
||||
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
found = store.load_by_user_code(user_code)
|
||||
if found is None:
|
||||
raise Conflict("user_code_not_pending")
|
||||
return _device_error_redirect("sso_failed", user_code)
|
||||
_, state = found
|
||||
if state.status is not DeviceFlowStatus.PENDING:
|
||||
raise Conflict("user_code_not_pending")
|
||||
return _device_error_redirect("sso_failed", user_code)
|
||||
|
||||
if AccountService.has_active_account_with_email(db.session, claims.email):
|
||||
_emit_external_rejection_audit(
|
||||
@ -170,7 +200,7 @@ def sso_complete():
|
||||
_RejectedClaims(subject_email=claims.email, subject_issuer=claims.issuer),
|
||||
reason="email_belongs_to_dify_account",
|
||||
)
|
||||
return redirect("/device?sso_error=email_belongs_to_dify_account", code=302)
|
||||
return _device_error_redirect("email_belongs_to_dify_account", user_code)
|
||||
|
||||
iss = _trusted_origin()
|
||||
cookie_value, _ = mint_approval_grant(
|
||||
|
||||
@ -144,7 +144,7 @@ class AnalyticdbVectorBySql:
|
||||
f"id text PRIMARY KEY,"
|
||||
f"vector real[], ref_doc_id text, page_content text, metadata_ jsonb, "
|
||||
f"to_tsvector TSVECTOR"
|
||||
f") WITH (fillfactor=70) DISTRIBUTED BY (id);"
|
||||
f") DISTRIBUTED BY (id);"
|
||||
)
|
||||
if embedding_dimension is not None:
|
||||
index_name = f"{self._collection_name}_embedding_idx"
|
||||
@ -153,7 +153,7 @@ class AnalyticdbVectorBySql:
|
||||
cur.execute(
|
||||
f"CREATE INDEX {index_name} ON {self.table_name} USING ann(vector) "
|
||||
f"WITH(dim='{embedding_dimension}', distancemeasure='{self.config.metrics}', "
|
||||
f"pq_enable=0, external_storage=0)"
|
||||
f"pq_enable=0)"
|
||||
)
|
||||
cur.execute(f"CREATE INDEX ON {self.table_name} USING gin(to_tsvector)")
|
||||
except Exception as e:
|
||||
|
||||
@ -40,6 +40,7 @@ from services.errors.app import QuotaExceededError
|
||||
from services.quota_service import QuotaService
|
||||
from services.trigger.app_trigger_service import AppTriggerService
|
||||
from services.workflow.entities import WebhookTriggerData
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
try:
|
||||
import magic
|
||||
@ -114,6 +115,7 @@ class WebhookService:
|
||||
workflow = session.scalar(
|
||||
select(Workflow)
|
||||
.where(
|
||||
Workflow.tenant_id == webhook_trigger.tenant_id,
|
||||
Workflow.app_id == webhook_trigger.app_id,
|
||||
Workflow.version == Workflow.VERSION_DRAFT,
|
||||
)
|
||||
@ -125,6 +127,7 @@ class WebhookService:
|
||||
app_trigger = session.scalar(
|
||||
select(AppTrigger)
|
||||
.where(
|
||||
AppTrigger.tenant_id == webhook_trigger.tenant_id,
|
||||
AppTrigger.app_id == webhook_trigger.app_id,
|
||||
AppTrigger.node_id == webhook_trigger.node_id,
|
||||
AppTrigger.trigger_type == AppTriggerType.TRIGGER_WEBHOOK,
|
||||
@ -145,16 +148,18 @@ class WebhookService:
|
||||
if app_trigger.status != AppTriggerStatus.ENABLED:
|
||||
raise ValueError(f"Webhook trigger is disabled for webhook {webhook_id}")
|
||||
|
||||
# Get workflow
|
||||
workflow = session.scalar(
|
||||
select(Workflow)
|
||||
app = session.scalar(
|
||||
select(App)
|
||||
.where(
|
||||
Workflow.app_id == webhook_trigger.app_id,
|
||||
Workflow.version != Workflow.VERSION_DRAFT,
|
||||
App.tenant_id == webhook_trigger.tenant_id,
|
||||
App.id == webhook_trigger.app_id,
|
||||
)
|
||||
.order_by(Workflow.created_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
if not app:
|
||||
raise ValueError(f"App not found for webhook {webhook_id}")
|
||||
|
||||
workflow = WorkflowService().get_published_workflow(app, session=session)
|
||||
if not workflow:
|
||||
raise ValueError(f"Workflow not found for app {webhook_trigger.app_id}")
|
||||
|
||||
|
||||
@ -333,7 +333,7 @@ class VectorService:
|
||||
|
||||
# Add documents to vector store if any
|
||||
if documents and dataset.is_multimodal:
|
||||
vector.add_texts(documents, duplicate_check=True)
|
||||
vector.create_multimodal(documents)
|
||||
|
||||
# Single commit for all operations
|
||||
db.session.commit()
|
||||
|
||||
@ -12,7 +12,7 @@ from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from celery import shared_task
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
@ -35,7 +35,7 @@ from models.enums import (
|
||||
WorkflowRunTriggeredFrom,
|
||||
WorkflowTriggerStatus,
|
||||
)
|
||||
from models.model import EndUser
|
||||
from models.model import App, EndUser
|
||||
from models.provider_ids import TriggerProviderID
|
||||
from models.trigger import TriggerSubscription, WorkflowPluginTrigger, WorkflowTriggerLog
|
||||
from models.workflow import Workflow, WorkflowAppLog, WorkflowAppLogCreatedFrom, WorkflowRun
|
||||
@ -99,23 +99,25 @@ def dispatch_trigger_debug_event(
|
||||
return 0
|
||||
|
||||
|
||||
def _get_latest_workflows_by_app_ids(
|
||||
def _get_published_workflows_by_app_ids(
|
||||
session: Session, subscribers: Sequence[WorkflowPluginTrigger]
|
||||
) -> Mapping[str, Workflow]:
|
||||
"""Get the latest workflows by app_ids"""
|
||||
workflow_query = (
|
||||
select(Workflow.app_id, func.max(Workflow.created_at).label("max_created_at"))
|
||||
.where(
|
||||
Workflow.app_id.in_({t.app_id for t in subscribers}),
|
||||
Workflow.version != Workflow.VERSION_DRAFT,
|
||||
)
|
||||
.group_by(Workflow.app_id)
|
||||
.subquery()
|
||||
)
|
||||
"""Get current published workflows through apps.workflow_id."""
|
||||
app_ids = {trigger.app_id for trigger in subscribers}
|
||||
tenant_ids = {trigger.tenant_id for trigger in subscribers}
|
||||
if not app_ids or not tenant_ids:
|
||||
return {}
|
||||
|
||||
workflows = session.scalars(
|
||||
select(Workflow).join(
|
||||
workflow_query,
|
||||
(Workflow.app_id == workflow_query.c.app_id) & (Workflow.created_at == workflow_query.c.max_created_at),
|
||||
select(Workflow)
|
||||
.join(App, App.workflow_id == Workflow.id)
|
||||
.where(
|
||||
App.id.in_(app_ids),
|
||||
App.tenant_id.in_(tenant_ids),
|
||||
App.workflow_id.isnot(None),
|
||||
Workflow.app_id == App.id,
|
||||
Workflow.tenant_id == App.tenant_id,
|
||||
Workflow.version != Workflow.VERSION_DRAFT,
|
||||
)
|
||||
).all()
|
||||
return {w.app_id: w for w in workflows}
|
||||
@ -262,7 +264,7 @@ def dispatch_triggered_workflow(
|
||||
|
||||
# Ensure expire_on_commit is set to False to remain workflows available
|
||||
with session_factory.create_session() as session:
|
||||
workflows: Mapping[str, Workflow] = _get_latest_workflows_by_app_ids(session, subscribers)
|
||||
workflows: Mapping[str, Workflow] = _get_published_workflows_by_app_ids(session, subscribers)
|
||||
|
||||
end_users: Mapping[str, EndUser] = EndUserService.create_end_user_batch(
|
||||
type=EndUserType.TRIGGER,
|
||||
|
||||
@ -127,6 +127,9 @@ class TestWebhookService:
|
||||
db_session_with_containers.add(workflow)
|
||||
db_session_with_containers.flush()
|
||||
|
||||
app.workflow_id = workflow.id
|
||||
db_session_with_containers.flush()
|
||||
|
||||
# Create webhook trigger
|
||||
webhook_id = fake.uuid4()[:16]
|
||||
webhook_trigger = WorkflowWebhookTrigger(
|
||||
|
||||
@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
@ -240,6 +241,40 @@ class TestWebhookServiceLookupWithContainers:
|
||||
with pytest.raises(ValueError, match="Workflow not found"):
|
||||
WebhookService.get_webhook_trigger_and_workflow(webhook_trigger.webhook_id)
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_uses_app_workflow_id(
|
||||
self, db_session_with_containers: Session, flask_app_with_containers: Flask
|
||||
):
|
||||
del flask_app_with_containers
|
||||
factory = WebhookServiceRelationshipFactory
|
||||
account, tenant = factory.create_account_and_tenant(db_session_with_containers)
|
||||
app = factory.create_app(db_session_with_containers, tenant, account)
|
||||
current_workflow = factory.create_workflow(
|
||||
db_session_with_containers, app=app, account=account, node_ids=["node-1"], version="2026-04-14.001"
|
||||
)
|
||||
newer_workflow = factory.create_workflow(
|
||||
db_session_with_containers, app=app, account=account, node_ids=["node-1"], version="2026-04-15.001"
|
||||
)
|
||||
current_workflow.created_at = datetime(2026, 4, 14)
|
||||
newer_workflow.created_at = datetime(2026, 4, 15)
|
||||
app.workflow_id = current_workflow.id
|
||||
db_session_with_containers.commit()
|
||||
|
||||
webhook_trigger = factory.create_webhook_trigger(
|
||||
db_session_with_containers, app=app, account=account, node_id="node-1"
|
||||
)
|
||||
factory.create_app_trigger(
|
||||
db_session_with_containers, app=app, node_id="node-1", status=AppTriggerStatus.ENABLED
|
||||
)
|
||||
|
||||
got_trigger, got_workflow, got_node_config = WebhookService.get_webhook_trigger_and_workflow(
|
||||
webhook_trigger.webhook_id
|
||||
)
|
||||
|
||||
assert got_trigger.id == webhook_trigger.id
|
||||
assert got_workflow.id == current_workflow.id
|
||||
assert got_workflow.id != newer_workflow.id
|
||||
assert got_node_config["id"] == "node-1"
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_returns_debug_draft_workflow(
|
||||
self, db_session_with_containers: Session, flask_app_with_containers: Flask
|
||||
):
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
"""SSO-branch device-flow endpoints under /openapi/v1/oauth/device/."""
|
||||
|
||||
import builtins
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
@ -77,3 +78,97 @@ def test_sso_complete_idp_callback_url_uses_canonical_path():
|
||||
from controllers.openapi import oauth_device_sso
|
||||
|
||||
assert oauth_device_sso._SSO_COMPLETE_PATH == "/openapi/v1/oauth/device/sso-complete"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _device_error_redirect helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_device_error_redirect_builds_relative_location():
|
||||
from controllers.openapi import oauth_device_sso
|
||||
|
||||
app = Flask(__name__)
|
||||
with app.test_request_context():
|
||||
resp = oauth_device_sso._device_error_redirect("sso_failed", "ABCD-1234")
|
||||
assert resp.status_code == 302
|
||||
loc = resp.headers["Location"]
|
||||
assert loc.startswith("/device?")
|
||||
assert "sso_error=sso_failed" in loc
|
||||
assert "user_code=ABCD-1234" in loc
|
||||
|
||||
|
||||
def test_device_error_redirect_clamps_unknown_code():
|
||||
from controllers.openapi import oauth_device_sso
|
||||
|
||||
app = Flask(__name__)
|
||||
with app.test_request_context():
|
||||
resp = oauth_device_sso._device_error_redirect("totally-bogus")
|
||||
assert "sso_error=sso_failed" in resp.headers["Location"]
|
||||
|
||||
|
||||
def test_device_error_redirect_keeps_email_special_case():
|
||||
from controllers.openapi import oauth_device_sso
|
||||
|
||||
app = Flask(__name__)
|
||||
with app.test_request_context():
|
||||
resp = oauth_device_sso._device_error_redirect("email_belongs_to_dify_account", "ABCD-1234")
|
||||
assert "sso_error=email_belongs_to_dify_account" in resp.headers["Location"]
|
||||
|
||||
|
||||
def test_device_error_redirect_omits_empty_user_code():
|
||||
from controllers.openapi import oauth_device_sso
|
||||
|
||||
app = Flask(__name__)
|
||||
with app.test_request_context():
|
||||
resp = oauth_device_sso._device_error_redirect("sso_failed")
|
||||
assert "user_code=" not in resp.headers["Location"]
|
||||
|
||||
|
||||
def test_device_error_redirect_drops_malformed_user_code():
|
||||
from controllers.openapi import oauth_device_sso
|
||||
|
||||
app = Flask(__name__)
|
||||
with app.test_request_context():
|
||||
resp = oauth_device_sso._device_error_redirect("sso_failed", "https://evil.example/")
|
||||
loc = resp.headers["Location"]
|
||||
assert loc.startswith("/device?")
|
||||
assert "user_code=" not in loc
|
||||
assert "evil" not in loc
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# sso_complete redirect behaviour
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _ee_features():
|
||||
from services.feature_service import LicenseStatus
|
||||
|
||||
m = MagicMock()
|
||||
m.license.status = LicenseStatus.ACTIVE
|
||||
return m
|
||||
|
||||
|
||||
@patch("libs.device_flow_security.FeatureService.get_system_features")
|
||||
def test_sso_complete_relays_inbound_sso_error(ee_feat, openapi_app):
|
||||
ee_feat.return_value = _ee_features()
|
||||
client = openapi_app.test_client()
|
||||
resp = client.get(
|
||||
"/openapi/v1/oauth/device/sso-complete?sso_error=sso_failed&user_code=ABCD-1234",
|
||||
follow_redirects=False,
|
||||
)
|
||||
assert resp.status_code == 302
|
||||
loc = resp.headers["Location"]
|
||||
assert "/device?" in loc
|
||||
assert "sso_error=sso_failed" in loc
|
||||
assert "user_code=ABCD-1234" in loc
|
||||
|
||||
|
||||
@patch("libs.device_flow_security.FeatureService.get_system_features")
|
||||
def test_sso_complete_missing_assertion_redirects_generic(ee_feat, openapi_app):
|
||||
ee_feat.return_value = _ee_features()
|
||||
client = openapi_app.test_client()
|
||||
resp = client.get("/openapi/v1/oauth/device/sso-complete", follow_redirects=False)
|
||||
assert resp.status_code == 302
|
||||
assert "sso_error=sso_failed" in resp.headers["Location"]
|
||||
|
||||
@ -34,8 +34,9 @@ def test_sso_complete_rejects_assertion_missing_email(ee_feat, jws_mod, app: Fla
|
||||
jws_mod.VerifyError = Exception
|
||||
|
||||
client = app.test_client()
|
||||
resp = client.get("/openapi/v1/oauth/device/sso-complete?sso_assertion=blob")
|
||||
assert resp.status_code == 400, resp.data
|
||||
resp = client.get("/openapi/v1/oauth/device/sso-complete?sso_assertion=blob", follow_redirects=False)
|
||||
assert resp.status_code == 302, resp.data
|
||||
assert "sso_error=sso_failed" in resp.headers["Location"]
|
||||
|
||||
|
||||
@patch("controllers.openapi.oauth_device_sso.jws")
|
||||
@ -48,8 +49,9 @@ def test_sso_complete_rejects_assertion_empty_issuer(ee_feat, jws_mod, app: Flas
|
||||
jws_mod.VerifyError = Exception
|
||||
|
||||
client = app.test_client()
|
||||
resp = client.get("/openapi/v1/oauth/device/sso-complete?sso_assertion=blob")
|
||||
assert resp.status_code == 400
|
||||
resp = client.get("/openapi/v1/oauth/device/sso-complete?sso_assertion=blob", follow_redirects=False)
|
||||
assert resp.status_code == 302
|
||||
assert "sso_error=sso_failed" in resp.headers["Location"]
|
||||
|
||||
|
||||
def test_verify_approval_grant_raises_on_missing_field():
|
||||
|
||||
@ -639,8 +639,8 @@ def test_update_multimodel_vector_adds_bindings_and_vectors_and_skips_missing_up
|
||||
assert len(bindings) == 1
|
||||
assert bindings[0]["attachment_id"] == "file-1"
|
||||
|
||||
vector_instance.add_texts.assert_called_once()
|
||||
documents = vector_instance.add_texts.call_args.args[0]
|
||||
vector_instance.create_multimodal.assert_called_once()
|
||||
documents = vector_instance.create_multimodal.call_args.args[0]
|
||||
assert len(documents) == 1
|
||||
assert documents[0].page_content == "img.png"
|
||||
assert documents[0].metadata["doc_id"] == "file-1"
|
||||
|
||||
@ -98,7 +98,7 @@ class TestDispatchTriggeredWorkflow:
|
||||
),
|
||||
patch.object(
|
||||
trigger_processing_tasks_module,
|
||||
"_get_latest_workflows_by_app_ids",
|
||||
"_get_published_workflows_by_app_ids",
|
||||
) as get_workflows,
|
||||
patch.object(
|
||||
trigger_processing_tasks_module.EndUserService,
|
||||
|
||||
@ -232,7 +232,20 @@ describe('Billing Page + Plan Integration', () => {
|
||||
|
||||
// Verify billing URL button visibility and behavior
|
||||
describe('Billing URL button', () => {
|
||||
it('should show billing button when subscription management permission is granted', () => {
|
||||
it('should show billing button when manager has subscription management permission', () => {
|
||||
setupProviderContext({ type: Plan.sandbox })
|
||||
setupAppContext({
|
||||
isCurrentWorkspaceManager: true,
|
||||
workspacePermissionKeys: ['billing.subscription.manage'],
|
||||
})
|
||||
|
||||
render(<Billing />)
|
||||
|
||||
expect(screen.getByText(/viewBillingTitle/i)).toBeInTheDocument()
|
||||
expect(screen.getByText(/viewBillingAction/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should hide billing button when subscription management permission is granted without manager role', () => {
|
||||
setupProviderContext({ type: Plan.sandbox })
|
||||
setupAppContext({
|
||||
isCurrentWorkspaceManager: false,
|
||||
@ -241,8 +254,7 @@ describe('Billing Page + Plan Integration', () => {
|
||||
|
||||
render(<Billing />)
|
||||
|
||||
expect(screen.getByText(/viewBillingTitle/i)).toBeInTheDocument()
|
||||
expect(screen.getByText(/viewBillingAction/i)).toBeInTheDocument()
|
||||
expect(screen.queryByText(/viewBillingTitle/i)).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should hide billing button when subscription management permission is missing', () => {
|
||||
|
||||
@ -21,6 +21,7 @@ let mockChatConversationDetail: Record<string, unknown> | undefined
|
||||
let mockCompletionConversationDetail: Record<string, unknown> | undefined
|
||||
let mockShowMessageLogModal = false
|
||||
let mockShowPromptLogModal = false
|
||||
let mockShowAgentLogModal = false
|
||||
let mockCurrentLogItem: Record<string, unknown> | undefined
|
||||
let mockCurrentLogModalActiveTab = 'messages'
|
||||
|
||||
@ -81,6 +82,7 @@ vi.mock('@/app/components/app/store', () => ({
|
||||
setShowAgentLogModal: mockSetShowAgentLogModal,
|
||||
setShowMessageLogModal: mockSetShowMessageLogModal,
|
||||
showPromptLogModal: mockShowPromptLogModal,
|
||||
showAgentLogModal: mockShowAgentLogModal,
|
||||
currentLogModalActiveTab: mockCurrentLogModalActiveTab,
|
||||
}),
|
||||
}))
|
||||
@ -126,6 +128,7 @@ vi.mock('@/app/components/base/chat/chat', () => ({
|
||||
onAnnotationEdited,
|
||||
onAnnotationRemoved,
|
||||
switchSibling,
|
||||
hideLogModal,
|
||||
}: {
|
||||
chatList: Array<{ id: string }>
|
||||
onFeedback: (mid: string, value: { rating: string, content?: string }) => Promise<boolean>
|
||||
@ -133,8 +136,9 @@ vi.mock('@/app/components/base/chat/chat', () => ({
|
||||
onAnnotationEdited: (query: string, answer: string, index: number) => void
|
||||
onAnnotationRemoved: (index: number) => Promise<boolean>
|
||||
switchSibling: (siblingMessageId: string) => void
|
||||
hideLogModal?: boolean
|
||||
}) => (
|
||||
<div data-testid="chat-panel">
|
||||
<div data-testid="chat-panel" data-hide-log-modal={String(hideLogModal)}>
|
||||
<div>{chatList.length}</div>
|
||||
<button onClick={() => void onFeedback('message-1', { rating: 'like', content: 'nice' })}>chat-feedback</button>
|
||||
<button onClick={() => onAnnotationAdded('annotation-2', 'Admin', 'Edited question', 'Edited answer', 1)}>chat-add-annotation</button>
|
||||
@ -145,6 +149,14 @@ vi.mock('@/app/components/base/chat/chat', () => ({
|
||||
),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/agent-log-modal', () => ({
|
||||
default: ({ floating, onCancel }: { floating?: boolean, onCancel: () => void }) => (
|
||||
<div data-testid="agent-log-modal" data-floating={String(floating)}>
|
||||
<button onClick={onCancel}>close-agent-log-modal</button>
|
||||
</div>
|
||||
),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/message-log-modal', () => ({
|
||||
default: ({ onCancel }: { onCancel: () => void }) => (
|
||||
<div data-testid="message-log-modal">
|
||||
@ -255,6 +267,7 @@ describe('ConversationList', () => {
|
||||
mockCompletionConversationDetail = undefined
|
||||
mockShowMessageLogModal = false
|
||||
mockShowPromptLogModal = false
|
||||
mockShowAgentLogModal = false
|
||||
mockCurrentLogItem = undefined
|
||||
mockCurrentLogModalActiveTab = 'messages'
|
||||
mockDelAnnotation.mockResolvedValue(undefined)
|
||||
@ -383,6 +396,7 @@ describe('ConversationList', () => {
|
||||
|
||||
expect(screen.getByTestId('var-panel')).toHaveTextContent('query:Latest question')
|
||||
expect(screen.getByTestId('model-info')).toHaveTextContent('gpt-4o')
|
||||
expect(screen.getByTestId('chat-panel')).toHaveAttribute('data-hide-log-modal', 'true')
|
||||
expect(screen.getByTestId('message-log-modal')).toBeInTheDocument()
|
||||
|
||||
fireEvent.click(screen.getByText('chat-feedback'))
|
||||
@ -399,6 +413,61 @@ describe('ConversationList', () => {
|
||||
})
|
||||
})
|
||||
|
||||
it('should mount agent log modals from the detail panel instead of the nested chat layout', async () => {
|
||||
mockChatConversationDetail = {
|
||||
id: 'conversation-1',
|
||||
created_at: 1710000000,
|
||||
model_config: {
|
||||
model: 'gpt-4o',
|
||||
configs: {
|
||||
introduction: 'Hello there',
|
||||
},
|
||||
user_input_form: [],
|
||||
},
|
||||
message: {
|
||||
inputs: {},
|
||||
},
|
||||
}
|
||||
mockShowAgentLogModal = true
|
||||
mockCurrentLogItem = {
|
||||
id: 'message-1',
|
||||
conversationId: 'conversation-1',
|
||||
}
|
||||
mockFetchChatMessages.mockResolvedValue({
|
||||
data: [
|
||||
{
|
||||
id: 'message-1',
|
||||
answer: 'Assistant reply',
|
||||
query: 'Latest question',
|
||||
created_at: 1710000000,
|
||||
inputs: {},
|
||||
feedbacks: [],
|
||||
message: [],
|
||||
message_files: [],
|
||||
agent_thoughts: [{ id: 'thought-1' }],
|
||||
},
|
||||
],
|
||||
has_more: false,
|
||||
})
|
||||
|
||||
renderConversationList({
|
||||
searchParams: '?page=2&conversation_id=conversation-1',
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('chat-panel')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
expect(screen.getByTestId('chat-panel')).toHaveAttribute('data-hide-log-modal', 'true')
|
||||
expect(screen.getByTestId('agent-log-modal')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('agent-log-modal')).toHaveAttribute('data-floating', 'true')
|
||||
|
||||
fireEvent.click(screen.getByText('close-agent-log-modal'))
|
||||
|
||||
expect(mockSetCurrentLogItem).toHaveBeenCalled()
|
||||
expect(mockSetShowAgentLogModal).toHaveBeenCalledWith(false)
|
||||
})
|
||||
|
||||
it('should render completion details and refetch after feedback updates', async () => {
|
||||
mockCompletionConversationDetail = {
|
||||
id: 'conversation-1',
|
||||
@ -424,7 +493,7 @@ describe('ConversationList', () => {
|
||||
},
|
||||
}
|
||||
mockShowPromptLogModal = true
|
||||
mockCurrentLogItem = { id: 'log-2' }
|
||||
mockCurrentLogItem = { id: 'log-2', log: [{ role: 'user', text: 'Prompt body' }] }
|
||||
|
||||
renderConversationList({
|
||||
appDetail: { id: 'app-1', mode: AppModeEnum.COMPLETION } as any,
|
||||
@ -626,7 +695,7 @@ describe('ConversationList', () => {
|
||||
},
|
||||
}
|
||||
mockShowPromptLogModal = true
|
||||
mockCurrentLogItem = { id: 'log-2' }
|
||||
mockCurrentLogItem = { id: 'log-2', log: [{ role: 'user', text: 'Prompt body' }] }
|
||||
|
||||
renderConversationList({
|
||||
appDetail: { id: 'app-1', mode: AppModeEnum.COMPLETION } as any,
|
||||
|
||||
@ -36,6 +36,7 @@ import ModelInfo from '@/app/components/app/log/model-info'
|
||||
import { useStore as useAppStore } from '@/app/components/app/store'
|
||||
import TextGeneration from '@/app/components/app/text-generate/item'
|
||||
import ActionButton from '@/app/components/base/action-button'
|
||||
import AgentLogModal from '@/app/components/base/agent-log-modal'
|
||||
import Chat from '@/app/components/base/chat/chat'
|
||||
import CopyIcon from '@/app/components/base/copy-icon'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
@ -165,13 +166,25 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) {
|
||||
})
|
||||
const { formatTime } = useTimestamp()
|
||||
const { onClose, appDetail } = useContext(DrawerContext)
|
||||
const { currentLogItem, setCurrentLogItem, showMessageLogModal, setShowMessageLogModal, showPromptLogModal, setShowPromptLogModal, currentLogModalActiveTab } = useAppStore(useShallow((state: AppStoreState) => ({
|
||||
const {
|
||||
currentLogItem,
|
||||
setCurrentLogItem,
|
||||
showMessageLogModal,
|
||||
setShowMessageLogModal,
|
||||
showPromptLogModal,
|
||||
setShowPromptLogModal,
|
||||
showAgentLogModal,
|
||||
setShowAgentLogModal,
|
||||
currentLogModalActiveTab,
|
||||
} = useAppStore(useShallow((state: AppStoreState) => ({
|
||||
currentLogItem: state.currentLogItem,
|
||||
setCurrentLogItem: state.setCurrentLogItem,
|
||||
showMessageLogModal: state.showMessageLogModal,
|
||||
setShowMessageLogModal: state.setShowMessageLogModal,
|
||||
showPromptLogModal: state.showPromptLogModal,
|
||||
setShowPromptLogModal: state.setShowPromptLogModal,
|
||||
showAgentLogModal: state.showAgentLogModal,
|
||||
setShowAgentLogModal: state.setShowAgentLogModal,
|
||||
currentLogModalActiveTab: state.currentLogModalActiveTab,
|
||||
})))
|
||||
const { t } = useTranslation()
|
||||
@ -395,6 +408,7 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) {
|
||||
|
||||
const isChatMode = appDetail?.mode !== AppModeEnum.COMPLETION
|
||||
const isAdvanced = appDetail?.mode === AppModeEnum.ADVANCED_CHAT
|
||||
const shouldShowPromptLogModal = showPromptLogModal && !!currentLogItem?.log
|
||||
|
||||
const varList = getDetailVarList(detail, varValues)
|
||||
const message_files = getCompletionMessageFiles(detail, isChatMode)
|
||||
@ -507,6 +521,7 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) {
|
||||
noChatInput
|
||||
showPromptLog
|
||||
hideProcessDetail
|
||||
hideLogModal
|
||||
chatContainerInnerClassName="px-3"
|
||||
switchSibling={switchSibling}
|
||||
/>
|
||||
@ -546,6 +561,7 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) {
|
||||
noChatInput
|
||||
showPromptLog
|
||||
hideProcessDetail
|
||||
hideLogModal
|
||||
chatContainerInnerClassName="px-3"
|
||||
switchSibling={switchSibling}
|
||||
/>
|
||||
@ -574,7 +590,18 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) {
|
||||
/>
|
||||
</WorkflowContextProvider>
|
||||
)}
|
||||
{!isChatMode && showPromptLogModal && (
|
||||
{showAgentLogModal && (
|
||||
<AgentLogModal
|
||||
floating
|
||||
width={width}
|
||||
currentLogItem={currentLogItem}
|
||||
onCancel={() => {
|
||||
setCurrentLogItem()
|
||||
setShowAgentLogModal(false)
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
{shouldShowPromptLogModal && (
|
||||
<PromptLogModal
|
||||
width={width}
|
||||
currentLogItem={currentLogItem}
|
||||
|
||||
@ -119,6 +119,17 @@ describe('AgentLogModal', () => {
|
||||
})
|
||||
})
|
||||
|
||||
it('should render the floating modal through a dialog portal', () => {
|
||||
vi.mocked(fetchAgentLogDetail).mockReturnValue(new Promise(() => {}))
|
||||
|
||||
const { container } = render(<AgentLogModal {...mockProps} floating />)
|
||||
|
||||
const modal = screen.getByRole('dialog')
|
||||
expect(container).not.toContainElement(modal)
|
||||
expect(document.body).toContainElement(modal)
|
||||
expect(modal).toHaveClass('fixed', 'z-50', 'w-[480px]!', 'left-[max(8px,calc(100vw-1136px))]!')
|
||||
})
|
||||
|
||||
it('should call onCancel when close button is clicked', () => {
|
||||
vi.mocked(fetchAgentLogDetail).mockReturnValue(new Promise(() => {}))
|
||||
|
||||
@ -158,4 +169,18 @@ describe('AgentLogModal', () => {
|
||||
|
||||
expect(mockProps.onCancel).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should not use click-away to close the floating dialog', () => {
|
||||
vi.mocked(fetchAgentLogDetail).mockReturnValue(new Promise(() => {}))
|
||||
|
||||
let clickAwayHandler!: (event: Event) => void
|
||||
vi.mocked(useClickAway).mockImplementation((callback) => {
|
||||
clickAwayHandler = callback
|
||||
})
|
||||
|
||||
render(<AgentLogModal {...mockProps} floating />)
|
||||
clickAwayHandler(new Event('click'))
|
||||
|
||||
expect(mockProps.onCancel).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import type { FC } from 'react'
|
||||
import type { IChatItem } from '@/app/components/base/chat/chat/type'
|
||||
import { cn } from '@langgenius/dify-ui/cn'
|
||||
import { Dialog, DialogContent, DialogTitle } from '@langgenius/dify-ui/dialog'
|
||||
import { RiCloseLine } from '@remixicon/react'
|
||||
import { useClickAway } from 'ahooks'
|
||||
import { useEffect, useRef, useState } from 'react'
|
||||
@ -10,11 +11,13 @@ import AgentLogDetail from './detail'
|
||||
type AgentLogModalProps = Readonly<{
|
||||
currentLogItem?: IChatItem
|
||||
width: number
|
||||
floating?: boolean
|
||||
onCancel: () => void
|
||||
}>
|
||||
const AgentLogModal: FC<AgentLogModalProps> = ({
|
||||
currentLogItem,
|
||||
width,
|
||||
floating,
|
||||
onCancel,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
@ -22,7 +25,7 @@ const AgentLogModal: FC<AgentLogModalProps> = ({
|
||||
const [mounted, setMounted] = useState(false)
|
||||
|
||||
useClickAway(() => {
|
||||
if (mounted)
|
||||
if (mounted && !floating)
|
||||
onCancel()
|
||||
}, ref)
|
||||
|
||||
@ -33,6 +36,44 @@ const AgentLogModal: FC<AgentLogModalProps> = ({
|
||||
if (!currentLogItem || !currentLogItem.conversationId)
|
||||
return null
|
||||
|
||||
const detailContent = (
|
||||
<>
|
||||
<AgentLogDetail
|
||||
conversationID={currentLogItem.conversationId}
|
||||
messageID={currentLogItem.id}
|
||||
log={currentLogItem}
|
||||
/>
|
||||
</>
|
||||
)
|
||||
|
||||
if (floating) {
|
||||
return (
|
||||
<Dialog
|
||||
open
|
||||
onOpenChange={(open) => {
|
||||
if (!open)
|
||||
onCancel()
|
||||
}}
|
||||
>
|
||||
<DialogContent
|
||||
backdropClassName="bg-transparent!"
|
||||
className="top-16! bottom-4! left-[max(8px,calc(100vw-1136px))]! flex max-h-none! w-[480px]! max-w-[calc(100vw-16px)]! translate-x-0! translate-y-0! flex-col overflow-hidden! rounded-xl! border-[0.5px]! border-components-panel-border! bg-components-panel-bg! p-0! pt-3! pb-3! shadow-xl!"
|
||||
>
|
||||
<DialogTitle className="text-md shrink-0 px-4 py-1 font-semibold text-text-primary">{t('runDetail.workflowTitle', { ns: 'appLog' })}</DialogTitle>
|
||||
<button
|
||||
type="button"
|
||||
aria-label={t('operation.close', { ns: 'common' })}
|
||||
className="absolute top-4 right-3 z-20 cursor-pointer border-none bg-transparent p-1 focus-visible:ring-1 focus-visible:ring-components-input-border-active focus-visible:outline-hidden"
|
||||
onClick={onCancel}
|
||||
>
|
||||
<RiCloseLine className="size-4 text-text-tertiary" aria-hidden="true" />
|
||||
</button>
|
||||
{detailContent}
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn('relative z-10 flex flex-col rounded-xl border-[0.5px] border-components-panel-border bg-components-panel-bg py-3 shadow-xl')}
|
||||
@ -54,11 +95,7 @@ const AgentLogModal: FC<AgentLogModalProps> = ({
|
||||
>
|
||||
<RiCloseLine className="size-4 text-text-tertiary" aria-hidden="true" />
|
||||
</button>
|
||||
<AgentLogDetail
|
||||
conversationID={currentLogItem.conversationId}
|
||||
messageID={currentLogItem.id}
|
||||
log={currentLogItem}
|
||||
/>
|
||||
{detailContent}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
@ -6,6 +6,7 @@ let fetching = false
|
||||
let isManager = true
|
||||
let enableBilling = true
|
||||
let workspacePermissionKeys: string[] = ['billing.subscription.manage']
|
||||
let billingUrlEnabled = false
|
||||
|
||||
const refetchMock = vi.fn()
|
||||
const openAsyncWindowMock = vi.fn()
|
||||
@ -19,11 +20,14 @@ type BillingWindowOptions = {
|
||||
type OpenAsyncWindowCall = [BillingUrlCallback, BillingWindowOptions]
|
||||
|
||||
vi.mock('@/service/use-billing', () => ({
|
||||
useBillingUrl: () => ({
|
||||
data: currentBillingUrl,
|
||||
isFetching: fetching,
|
||||
refetch: refetchMock,
|
||||
}),
|
||||
useBillingUrl: (enabled: boolean) => {
|
||||
billingUrlEnabled = enabled
|
||||
return {
|
||||
data: currentBillingUrl,
|
||||
isFetching: fetching,
|
||||
refetch: refetchMock,
|
||||
}
|
||||
},
|
||||
}))
|
||||
|
||||
vi.mock('@/hooks/use-async-window-open', () => ({
|
||||
@ -54,28 +58,32 @@ describe('Billing', () => {
|
||||
fetching = false
|
||||
isManager = true
|
||||
enableBilling = true
|
||||
billingUrlEnabled = false
|
||||
workspacePermissionKeys = ['billing.subscription.manage']
|
||||
refetchMock.mockResolvedValue({ data: 'https://billing' })
|
||||
})
|
||||
|
||||
it('shows the billing action when subscription management permission is granted without manager role', () => {
|
||||
it('hides the billing action when subscription management permission is granted without manager role', () => {
|
||||
isManager = false
|
||||
|
||||
render(<Billing />)
|
||||
|
||||
expect(screen.getByRole('button', { name: /billing\.viewBillingTitle/ })).toBeInTheDocument()
|
||||
expect(screen.queryByRole('button', { name: /billing\.viewBillingTitle/ })).not.toBeInTheDocument()
|
||||
expect(billingUrlEnabled).toBe(false)
|
||||
})
|
||||
|
||||
it('hides the billing action when subscription management permission is missing or billing is disabled', () => {
|
||||
workspacePermissionKeys = []
|
||||
render(<Billing />)
|
||||
expect(screen.queryByRole('button', { name: /billing\.viewBillingTitle/ })).not.toBeInTheDocument()
|
||||
expect(billingUrlEnabled).toBe(false)
|
||||
|
||||
vi.clearAllMocks()
|
||||
workspacePermissionKeys = ['billing.subscription.manage']
|
||||
enableBilling = false
|
||||
render(<Billing />)
|
||||
expect(screen.queryByRole('button', { name: /billing\.viewBillingTitle/ })).not.toBeInTheDocument()
|
||||
expect(billingUrlEnabled).toBe(false)
|
||||
})
|
||||
|
||||
it('opens the billing window with the immediate url when the button is clicked', async () => {
|
||||
|
||||
@ -11,9 +11,9 @@ import PlanComp from '../plan'
|
||||
|
||||
const Billing: FC = () => {
|
||||
const { t } = useTranslation()
|
||||
const { workspacePermissionKeys } = useAppContext()
|
||||
const { isCurrentWorkspaceManager, workspacePermissionKeys } = useAppContext()
|
||||
const { enableBilling } = useProviderContext()
|
||||
const canManageBillingSubscription = hasPermission(workspacePermissionKeys, BillingPermission.SubscriptionManage)
|
||||
const canManageBillingSubscription = isCurrentWorkspaceManager && hasPermission(workspacePermissionKeys, BillingPermission.SubscriptionManage)
|
||||
const { data: billingUrl, isFetching, refetch } = useBillingUrl(enableBilling && canManageBillingSubscription)
|
||||
const openAsyncWindow = useAsyncWindowOpen()
|
||||
|
||||
|
||||
@ -121,40 +121,50 @@ describe('error_lookup_failed terminal state', () => {
|
||||
})
|
||||
})
|
||||
|
||||
describe('sso_error inline banner on the code-entry page', () => {
|
||||
const SSO_BANNER_COPY = /identity is linked to a Dify account/i
|
||||
describe('error_sso dedicated view', () => {
|
||||
const GENERIC = /single sign-on could not be completed/i
|
||||
|
||||
it('shows the error banner with friendly copy when sso_error is present', async () => {
|
||||
mockSearchParams = { sso_error: 'email_belongs_to_dify_account' }
|
||||
it('renders the dedicated SSO error screen (not the code-entry page)', async () => {
|
||||
mockSearchParams = { sso_error: 'sso_failed', user_code: 'ABCD-3456' }
|
||||
render(<DevicePage />)
|
||||
expect(await screen.findByText(SSO_BANNER_COPY)).toBeInTheDocument()
|
||||
expect(await screen.findByText('Sign-in could not be completed')).toBeInTheDocument()
|
||||
expect(await screen.findByText(GENERIC)).toBeInTheDocument()
|
||||
expect(screen.queryByRole('textbox')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('keeps the code-entry screen visible (error on main page, not a separate view)', async () => {
|
||||
mockSearchParams = { sso_error: 'email_belongs_to_dify_account' }
|
||||
it('shows the email special-case copy', async () => {
|
||||
mockSearchParams = { sso_error: 'email_belongs_to_dify_account', user_code: 'ABCD-3456' }
|
||||
render(<DevicePage />)
|
||||
await screen.findByText(SSO_BANNER_COPY)
|
||||
expect(screen.getByRole('textbox')).toBeInTheDocument()
|
||||
expect(screen.getByRole('button', { name: /Continue/i })).toBeInTheDocument()
|
||||
expect(await screen.findByText(/identity is linked to a Dify account/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('does not surface the raw backend error code', async () => {
|
||||
mockSearchParams = { sso_error: 'email_belongs_to_dify_account' }
|
||||
it('never surfaces the raw backend code', async () => {
|
||||
mockSearchParams = { sso_error: 'email_belongs_to_dify_account', user_code: 'ABCD-3456' }
|
||||
render(<DevicePage />)
|
||||
await screen.findByText(SSO_BANNER_COPY)
|
||||
await screen.findByText(/identity is linked to a Dify account/i)
|
||||
expect(screen.queryByText('email_belongs_to_dify_account')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('does not scrub the param on mount (regression: error was wiped by router.replace)', async () => {
|
||||
mockSearchParams = { sso_error: 'email_belongs_to_dify_account' }
|
||||
it('scrubs sso_error + user_code from the URL on mount', async () => {
|
||||
mockSearchParams = { sso_error: 'sso_failed', user_code: 'ABCD-3456' }
|
||||
render(<DevicePage />)
|
||||
await screen.findByText(SSO_BANNER_COPY)
|
||||
expect(mockReplace).not.toHaveBeenCalled()
|
||||
await screen.findByText('Sign-in could not be completed')
|
||||
expect(mockReplace).toHaveBeenCalledWith('/device')
|
||||
})
|
||||
|
||||
it('shows no banner when sso_error is absent', () => {
|
||||
it('"Back to login options" re-checks the code and advances to the chooser', async () => {
|
||||
mockSearchParams = { sso_error: 'sso_failed', user_code: 'ABCD-3456' }
|
||||
mockDeviceLookup.mockResolvedValue({ valid: true })
|
||||
render(<DevicePage />)
|
||||
await screen.findByText('Sign-in could not be completed')
|
||||
fireEvent.click(screen.getByRole('button', { name: /Back to login options/i }))
|
||||
await screen.findByText(/is valid/i)
|
||||
expect(mockDeviceLookup).toHaveBeenCalledWith('ABCD-3456')
|
||||
})
|
||||
|
||||
it('shows no SSO error screen when sso_error is absent', () => {
|
||||
render(<DevicePage />)
|
||||
expect(screen.getByRole('textbox')).toBeInTheDocument()
|
||||
expect(screen.queryByText(SSO_BANNER_COPY)).not.toBeInTheDocument()
|
||||
expect(screen.queryByText('Sign-in could not be completed')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
@ -25,6 +25,7 @@ type View
|
||||
| { kind: 'error_expired' }
|
||||
| { kind: 'error_rate_limited' }
|
||||
| { kind: 'error_lookup_failed' }
|
||||
| { kind: 'error_sso', code: string, userCode: string }
|
||||
|
||||
export default function DevicePage() {
|
||||
const searchParams = useSearchParams()
|
||||
@ -72,6 +73,11 @@ export default function DevicePage() {
|
||||
useEffect(() => {
|
||||
if (view.kind !== 'code_entry' && view.kind !== 'chooser')
|
||||
return
|
||||
if (ssoError) {
|
||||
setView({ kind: 'error_sso', code: ssoError, userCode: urlUserCode }) // eslint-disable-line react/set-state-in-effect
|
||||
router.replace(pathname)
|
||||
return
|
||||
}
|
||||
// Post-login bounce: chooser holds the typed code, account just loaded.
|
||||
// The URL was already scrubbed on the first effect run, so urlUserCode
|
||||
// is empty here — advance using the userCode stashed in view state.
|
||||
@ -93,13 +99,11 @@ export default function DevicePage() {
|
||||
}
|
||||
if (consumed && (urlUserCode || ssoVerified))
|
||||
router.replace(pathname)
|
||||
}, [urlUserCode, ssoVerified, account, view, router, pathname])
|
||||
}, [urlUserCode, ssoVerified, ssoError, account, view, router, pathname])
|
||||
|
||||
const onContinue = async () => {
|
||||
if (!isValidUserCode(typed))
|
||||
return
|
||||
const advanceFromCode = async (code: string) => {
|
||||
try {
|
||||
const reply = await deviceLookup(typed)
|
||||
const reply = await deviceLookup(code)
|
||||
if (!reply.valid) {
|
||||
setView({ kind: 'error_expired' })
|
||||
return
|
||||
@ -116,20 +120,20 @@ export default function DevicePage() {
|
||||
return
|
||||
}
|
||||
if (account)
|
||||
setView({ kind: 'authorize_account', userCode: typed })
|
||||
else setView({ kind: 'chooser', userCode: typed })
|
||||
setView({ kind: 'authorize_account', userCode: code })
|
||||
else setView({ kind: 'chooser', userCode: code })
|
||||
}
|
||||
|
||||
const onContinue = async () => {
|
||||
if (!isValidUserCode(typed))
|
||||
return
|
||||
await advanceFromCode(typed)
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
{view.kind === 'code_entry' && (
|
||||
<div className="flex flex-col gap-5">
|
||||
{ssoError && (
|
||||
<div className="flex items-start gap-2 rounded-lg bg-state-destructive-hover p-3">
|
||||
<span className="mt-0.5 i-ri-close-circle-line h-4 w-4 shrink-0 text-util-colors-red-red-600" />
|
||||
<p className="text-sm text-text-destructive">{ssoErrorCopy(ssoError)}</p>
|
||||
</div>
|
||||
)}
|
||||
<div>
|
||||
<h1 className="text-2xl font-semibold text-text-primary">Authorize Dify CLI</h1>
|
||||
<p className="mt-2 text-sm text-text-secondary">
|
||||
@ -270,6 +274,31 @@ export default function DevicePage() {
|
||||
</div>
|
||||
)}
|
||||
|
||||
{view.kind === 'error_sso' && (
|
||||
<div className="flex flex-col gap-1">
|
||||
<div className="mb-2.5 flex h-[38px] w-[38px] items-center justify-center rounded-full bg-state-warning-hover">
|
||||
<span aria-hidden="true" className="i-ri-error-warning-line h-[18px] w-[18px] text-util-colors-yellow-yellow-600" />
|
||||
</div>
|
||||
<h1 className="text-xl font-semibold text-text-primary">Sign-in could not be completed</h1>
|
||||
<p className="text-sm text-text-secondary">{ssoErrorCopy(view.code)}</p>
|
||||
<Divider className="my-3" />
|
||||
<Button
|
||||
variant="primary"
|
||||
size="large"
|
||||
className="w-full"
|
||||
onClick={() => {
|
||||
setErrMsg(null)
|
||||
if (view.userCode)
|
||||
advanceFromCode(view.userCode)
|
||||
else
|
||||
setView({ kind: 'code_entry' })
|
||||
}}
|
||||
>
|
||||
← Back to login options
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{errMsg && (
|
||||
<p className="mt-4 text-sm text-text-destructive">{errMsg}</p>
|
||||
)}
|
||||
|
||||
Reference in New Issue
Block a user