diff --git a/.github/workflows/deploy-agent-dev.yml b/.github/workflows/deploy-saas.yml similarity index 69% rename from .github/workflows/deploy-agent-dev.yml rename to .github/workflows/deploy-saas.yml index 9b9b77e0a2..b00883c8c7 100644 --- a/.github/workflows/deploy-agent-dev.yml +++ b/.github/workflows/deploy-saas.yml @@ -1,4 +1,4 @@ -name: Deploy Agent Dev +name: Deploy SaaS permissions: contents: read @@ -7,7 +7,7 @@ on: workflow_run: workflows: ["Build and Push API & Web"] branches: - - "deploy/agent-dev" + - "deploy/saas" types: - completed @@ -16,13 +16,13 @@ jobs: runs-on: depot-ubuntu-24.04 if: | github.event.workflow_run.conclusion == 'success' && - github.event.workflow_run.head_branch == 'deploy/agent-dev' + github.event.workflow_run.head_branch == 'deploy/saas' steps: - name: Deploy to server uses: appleboy/ssh-action@0ff4204d59e8e51228ff73bce53f80d53301dee2 # v1.2.5 with: - host: ${{ secrets.AGENT_DEV_SSH_HOST }} + host: ${{ secrets.SAAS_DEV_SSH_HOST }} username: ${{ secrets.SSH_USER }} key: ${{ secrets.SSH_PRIVATE_KEY }} script: | - ${{ vars.SSH_SCRIPT || secrets.SSH_SCRIPT }} + ${{ vars.SSH_SCRIPT_SAAS_DEV || secrets.SSH_SCRIPT_SAAS_DEV }} diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index c40ca5c1ea..e18607a60e 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -95,6 +95,51 @@ jobs: if: steps.changed-files.outputs.any_changed == 'true' uses: ./.github/actions/setup-web + - name: Web tsslint + if: steps.changed-files.outputs.any_changed == 'true' + env: + NODE_OPTIONS: --max-old-space-size=4096 + run: vp run lint:tss + + - name: Web dead code check + if: steps.changed-files.outputs.any_changed == 'true' + run: vp run knip + + ts-common-style: + name: TS Common + runs-on: depot-ubuntu-24.04 + permissions: + checks: write + pull-requests: read + + steps: + - name: Checkout code + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + with: + persist-credentials: false + + - name: Check changed files + id: changed-files + uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6 + with: + files: | + web/** + cli/** + e2e/** + sdks/nodejs-client/** + packages/** + package.json + pnpm-lock.yaml + pnpm-workspace.yaml + .nvmrc + eslint.config.mjs + .github/workflows/style.yml + .github/actions/setup-web/** + + - name: Setup web environment + if: steps.changed-files.outputs.any_changed == 'true' + uses: ./.github/actions/setup-web + - name: Restore ESLint cache if: steps.changed-files.outputs.any_changed == 'true' id: eslint-cache-restore @@ -105,28 +150,14 @@ jobs: restore-keys: | ${{ runner.os }}-eslint-${{ hashFiles('pnpm-lock.yaml', 'eslint.config.mjs', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}- - - name: Web style check + - name: Style check if: steps.changed-files.outputs.any_changed == 'true' - working-directory: . run: vp run lint:ci - - name: Web tsslint + - name: Type check if: steps.changed-files.outputs.any_changed == 'true' - working-directory: ./web - env: - NODE_OPTIONS: --max-old-space-size=4096 - run: vp run lint:tss - - - name: Web type check - if: steps.changed-files.outputs.any_changed == 'true' - working-directory: . run: vp run type-check - - name: Web dead code check - if: steps.changed-files.outputs.any_changed == 'true' - working-directory: ./web - run: vp run knip - - name: Save ESLint cache if: steps.changed-files.outputs.any_changed == 'true' && success() && steps.eslint-cache-restore.outputs.cache-hit != 'true' uses: actions/cache/save@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5 diff --git a/api/controllers/service_api/app/annotation.py b/api/controllers/service_api/app/annotation.py index 12ab692adb..8906d544e8 100644 --- a/api/controllers/service_api/app/annotation.py +++ b/api/controllers/service_api/app/annotation.py @@ -6,7 +6,7 @@ from flask_restx import Resource from flask_restx.api import HTTPStatus from pydantic import BaseModel, Field, TypeAdapter -from controllers.common.schema import register_schema_models +from controllers.common.schema import query_params_from_model, register_schema_models from controllers.console.wraps import edit_permission_required from controllers.service_api import service_api_ns from controllers.service_api.wraps import validate_app_token @@ -32,8 +32,19 @@ class AnnotationReplyActionPayload(BaseModel): embedding_model_name: str = Field(description="Embedding model name") +class AnnotationListQuery(BaseModel): + page: int = Field(default=1, ge=1, description="Page number") + limit: int = Field(default=20, ge=1, description="Number of annotations per page") + keyword: str = Field(default="", description="Keyword to search annotations") + + register_schema_models( - service_api_ns, AnnotationCreatePayload, AnnotationReplyActionPayload, Annotation, AnnotationList + service_api_ns, + AnnotationCreatePayload, + AnnotationReplyActionPayload, + AnnotationListQuery, + Annotation, + AnnotationList, ) @@ -100,6 +111,7 @@ class AnnotationReplyActionStatusApi(Resource): class AnnotationListApi(Resource): @service_api_ns.doc("list_annotations") @service_api_ns.doc(description="List annotations for the application") + @service_api_ns.doc(params=query_params_from_model(AnnotationListQuery)) @service_api_ns.doc( responses={ 200: "Annotations retrieved successfully", @@ -114,18 +126,18 @@ class AnnotationListApi(Resource): @validate_app_token def get(self, app_model: App): """List annotations for the application.""" - page = request.args.get("page", default=1, type=int) - limit = request.args.get("limit", default=20, type=int) - keyword = request.args.get("keyword", default="", type=str) + query = AnnotationListQuery.model_validate(request.args.to_dict(flat=True)) - annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_model.id, page, limit, keyword) + annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id( + app_model.id, query.page, query.limit, query.keyword + ) annotation_models = TypeAdapter(list[Annotation]).validate_python(annotation_list, from_attributes=True) response = AnnotationList( data=annotation_models, - has_more=len(annotation_list) == limit, - limit=limit, + has_more=len(annotation_list) == query.limit, + limit=query.limit, total=total, - page=page, + page=query.page, ) return response.model_dump(mode="json") diff --git a/api/core/app/apps/common/workflow_response_converter.py b/api/core/app/apps/common/workflow_response_converter.py index 297279954e..96fc55526c 100644 --- a/api/core/app/apps/common/workflow_response_converter.py +++ b/api/core/app/apps/common/workflow_response_converter.py @@ -562,15 +562,16 @@ class WorkflowResponseConverter: outputs, outputs_truncated = self._truncate_mapping(encoded_outputs) metadata = self._merge_metadata(event.execution_metadata, snapshot) - if isinstance(event, QueueNodeSucceededEvent): - status = WorkflowNodeExecutionStatus.SUCCEEDED - error_message = event.error - elif isinstance(event, QueueNodeFailedEvent): - status = WorkflowNodeExecutionStatus.FAILED - error_message = event.error - else: - status = WorkflowNodeExecutionStatus.EXCEPTION - error_message = event.error + match event: + case QueueNodeSucceededEvent(): + status = WorkflowNodeExecutionStatus.SUCCEEDED + error_message = event.error + case QueueNodeFailedEvent(): + status = WorkflowNodeExecutionStatus.FAILED + error_message = event.error + case _: + status = WorkflowNodeExecutionStatus.EXCEPTION + error_message = event.error return NodeFinishStreamResponse( task_id=task_id, diff --git a/api/core/base/tts/app_generator_tts_publisher.py b/api/core/base/tts/app_generator_tts_publisher.py index 7002b1a470..b00d2b5613 100644 --- a/api/core/base/tts/app_generator_tts_publisher.py +++ b/api/core/base/tts/app_generator_tts_publisher.py @@ -91,26 +91,28 @@ class AppGeneratorTTSPublisher: ) future_queue.put(futures_result) break - elif isinstance(message.event, QueueAgentMessageEvent | QueueLLMChunkEvent): - message_content = message.event.chunk.delta.message.content - if not message_content: - continue - match message_content: - case str(): - self.msg_text += message_content - case list(): - for content in message_content: - if not isinstance(content, TextPromptMessageContent): - continue - self.msg_text += content.data - elif isinstance(message.event, QueueTextChunkEvent): - self.msg_text += message.event.text - elif isinstance(message.event, QueueNodeSucceededEvent): - if message.event.outputs is None: - continue - output = message.event.outputs.get("output", "") - if isinstance(output, str): - self.msg_text += output + else: + match message.event: + case QueueAgentMessageEvent() | QueueLLMChunkEvent(): + message_content = message.event.chunk.delta.message.content + if not message_content: + continue + match message_content: + case str(): + self.msg_text += message_content + case list(): + for content in message_content: + if not isinstance(content, TextPromptMessageContent): + continue + self.msg_text += content.data + case QueueTextChunkEvent(): + self.msg_text += message.event.text + case QueueNodeSucceededEvent(): + if message.event.outputs is None: + continue + output = message.event.outputs.get("output", "") + if isinstance(output, str): + self.msg_text += output self.last_message = message sentence_arr, text_tmp = self._extract_sentence(self.msg_text) if len(sentence_arr) >= min(self.max_sentence, 7): diff --git a/api/core/rag/extractor/blob/blob.py b/api/core/rag/extractor/blob/blob.py index b2e6d782d8..4537c1b537 100644 --- a/api/core/rag/extractor/blob/blob.py +++ b/api/core/rag/extractor/blob/blob.py @@ -54,36 +54,39 @@ class Blob(BaseModel): def as_string(self) -> str: """Read data as a string.""" - if self.data is None and self.path: - return Path(str(self.path)).read_text(encoding=self.encoding) - elif isinstance(self.data, bytes): - return self.data.decode(self.encoding) - elif isinstance(self.data, str): - return self.data - else: - raise ValueError(f"Unable to get string for blob {self}") + match self.data: + case None if self.path: + return Path(str(self.path)).read_text(encoding=self.encoding) + case bytes(): + return self.data.decode(self.encoding) + case str(): + return self.data + case _: + raise ValueError(f"Unable to get string for blob {self}") def as_bytes(self) -> bytes: """Read data as bytes.""" - if isinstance(self.data, bytes): - return self.data - elif isinstance(self.data, str): - return self.data.encode(self.encoding) - elif self.data is None and self.path: - return Path(str(self.path)).read_bytes() - else: - raise ValueError(f"Unable to get bytes for blob {self}") + match self.data: + case bytes(): + return self.data + case str(): + return self.data.encode(self.encoding) + case None if self.path: + return Path(str(self.path)).read_bytes() + case _: + raise ValueError(f"Unable to get bytes for blob {self}") @contextlib.contextmanager def as_bytes_io(self) -> Generator[BytesIO | BufferedReader, None, None]: """Read data as a byte stream.""" - if isinstance(self.data, bytes): - yield BytesIO(self.data) - elif self.data is None and self.path: - with open(str(self.path), "rb") as f: - yield f - else: - raise NotImplementedError(f"Unable to convert blob {self}") + match self.data: + case bytes(): + yield BytesIO(self.data) + case None if self.path: + with open(str(self.path), "rb") as f: + yield f + case _: + raise NotImplementedError(f"Unable to convert blob {self}") @classmethod def from_path( diff --git a/api/openapi/markdown/service-swagger.md b/api/openapi/markdown/service-swagger.md index 071b1b526c..11881c5d38 100644 --- a/api/openapi/markdown/service-swagger.md +++ b/api/openapi/markdown/service-swagger.md @@ -112,6 +112,14 @@ List annotations for the application List annotations for the application +##### Parameters + +| Name | Located in | Description | Required | Schema | +| ---- | ---------- | ----------- | -------- | ------ | +| keyword | query | Keyword to search annotations | No | string | +| limit | query | Number of annotations per page | No | integer | +| page | query | Page number | No | integer | + ##### Responses | Code | Description | Schema | @@ -2169,6 +2177,14 @@ Returns a list of available models for the specified model type. | page | integer | | Yes | | total | integer | | Yes | +#### AnnotationListQuery + +| Name | Type | Description | Required | +| ---- | ---- | ----------- | -------- | +| keyword | string | Keyword to search annotations | No | +| limit | integer | Number of annotations per page | No | +| page | integer | Page number | No | + #### AnnotationReplyActionPayload | Name | Type | Description | Required | diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 2e593ea71f..c1fe769997 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -2329,15 +2329,15 @@ class DocumentService: # if knowledge_config.data_source: # if knowledge_config.data_source.info_list.data_source_type == "upload_file": # upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids - # # type: ignore + # # count = len(upload_file_list) # elif knowledge_config.data_source.info_list.data_source_type == "notion_import": # notion_info_list = knowledge_config.data_source.info_list.notion_info_list - # for notion_info in notion_info_list: # type: ignore + # for notion_info in notion_info_list: # count = count + len(notion_info.pages) # elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": # website_info = knowledge_config.data_source.info_list.website_info_list - # count = len(website_info.urls) # type: ignore + # count = len(website_info.urls) # batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT) # if features.billing.subscription.plan == CloudPlan.SANDBOX and count > 1: @@ -2349,7 +2349,7 @@ class DocumentService: # # if dataset is empty, update dataset data_source_type # if not dataset.data_source_type: - # dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type # type: ignore + # dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type # if not dataset.indexing_technique: # if knowledge_config.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST: @@ -2386,7 +2386,7 @@ class DocumentService: # knowledge_config.retrieval_model.model_dump() # if knowledge_config.retrieval_model # else default_retrieval_model - # ) # type: ignore + # ) # documents = [] # if knowledge_config.original_document_id: @@ -2425,8 +2425,8 @@ class DocumentService: # position = DocumentService.get_documents_position(dataset.id) # document_ids = [] # duplicate_document_ids = [] - # if knowledge_config.data_source.info_list.data_source_type == "upload_file": # type: ignore - # upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore + # if knowledge_config.data_source.info_list.data_source_type == "upload_file": + # upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # for file_id in upload_file_list: # file = ( # db.session.query(UploadFile) @@ -2452,7 +2452,7 @@ class DocumentService: # name=file_name, # ).first() # if document: - # document.dataset_process_rule_id = dataset_process_rule.id # type: ignore + # document.dataset_process_rule_id = dataset_process_rule.id # document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) # document.created_from = created_from # document.doc_form = knowledge_config.doc_form @@ -2466,8 +2466,8 @@ class DocumentService: # continue # document = DocumentService.build_document( # dataset, - # dataset_process_rule.id, # type: ignore - # knowledge_config.data_source.info_list.data_source_type, # type: ignore + # dataset_process_rule.id, + # knowledge_config.data_source.info_list.data_source_type, # knowledge_config.doc_form, # knowledge_config.doc_language, # data_source_info, @@ -2482,8 +2482,8 @@ class DocumentService: # document_ids.append(document.id) # documents.append(document) # position += 1 - # elif knowledge_config.data_source.info_list.data_source_type == "notion_import": # type: ignore - # notion_info_list = knowledge_config.data_source.info_list.notion_info_list # type: ignore + # elif knowledge_config.data_source.info_list.data_source_type == "notion_import": + # notion_info_list = knowledge_config.data_source.info_list.notion_info_list # if not notion_info_list: # raise ValueError("No notion info list found.") # exist_page_ids = [] @@ -2523,8 +2523,8 @@ class DocumentService: # truncated_page_name = page.page_name[:255] if page.page_name else "nopagename" # document = DocumentService.build_document( # dataset, - # dataset_process_rule.id, # type: ignore - # knowledge_config.data_source.info_list.data_source_type, # type: ignore + # dataset_process_rule.id, + # knowledge_config.data_source.info_list.data_source_type, # knowledge_config.doc_form, # knowledge_config.doc_language, # data_source_info, @@ -2544,8 +2544,8 @@ class DocumentService: # # delete not selected documents # if len(exist_document) > 0: # clean_notion_document_task.delay(list(exist_document.values()), dataset.id) - # elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": # type: ignore - # website_info = knowledge_config.data_source.info_list.website_info_list # type: ignore + # elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": + # website_info = knowledge_config.data_source.info_list.website_info_list # if not website_info: # raise ValueError("No website info list found.") # urls = website_info.urls @@ -2563,8 +2563,8 @@ class DocumentService: # document_name = url # document = DocumentService.build_document( # dataset, - # dataset_process_rule.id, # type: ignore - # knowledge_config.data_source.info_list.data_source_type, # type: ignore + # dataset_process_rule.id, + # knowledge_config.data_source.info_list.data_source_type, # knowledge_config.doc_form, # knowledge_config.doc_language, # data_source_info, diff --git a/api/tests/unit_tests/controllers/service_api/app/test_annotation.py b/api/tests/unit_tests/controllers/service_api/app/test_annotation.py index 2ab5547cd4..6d586d31a9 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_annotation.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_annotation.py @@ -19,10 +19,12 @@ from unittest.mock import Mock import pytest from flask import Flask from flask_restx.api import HTTPStatus +from pydantic import ValidationError from controllers.service_api.app.annotation import ( AnnotationCreatePayload, AnnotationListApi, + AnnotationListQuery, AnnotationReplyActionApi, AnnotationReplyActionPayload, AnnotationReplyActionStatusApi, @@ -106,6 +108,28 @@ class TestAnnotationReplyActionPayload: assert payload.score_threshold == 0.0 +class TestAnnotationListQuery: + def test_defaults(self) -> None: + query = AnnotationListQuery.model_validate({}) + + assert query.page == 1 + assert query.limit == 20 + assert query.keyword == "" + + def test_valid_numeric_strings(self) -> None: + query = AnnotationListQuery.model_validate({"page": "2", "limit": "5", "keyword": "refund"}) + + assert query.page == 2 + assert query.limit == 5 + assert query.keyword == "refund" + + @pytest.mark.parametrize("field", ["page", "limit"]) + @pytest.mark.parametrize("value", ["abc", "1.5", "1e2", "", "0", "-1"]) + def test_invalid_explicit_pagination_value(self, field: str, value: str) -> None: + with pytest.raises(ValidationError): + AnnotationListQuery.model_validate({field: value}) + + # --------------------------------------------------------------------------- # Model and Error Pattern Tests # --------------------------------------------------------------------------- @@ -232,22 +256,55 @@ class TestAnnotationReplyActionStatusApi: class TestAnnotationListApi: - def test_get(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: + def test_get_uses_defaults(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: annotation = SimpleNamespace(id="a1", question="q", content="a", created_at=0) - monkeypatch.setattr( - AppAnnotationService, - "get_annotation_list_by_app_id", - lambda *_args, **_kwargs: ([annotation], 1), - ) + get_mock = Mock(return_value=([annotation], 1)) + monkeypatch.setattr(AppAnnotationService, "get_annotation_list_by_app_id", get_mock) api = AnnotationListApi() handler = _unwrap(api.get) app_model = SimpleNamespace(id="app") - with app.test_request_context("/apps/annotations?page=1&limit=1", method="GET"): + with app.test_request_context("/apps/annotations", method="GET"): + response = handler(api, app_model=app_model) + + assert response["page"] == 1 + assert response["limit"] == 20 + get_mock.assert_called_once_with("app", 1, 20, "") + + def test_get_accepts_valid_numeric_strings(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: + annotation = SimpleNamespace(id="a1", question="q", content="a", created_at=0) + get_mock = Mock(return_value=([annotation], 1)) + monkeypatch.setattr(AppAnnotationService, "get_annotation_list_by_app_id", get_mock) + + api = AnnotationListApi() + handler = _unwrap(api.get) + app_model = SimpleNamespace(id="app") + + with app.test_request_context("/apps/annotations?page=2&limit=5&keyword=refund", method="GET"): response = handler(api, app_model=app_model) assert response["total"] == 1 + assert response["page"] == 2 + assert response["limit"] == 5 + get_mock.assert_called_once_with("app", 2, 5, "refund") + + @pytest.mark.parametrize("query_string", ["page=abc&limit=5", "page=1&limit=abc", "page=&limit=5", "limit=0"]) + def test_get_rejects_invalid_explicit_pagination_value( + self, app: Flask, monkeypatch: pytest.MonkeyPatch, query_string: str + ) -> None: + get_mock = Mock(return_value=([], 0)) + monkeypatch.setattr(AppAnnotationService, "get_annotation_list_by_app_id", get_mock) + + api = AnnotationListApi() + handler = _unwrap(api.get) + app_model = SimpleNamespace(id="app") + + with app.test_request_context(f"/apps/annotations?{query_string}", method="GET"): + with pytest.raises(ValidationError): + handler(api, app_model=app_model) + + get_mock.assert_not_called() def test_create(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: annotation = SimpleNamespace(id="a1", question="q", content="a", created_at=0) diff --git a/cli/src/api/oauth-device.ts b/cli/src/api/oauth-device.ts index 1368a7617b..582b5f06b2 100644 --- a/cli/src/api/oauth-device.ts +++ b/cli/src/api/oauth-device.ts @@ -40,7 +40,7 @@ export type PollSuccess = { subject_type?: string subject_email?: string subject_issuer?: string - account?: PollAccount + account?: PollAccount | null workspaces?: readonly PollWorkspace[] default_workspace_id?: string token_id?: string diff --git a/cli/src/commands/auth/login/login.test.ts b/cli/src/commands/auth/login/login.test.ts index c93a2a824f..8436473634 100644 --- a/cli/src/commands/auth/login/login.test.ts +++ b/cli/src/commands/auth/login/login.test.ts @@ -106,6 +106,8 @@ describe('runLogin', () => { expect(bundle.account).toBeUndefined() expect(bundle.external_subject?.email).toBe('sso@dify.ai') expect(bundle.external_subject?.issuer).toBe('https://issuer.example') + const stored = await store.get(bundle.current_host, 'sso@dify.ai') + expect(stored).toBe('dfoe_test') expect(io.outBuf()).toContain('external SSO') expect(io.outBuf()).toContain('sso@dify.ai') }) diff --git a/cli/src/commands/auth/login/login.ts b/cli/src/commands/auth/login/login.ts index b06dca24f3..d30ebee26a 100644 --- a/cli/src/commands/auth/login/login.ts +++ b/cli/src/commands/auth/login/login.ts @@ -99,7 +99,7 @@ function renderCodePrompt(w: NodeJS.WritableStream, cs: ReturnType, host: string, s: PollSuccess): void { const display = bareHost(host) - if (s.account !== undefined && s.account.email !== '') { + if (s.account && s.account.email !== '') { w.write(`${cs.successIcon()} Logged in to ${display} as ${cs.bold(s.account.email)} (${s.account.name})\n`) const ws = findDefaultWorkspace(s) if (ws !== undefined) @@ -139,11 +139,11 @@ function bundleFromSuccess(host: string, s: PollSuccess, mode: StorageMode): Hos token_id: s.token_id, tokens: { bearer: s.token }, } - if (s.account !== undefined) { + if (s.account) { bundle.account = { id: s.account.id, email: s.account.email, name: s.account.name } } if (s.subject_email !== undefined && s.subject_email !== '' - && (s.account === undefined || s.account.id === '')) { + && (!s.account || s.account.id === '')) { bundle.external_subject = { email: s.subject_email, issuer: s.subject_issuer ?? '', diff --git a/cli/src/commands/create/member/index.ts b/cli/src/commands/create/member/index.ts index fe5b712769..7a0ee2a935 100644 --- a/cli/src/commands/create/member/index.ts +++ b/cli/src/commands/create/member/index.ts @@ -1,5 +1,5 @@ import { Flags } from '../../../framework/flags.js' -import { formatted } from '../../../framework/output.js' +import { formatted, OutputFormat } from '../../../framework/output.js' import { DifyCommand } from '../../_shared/dify-command.js' import { httpRetryFlag } from '../../_shared/global-flags.js' import { runCreateMember } from './run.js' @@ -24,7 +24,7 @@ export default class CreateMember extends DifyCommand { description: 'workspace id (overrides DIFY_WORKSPACE_ID and stored default)', }), 'http-retry': httpRetryFlag, - 'output': Flags.string({ char: 'o', description: 'output format (json|yaml|name|text)', default: '' }), + 'output': Flags.outputFormat({ options: [OutputFormat.JSON, OutputFormat.YAML, OutputFormat.NAME, OutputFormat.TEXT], default: '' }), } async run(argv: string[]) { diff --git a/cli/src/commands/delete/member/index.ts b/cli/src/commands/delete/member/index.ts index f455de9fbb..1ec7956502 100644 --- a/cli/src/commands/delete/member/index.ts +++ b/cli/src/commands/delete/member/index.ts @@ -1,5 +1,5 @@ import { Args, Flags } from '../../../framework/flags.js' -import { formatted } from '../../../framework/output.js' +import { formatted, OutputFormat } from '../../../framework/output.js' import { DifyCommand } from '../../_shared/dify-command.js' import { httpRetryFlag } from '../../_shared/global-flags.js' import { runDeleteMember } from './run.js' @@ -23,7 +23,7 @@ export default class DeleteMember extends DifyCommand { description: 'workspace id (overrides DIFY_WORKSPACE_ID and stored default)', }), 'http-retry': httpRetryFlag, - 'output': Flags.string({ char: 'o', description: 'output format (json|yaml|name|text)', default: '' }), + 'output': Flags.outputFormat({ options: [OutputFormat.JSON, OutputFormat.YAML, OutputFormat.NAME, OutputFormat.TEXT], default: '' }), 'yes': Flags.boolean({ char: 'y', description: 'skip confirmation prompt', default: false }), } diff --git a/cli/src/commands/describe/app/index.ts b/cli/src/commands/describe/app/index.ts index 514201bdf0..61a87005d3 100644 --- a/cli/src/commands/describe/app/index.ts +++ b/cli/src/commands/describe/app/index.ts @@ -1,5 +1,5 @@ import { Args, Flags } from '../../../framework/flags.js' -import { formatted } from '../../../framework/output.js' +import { formatted, OutputFormat } from '../../../framework/output.js' import { DifyCommand } from '../../_shared/dify-command.js' import { httpRetryFlag } from '../../_shared/global-flags.js' import { runDescribeApp } from './run.js' @@ -20,7 +20,7 @@ export default class DescribeApp extends DifyCommand { static override flags = { 'workspace': Flags.string({ description: 'workspace id (overrides DIFY_WORKSPACE_ID and stored default)' }), 'http-retry': httpRetryFlag, - 'output': Flags.string({ char: 'o', description: 'output format (json|yaml|text)', default: '' }), + 'output': Flags.outputFormat({ options: [OutputFormat.JSON, OutputFormat.YAML, OutputFormat.TEXT], default: '' }), 'refresh': Flags.boolean({ description: 'bypass app-info cache and fetch fresh', default: false }), } diff --git a/cli/src/commands/get/app/index.ts b/cli/src/commands/get/app/index.ts index c5ce8516aa..c4ec2bd06c 100644 --- a/cli/src/commands/get/app/index.ts +++ b/cli/src/commands/get/app/index.ts @@ -1,6 +1,6 @@ import type { AppMode } from '@dify/contracts/api/openapi/types.gen' import { Args, Flags } from '../../../framework/flags.js' -import { table } from '../../../framework/output.js' +import { OutputFormat, table } from '../../../framework/output.js' import { DifyCommand } from '../../_shared/dify-command.js' import { httpRetryFlag } from '../../_shared/global-flags.js' import { runGetApp } from './run.js' @@ -42,7 +42,7 @@ export default class GetApp extends DifyCommand { 'name': Flags.string({ description: 'filter by app name (server-side substring)' }), 'tag': Flags.string({ description: 'filter by tag name (server-side exact match)' }), 'http-retry': httpRetryFlag, - 'output': Flags.string({ char: 'o', description: 'output format (json|yaml|name|wide)', default: '' }), + 'output': Flags.outputFormat({ options: [OutputFormat.JSON, OutputFormat.YAML, OutputFormat.NAME, OutputFormat.WIDE], default: '' }), } async run(argv: string[]) { diff --git a/cli/src/commands/get/member/index.ts b/cli/src/commands/get/member/index.ts index 44a3dd241a..b1b4d82033 100644 --- a/cli/src/commands/get/member/index.ts +++ b/cli/src/commands/get/member/index.ts @@ -1,5 +1,5 @@ import { Flags } from '../../../framework/flags.js' -import { table } from '../../../framework/output.js' +import { OutputFormat, table } from '../../../framework/output.js' import { DifyCommand } from '../../_shared/dify-command.js' import { httpRetryFlag } from '../../_shared/global-flags.js' import { runGetMember } from './run.js' @@ -23,7 +23,7 @@ export default class GetMember extends DifyCommand { 'page': Flags.integer({ description: 'page number', default: 1 }), 'limit': Flags.string({ description: 'page size [1..200]' }), 'http-retry': httpRetryFlag, - 'output': Flags.string({ char: 'o', description: 'output format (json|yaml|name|wide)', default: '' }), + 'output': Flags.outputFormat({ options: [OutputFormat.JSON, OutputFormat.YAML, OutputFormat.NAME, OutputFormat.WIDE], default: '' }), } async run(argv: string[]) { diff --git a/cli/src/commands/get/workspace/index.ts b/cli/src/commands/get/workspace/index.ts index f1edd17b03..3364f50761 100644 --- a/cli/src/commands/get/workspace/index.ts +++ b/cli/src/commands/get/workspace/index.ts @@ -1,5 +1,5 @@ import { Flags } from '../../../framework/flags.js' -import { raw, table } from '../../../framework/output.js' +import { OutputFormat, raw, table } from '../../../framework/output.js' import { DifyCommand } from '../../_shared/dify-command.js' import { httpRetryFlag } from '../../_shared/global-flags.js' import { runGetWorkspace } from './run.js' @@ -15,7 +15,7 @@ export default class GetWorkspace extends DifyCommand { static override flags = { 'http-retry': httpRetryFlag, - 'output': Flags.string({ char: 'o', description: 'output format (json|yaml|name|wide)', default: '' }), + 'output': Flags.outputFormat({ options: [OutputFormat.JSON, OutputFormat.YAML, OutputFormat.NAME, OutputFormat.WIDE], default: '' }), } async run(argv: string[]) { diff --git a/cli/src/commands/resume/app/index.ts b/cli/src/commands/resume/app/index.ts index 6498549493..99ef349d7b 100644 --- a/cli/src/commands/resume/app/index.ts +++ b/cli/src/commands/resume/app/index.ts @@ -1,4 +1,5 @@ import { Args, Flags } from '../../../framework/flags.js' +import { OutputFormat } from '../../../framework/output.js' import { DifyCommand } from '../../_shared/dify-command.js' import { httpRetryFlag } from '../../_shared/global-flags.js' import { resumeApp } from './run.js' @@ -25,7 +26,7 @@ export default class ResumeApp extends DifyCommand { 'with-history': Flags.boolean({ description: 'Replay executed-node history before attaching to live stream.', default: false }), 'stream': Flags.boolean({ description: 'Print output live as tokens/events arrive. Default: collect and print at end.', default: false }), 'think': Flags.boolean({ description: 'Show model thinking/reasoning when available. Strips ... blocks silently by default; with --think, thinking is printed to stderr.', default: false }), - 'output': Flags.string({ char: 'o', description: 'output format (json|yaml|text)', default: '' }), + 'output': Flags.outputFormat({ options: [OutputFormat.JSON, OutputFormat.YAML, OutputFormat.TEXT], default: '' }), 'http-retry': httpRetryFlag, } diff --git a/cli/src/commands/run/app/index.ts b/cli/src/commands/run/app/index.ts index 4799303249..77a55cf5b0 100644 --- a/cli/src/commands/run/app/index.ts +++ b/cli/src/commands/run/app/index.ts @@ -1,4 +1,5 @@ import { Args, Flags } from '../../../framework/flags.js' +import { OutputFormat } from '../../../framework/output.js' import { DifyCommand } from '../../_shared/dify-command.js' import { httpRetryFlag } from '../../_shared/global-flags.js' import { agentGuide } from './guide.js' @@ -32,7 +33,7 @@ export default class RunApp extends DifyCommand { 'stream': Flags.boolean({ description: 'Print output live as tokens/events arrive (default: collect and print at end)', default: false }), 'think': Flags.boolean({ description: 'Show model thinking/reasoning when available. Strips ... blocks silently by default; with --think, thinking is printed to stderr.', default: false }), 'http-retry': httpRetryFlag, - 'output': Flags.string({ char: 'o', description: 'Output format (json|yaml|text)', default: '' }), + 'output': Flags.outputFormat({ options: [OutputFormat.JSON, OutputFormat.YAML, OutputFormat.TEXT], default: '' }), } async run(argv: string[]): Promise { diff --git a/cli/src/commands/set/member/index.ts b/cli/src/commands/set/member/index.ts index 3cbf3bf106..f9ffdb286c 100644 --- a/cli/src/commands/set/member/index.ts +++ b/cli/src/commands/set/member/index.ts @@ -1,5 +1,5 @@ import { Args, Flags } from '../../../framework/flags.js' -import { formatted } from '../../../framework/output.js' +import { formatted, OutputFormat } from '../../../framework/output.js' import { DifyCommand } from '../../_shared/dify-command.js' import { httpRetryFlag } from '../../_shared/global-flags.js' import { runSetMember } from './run.js' @@ -27,7 +27,7 @@ export default class SetMember extends DifyCommand { description: 'workspace id (overrides DIFY_WORKSPACE_ID and stored default)', }), 'http-retry': httpRetryFlag, - 'output': Flags.string({ char: 'o', description: 'output format (json|yaml|name|text)', default: '' }), + 'output': Flags.outputFormat({ options: [OutputFormat.JSON, OutputFormat.YAML, OutputFormat.NAME, OutputFormat.TEXT], default: '' }), } async run(argv: string[]) { diff --git a/cli/src/commands/version/index.ts b/cli/src/commands/version/index.ts index 01d398a001..884824d232 100644 --- a/cli/src/commands/version/index.ts +++ b/cli/src/commands/version/index.ts @@ -1,5 +1,5 @@ import { Flags } from '../../framework/flags.js' -import { formatted, raw, stringifyOutput } from '../../framework/output.js' +import { formatted, OutputFormat, raw, stringifyOutput } from '../../framework/output.js' import { colorEnabled } from '../../sys/io/color.js' import { realStreams } from '../../sys/io/streams' import { versionInfo } from '../../version/info.js' @@ -20,11 +20,7 @@ export default class Version extends DifyCommand { ] static override flags = { - 'output': Flags.string({ - char: 'o', - description: 'output format (text|json|yaml)', - default: '', - }), + 'output': Flags.outputFormat({ options: [OutputFormat.TEXT, OutputFormat.JSON, OutputFormat.YAML], default: '' }), 'client': Flags.boolean({ description: 'skip server probe' }), 'short': Flags.boolean({ description: 'print only the client semver' }), 'check-compat': Flags.boolean({ diff --git a/cli/src/errors/codes.test.ts b/cli/src/errors/codes.test.ts index 101eb2eead..fcaf55e00a 100644 --- a/cli/src/errors/codes.test.ts +++ b/cli/src/errors/codes.test.ts @@ -8,8 +8,8 @@ import { } from './codes.js' describe('error codes', () => { - it('has 18 codes (parity with internal/api/errors)', () => { - expect(ALL_ERROR_CODES).toHaveLength(18) + it('has correct number codes (parity with internal/api/errors)', () => { + expect(ALL_ERROR_CODES).toHaveLength(Object.keys(CODE_TO_EXIT_MAP).length) }) it('has the expected ExitCode buckets', () => { diff --git a/cli/src/errors/codes.ts b/cli/src/errors/codes.ts index f435476812..0d157f90cb 100644 --- a/cli/src/errors/codes.ts +++ b/cli/src/errors/codes.ts @@ -17,6 +17,7 @@ export const ErrorCode = { Server4xxOther: 'server_4xx_other', ClientError: 'client_error', Unknown: 'unknown', + IllegalArgumentError: 'illegal_argument', } as const export type ErrorCodeValue = (typeof ErrorCode)[keyof typeof ErrorCode] @@ -50,6 +51,7 @@ const CODE_TO_EXIT: Readonly> = { server_4xx_other: ExitCode.Generic, client_error: ExitCode.Generic, unknown: ExitCode.Generic, + illegal_argument: ExitCode.Usage, } export function exitFor(code: string): ExitCodeValue { diff --git a/cli/src/framework/errors.test.ts b/cli/src/framework/errors.test.ts new file mode 100644 index 0000000000..90b9061648 --- /dev/null +++ b/cli/src/framework/errors.test.ts @@ -0,0 +1,34 @@ +import type { FlagDefinition } from './types.js' +import { describe, expect, it } from 'vitest' +import { OutputFormatNotSupportedError, UnsupportedArgValueError } from './errors.js' + +describe('OutputFormatNotSupportedError', () => { + it('states the offending format in the message', () => { + const err = new OutputFormatNotSupportedError('csv') + expect(err.message).toBe('format csv is not supported by this command') + }) +}) + +describe('UnsupportedArgValueError', () => { + it('includes both long and short option labels when a char exists', () => { + const def: FlagDefinition = { type: 'string', description: 'output', char: 'o', options: ['json', 'yaml'] } + const err = new UnsupportedArgValueError('output', def, 'csv') + expect(err.message).toBe('illegal value csv for flag --output / -o') + }) + + it('omits the short option label when the flag has no char', () => { + const def: FlagDefinition = { type: 'string', description: 'app mode', options: ['chat', 'workflow'] } + const err = new UnsupportedArgValueError('mode', def, 'chatbot') + expect(err.message).toBe('illegal value chatbot for flag --mode') + }) + + it('lists supported values in the hint', () => { + const def: FlagDefinition = { type: 'string', description: 'app mode', options: ['chat', 'workflow'] } + expect(new UnsupportedArgValueError('mode', def, 'chatbot').hint).toBe('supported value: chat, workflow') + }) + + it('leaves the hint empty when the flag declares no options', () => { + const def: FlagDefinition = { type: 'string', description: 'app mode' } + expect(new UnsupportedArgValueError('mode', def, 'chatbot').hint).toBe('') + }) +}) diff --git a/cli/src/framework/errors.ts b/cli/src/framework/errors.ts new file mode 100644 index 0000000000..b1cb4e1c29 --- /dev/null +++ b/cli/src/framework/errors.ts @@ -0,0 +1,23 @@ +import type { FlagDefinition } from './types' +import { BaseError } from '../errors/base' +import { ErrorCode } from '../errors/codes' + +export class OutputFormatNotSupportedError extends BaseError { + constructor(format: string) { + super({ + code: ErrorCode.IllegalArgumentError, + message: `format ${format} is not supported by this command`, + }) + } +} + +export class UnsupportedArgValueError extends BaseError { + constructor(flagName: string, flagDef: FlagDefinition, givenValue: string) { + const flagLabel = flagDef.char ? `--${flagName} / -${flagDef.char}` : `--${flagName}` + super({ + code: ErrorCode.IllegalArgumentError, + message: `illegal value ${givenValue} for flag ${flagLabel}`, + hint: flagDef.options ? `supported value: ${flagDef.options.join(', ')}` : '', + }) + } +} diff --git a/cli/src/framework/flags.test.ts b/cli/src/framework/flags.test.ts index efab143a06..c68d7733e4 100644 --- a/cli/src/framework/flags.test.ts +++ b/cli/src/framework/flags.test.ts @@ -1,4 +1,5 @@ import { describe, expect, it } from 'vitest' +import { UnsupportedArgValueError } from './errors.js' import { Args, Flags, parseArgv } from './flags.js' const meta = { @@ -190,13 +191,13 @@ describe('parseArgv', () => { it('rejects an invalid option value (space form)', () => { expect(() => parseArgv(['--mode', 'chatbot'], metaWithOptions)).toThrow( - '--mode must be one of: chat, workflow, completion', + UnsupportedArgValueError, ) }) it('rejects an invalid option value (= form)', () => { expect(() => parseArgv(['--mode=chatbot'], metaWithOptions)).toThrow( - '--mode must be one of: chat, workflow, completion', + UnsupportedArgValueError, ) }) }) diff --git a/cli/src/framework/flags.ts b/cli/src/framework/flags.ts index c5a3baa583..b730e0de9b 100644 --- a/cli/src/framework/flags.ts +++ b/cli/src/framework/flags.ts @@ -1,6 +1,12 @@ import type { ArgDefinition, CommandMeta, FlagDefinition, ParsedArgs, ParsedFlags } from './types.js' +import { UnsupportedArgValueError } from './errors.js' -function stringFlag( +function stringFlag( opts: Opts, ): FlagDefinition { return { @@ -10,7 +16,19 @@ function stringFlag( +function outputFormatFlag( + opts: Opts, +): FlagDefinition { + return { + type: 'string', + description: `output format (${opts.options.join('|')})`, + char: 'o', + multiple: false, + ...opts, + } +} + +function stringRepeatedFlag( opts: Opts, ): FlagDefinition { return { @@ -20,11 +38,11 @@ function stringRepeatedFlag { +function booleanFlag(opts: { description: string, char?: string, default?: boolean }): FlagDefinition { return { type: 'boolean', ...opts } } -function integerFlag( +function integerFlag( opts: Opts, ): FlagDefinition { return { type: 'integer', ...opts } as FlagDefinition @@ -35,6 +53,7 @@ export const Flags = { stringArray: stringRepeatedFlag, boolean: booleanFlag, integer: integerFlag, + outputFormat: outputFormatFlag, } function stringArg( @@ -91,7 +110,32 @@ function resolveByChar(char: string, meta: CommandMeta): [name: string, def: Fla function validateFlagOptions(name: string, raw: string, def: FlagDefinition): void { if (def.options !== undefined && !def.options.includes(raw)) - throw new Error(`--${name} must be one of: ${def.options.join(', ')}`) + throw new UnsupportedArgValueError(name, def, raw) +} + +type ResolvedFlag = { name: string, def: FlagDefinition, label: string, inlineRaw: string | undefined } + +function resolveToken(token: string, meta: CommandMeta): ResolvedFlag | null { + if (token.startsWith('--')) { + const eqIdx = token.indexOf('=') + const name = eqIdx !== -1 ? token.slice(2, eqIdx) : token.slice(2) + const inlineRaw = eqIdx !== -1 ? token.slice(eqIdx + 1) : undefined + const def = meta.flags[name] + if (!def) + throw new Error(`unknown flag: --${name}`) + return { name, def, label: `--${name}`, inlineRaw } + } + + if (token.length === 2 && token[1] !== undefined) { + const char = token[1] + const resolved = resolveByChar(char, meta) + if (!resolved) + throw new Error(`unknown flag: -${char}`) + const [name, def] = resolved + return { name, def, label: `-${char}`, inlineRaw: undefined } + } + + return null } export function parseArgv(argv: readonly string[], meta: CommandMeta): { args: ParsedArgs, flags: ParsedFlags } { @@ -110,63 +154,38 @@ export function parseArgv(argv: readonly string[], meta: CommandMeta): { args: P continue } - if (!pastDoubleDash && token.startsWith('--')) { - const eqIdx = token.indexOf('=') - let name: string - let rawValue: string | undefined - - if (eqIdx !== -1) { - name = token.slice(2, eqIdx) - rawValue = token.slice(eqIdx + 1) - } - else { - name = token.slice(2) - rawValue = undefined - } - - const def = meta.flags[name] - if (!def) - throw new Error(`unknown flag: --${name}`) - - if (def.type === 'boolean') { - flags[name] = rawValue === undefined ? true : coerceFlagValue(rawValue, def) - } - else if (rawValue !== undefined) { - validateFlagOptions(name, rawValue, def) - accumulateFlagValue(flags, name, coerceFlagValue(rawValue, def), def) - } - else { - i++ - const next = i < argv.length ? argv[i] : undefined - if (next === undefined || next.startsWith('-')) - throw new Error(`flag --${name} expects a value`) - - validateFlagOptions(name, next, def) - accumulateFlagValue(flags, name, coerceFlagValue(next, def), def) - } + if (pastDoubleDash || !token.startsWith('-')) { + positional.push(token) + continue } - else if (!pastDoubleDash && token.startsWith('-') && token.length === 2 && token[1] !== undefined) { - const char = token[1] - const resolved = resolveByChar(char, meta) - if (!resolved) - throw new Error(`unknown flag: -${char}`) - const [flagName, def] = resolved - if (def.type === 'boolean') { - flags[flagName] = true - } - else { - i++ - const next = i < argv.length ? argv[i] : undefined - if (next === undefined || next.startsWith('-')) - throw new Error(`flag -${char} expects a value`) + const resolved = resolveToken(token, meta) + if (!resolved) { + positional.push(token) + continue + } - accumulateFlagValue(flags, flagName, coerceFlagValue(next, def), def) - } + const { name, def, label, inlineRaw } = resolved + + if (def.type === 'boolean') { + flags[name] = inlineRaw === undefined ? true : coerceFlagValue(inlineRaw, def) + continue + } + + let raw: string + if (inlineRaw !== undefined) { + raw = inlineRaw } else { - positional.push(token) + i++ + const next = i < argv.length ? argv[i] : undefined + if (next === undefined || next.startsWith('-')) + throw new Error(`flag ${label} expects a value`) + raw = next } + + validateFlagOptions(name, raw, def) + accumulateFlagValue(flags, name, coerceFlagValue(raw, def), def) } const args: ParsedArgs = {} diff --git a/cli/src/framework/output.test.ts b/cli/src/framework/output.test.ts index accd8328b5..df5ba2ff26 100644 --- a/cli/src/framework/output.test.ts +++ b/cli/src/framework/output.test.ts @@ -1,5 +1,6 @@ import type { FormattedPrintable, NamePrintable, TablePrintable } from './output.js' import { describe, expect, it } from 'vitest' +import { OutputFormatNotSupportedError } from './errors.js' import { formatted, raw, @@ -99,13 +100,12 @@ describe('stringifyOutput — formatted', () => { json: () => ({}), } const out = formatted({ format: 'name', data: noName }) - expect(() => stringifyOutput(out)).toThrow('name output requires data.name()') + expect(() => stringifyOutput(out)).toThrow(OutputFormatNotSupportedError) }) it('unknown format: throws with allowed list', () => { const out = formatted({ format: 'csv', data: makeFormatted({}) }) - expect(() => stringifyOutput(out)).toThrow(/not supported/) - expect(() => stringifyOutput(out)).toThrow(/json, name, text, yaml/) + expect(() => stringifyOutput(out)).toThrow(OutputFormatNotSupportedError) }) }) @@ -175,13 +175,12 @@ describe('stringifyOutput — table', () => { json: () => [], } const out = table({ format: 'name', data: noName }) - expect(() => stringifyOutput(out)).toThrow('name output requires data.name()') + expect(() => stringifyOutput(out)).toThrow(OutputFormatNotSupportedError) }) it('unknown format: throws with allowed list', () => { const out = table({ format: 'csv', data: makeTable({}) }) - expect(() => stringifyOutput(out)).toThrow(/not supported/) - expect(() => stringifyOutput(out)).toThrow(/json, name, wide, yaml/) + expect(() => stringifyOutput(out)).toThrow(OutputFormatNotSupportedError) }) it('table renders column padding correctly', () => { diff --git a/cli/src/framework/output.ts b/cli/src/framework/output.ts index 7f56c1b1e2..fabf48caaf 100644 --- a/cli/src/framework/output.ts +++ b/cli/src/framework/output.ts @@ -1,4 +1,5 @@ import yaml from 'js-yaml' +import { OutputFormatNotSupportedError } from './errors' export type RawOutput = { readonly kind: 'raw' @@ -31,6 +32,14 @@ export type JsonPrintable = { readonly json: () => unknown } +export const OutputFormat = { + NAME: 'name', + JSON: 'json', + YAML: 'yaml', + TEXT: 'text', + WIDE: 'wide', +} as const + export type TableOutput = { readonly kind: 'table' readonly format: string @@ -77,32 +86,32 @@ export function stringifyOutput(output: CommandOutput): string { function stringifyFormattedOutput(output: FormattedOutput): string { switch (output.format) { case '': - case 'text': + case OutputFormat.TEXT: return output.data.text() - case 'json': + case OutputFormat.JSON: return `${JSON.stringify(output.data.json(), null, 2)}\n` - case 'yaml': + case OutputFormat.YAML: return yaml.dump(output.data.json(), { indent: 2, lineWidth: -1 }) - case 'name': + case OutputFormat.NAME: return `${toName(output.data)}\n` default: - throw new Error(`output format ${JSON.stringify(output.format)} not supported, allowed: json, name, text, yaml`) + throw new OutputFormatNotSupportedError(output.format) } } function stringifyTableOutput(output: TableOutput): string { switch (output.format) { case '': - case 'wide': + case OutputFormat.WIDE: return renderTable(output) - case 'json': + case OutputFormat.JSON: return `${JSON.stringify(output.data.json(), null, 2)}\n` - case 'yaml': + case OutputFormat.YAML: return yaml.dump(output.data.json(), { indent: 2, lineWidth: -1 }) - case 'name': + case OutputFormat.NAME: return `${toName(output.data)}\n` default: - throw new Error(`output format ${JSON.stringify(output.format)} not supported, allowed: json, name, wide, yaml`) + throw new OutputFormatNotSupportedError(output.format) } } @@ -186,7 +195,7 @@ function formatTable(rows: readonly (readonly string[])[]): string { function toName(data: TablePrintable | FormattedPrintable): string { if (!isNamePrintable(data)) - throw new Error('name output requires data.name()') + throw new OutputFormatNotSupportedError('name') return data.name() } diff --git a/cli/src/framework/types.ts b/cli/src/framework/types.ts index 1db95e9d16..62487c3622 100644 --- a/cli/src/framework/types.ts +++ b/cli/src/framework/types.ts @@ -9,7 +9,6 @@ export type FlagDefinition Scenario, state?: MockState): Hono { subject_type: 'external_sso', subject_email: 'sso@dify.ai', subject_issuer: 'https://issuer.example', + account: null, + workspaces: [], + default_workspace_id: null, token_id: 'tok-sso-1', }) } diff --git a/e2e/features/step-definitions/apps/publish-app.steps.ts b/e2e/features/step-definitions/apps/publish-app.steps.ts index de4f5ee63f..c426bc4c5a 100644 --- a/e2e/features/step-definitions/apps/publish-app.steps.ts +++ b/e2e/features/step-definitions/apps/publish-app.steps.ts @@ -7,9 +7,13 @@ When('I open the publish panel', async function (this: DifyWorld) { }) When('I publish the app', async function (this: DifyWorld) { - await this.getPage().getByRole('button', { name: /Publish Update/ }).click() + await this.getPage() + .getByRole('button', { name: /Publish Update/ }) + .click() }) Then('the app should be marked as published', async function (this: DifyWorld) { - await expect(this.getPage().getByRole('button', { name: 'Published' })).toBeVisible({ timeout: 30_000 }) + await expect(this.getPage().getByRole('button', { name: 'Published' })).toBeVisible({ + timeout: 30_000, + }) }) diff --git a/e2e/features/step-definitions/apps/share-app.steps.ts b/e2e/features/step-definitions/apps/share-app.steps.ts index 3ec038b065..c7acc91ebe 100644 --- a/e2e/features/step-definitions/apps/share-app.steps.ts +++ b/e2e/features/step-definitions/apps/share-app.steps.ts @@ -1,13 +1,21 @@ import type { DifyWorld } from '../../support/world' import { Given, Then, When } from '@cucumber/cucumber' import { expect } from '@playwright/test' -import { createTestApp, enableAppSiteAndGetURL, publishWorkflowApp, syncRunnableWorkflowDraft } from '../../../support/api' +import { + createTestApp, + enableAppSiteAndGetURL, + publishWorkflowApp, + syncRunnableWorkflowDraft, +} from '../../../support/api' When('I enable the Web App share', async function (this: DifyWorld) { const page = this.getPage() const appName = this.lastCreatedAppName - if (!appName) - throw new Error('No app name available. Run "a \\"workflow\\" app has been created via API" first.') + if (!appName) { + throw new Error( + 'No app name available. Run "a \\"workflow\\" app has been created via API" first.', + ) + } await page.locator('button').filter({ hasText: appName }).filter({ hasText: 'Workflow' }).click() await expect(page.getByRole('switch').first()).toBeEnabled({ timeout: 15_000 }) @@ -28,8 +36,11 @@ Given('a workflow app has been published and shared via API', async function (th }) When('I open the shared app URL', async function (this: DifyWorld) { - if (!this.shareURL) - throw new Error('No share URL available. Run "a workflow app has been published and shared via API" first.') + if (!this.shareURL) { + throw new Error( + 'No share URL available. Run "a workflow app has been published and shared via API" first.', + ) + } await this.getPage().goto(this.shareURL, { timeout: 20_000 }) }) diff --git a/e2e/features/step-definitions/apps/workflow-run.steps.ts b/e2e/features/step-definitions/apps/workflow-run.steps.ts index 84c03bfa8f..c225591d69 100644 --- a/e2e/features/step-definitions/apps/workflow-run.steps.ts +++ b/e2e/features/step-definitions/apps/workflow-run.steps.ts @@ -12,7 +12,7 @@ Given('a minimal runnable workflow draft has been synced', async function (this: When('I run the workflow', async function (this: DifyWorld) { const page = this.getPage() - const testRunButton = page.getByText('Test Run') + const testRunButton = page.getByRole('button', { name: /Test Run/ }) await expect(testRunButton).toBeVisible({ timeout: 15_000 }) await testRunButton.click() @@ -20,6 +20,6 @@ When('I run the workflow', async function (this: DifyWorld) { Then('the workflow run should succeed', async function (this: DifyWorld) { const page = this.getPage() - await page.getByText('DETAIL').click() - await expect(page.getByText('SUCCESS').first()).toBeVisible({ timeout: 55_000 }) + await page.getByText('DETAIL', { exact: true }).click() + await expect(page.getByText('SUCCESS', { exact: true }).first()).toBeVisible({ timeout: 55_000 }) }) diff --git a/packages/contracts/generated/api/service/orpc.gen.ts b/packages/contracts/generated/api/service/orpc.gen.ts index 43b2d4402b..2c046d3054 100644 --- a/packages/contracts/generated/api/service/orpc.gen.ts +++ b/packages/contracts/generated/api/service/orpc.gen.ts @@ -24,6 +24,7 @@ import { zGetAppFeedbacksResponse, zGetAppsAnnotationReplyByActionStatusByJobIdPath, zGetAppsAnnotationReplyByActionStatusByJobIdResponse, + zGetAppsAnnotationsQuery, zGetAppsAnnotationsResponse, zGetConversationsByCIdVariablesPath, zGetConversationsByCIdVariablesQuery, @@ -379,6 +380,7 @@ export const get4 = oc summary: 'List annotations for the application', tags: ['service_api'], }) + .input(z.object({ query: zGetAppsAnnotationsQuery.optional() })) .output(zGetAppsAnnotationsResponse) /** diff --git a/packages/contracts/generated/api/service/types.gen.ts b/packages/contracts/generated/api/service/types.gen.ts index cd84f94d81..4e187e7202 100644 --- a/packages/contracts/generated/api/service/types.gen.ts +++ b/packages/contracts/generated/api/service/types.gen.ts @@ -25,6 +25,12 @@ export type AnnotationList = { total: number } +export type AnnotationListQuery = { + keyword?: string + limit?: number + page?: number +} + export type AnnotationReplyActionPayload = { embedding_model_name: string embedding_provider_name: string @@ -969,7 +975,11 @@ export type GetAppsAnnotationReplyByActionStatusByJobIdResponse export type GetAppsAnnotationsData = { body?: never path?: never - query?: never + query?: { + keyword?: string + limit?: number + page?: number + } url: '/apps/annotations' } diff --git a/packages/contracts/generated/api/service/zod.gen.ts b/packages/contracts/generated/api/service/zod.gen.ts index e3008ddfbf..ae7c5cbf6c 100644 --- a/packages/contracts/generated/api/service/zod.gen.ts +++ b/packages/contracts/generated/api/service/zod.gen.ts @@ -32,6 +32,15 @@ export const zAnnotationList = z.object({ total: z.int(), }) +/** + * AnnotationListQuery + */ +export const zAnnotationListQuery = z.object({ + keyword: z.string().optional().default(''), + limit: z.int().gte(1).optional().default(20), + page: z.int().gte(1).optional().default(1), +}) + /** * AnnotationReplyActionPayload */ @@ -1216,6 +1225,12 @@ export const zGetAppsAnnotationReplyByActionStatusByJobIdResponse = z.record( z.unknown(), ) +export const zGetAppsAnnotationsQuery = z.object({ + keyword: z.string().optional().default(''), + limit: z.int().gte(1).optional().default(20), + page: z.int().gte(1).optional().default(1), +}) + /** * Annotations retrieved successfully */ diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 3ea3c29f54..28cf3c242e 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -513,6 +513,9 @@ catalogs: scheduler: specifier: 0.27.0 version: 0.27.0 + server-only: + specifier: 0.0.1 + version: 0.0.1 sharp: specifier: 0.34.5 version: 0.34.5 @@ -1248,6 +1251,9 @@ importers: scheduler: specifier: 'catalog:' version: 0.27.0 + server-only: + specifier: 'catalog:' + version: 0.0.1 sharp: specifier: 'catalog:' version: 0.34.5 @@ -8424,6 +8430,9 @@ packages: resolution: {integrity: sha512-OwrZRZAfhHww0WEnKHDY8OM0U/Qs8OTfIDWhUD4BLpNJUfXK4cGmjiagGze086m+mhI+V2nD0gfbHEnJjb9STA==} engines: {node: '>=10'} + server-only@0.0.1: + resolution: {integrity: sha512-qepMx2JxAa5jjfzxG79yPPq+8BuFToHd1hm7kI+Z4zAq1ftQiP7HcxMhDDItrbtwVeLg/cY2JnKnrcFkmiswNA==} + sharp@0.34.5: resolution: {integrity: sha512-Ou9I5Ft9WNcCbXrU9cMgPBcCK8LiwLqcbywW3t4oDV37n1pzpuNLsYiAV8eODnjbtQlSDwZ2cUEeQz4E54Hltg==} engines: {node: ^18.17.0 || ^20.3.0 || >=21.0.0} @@ -16656,6 +16665,8 @@ snapshots: seroval@1.5.1: {} + server-only@0.0.1: {} + sharp@0.34.5: dependencies: '@img/colour': 1.1.0 @@ -17774,6 +17785,7 @@ time: remark-breaks@4.0.0: '2023-09-22T16:45:41.061Z' remark-directive@4.0.0: '2025-02-27T15:15:20.630Z' scheduler@0.27.0: '2025-10-01T21:39:15.208Z' + server-only@0.0.1: '2022-09-03T01:07:26.139Z' sharp@0.34.5: '2025-11-06T14:19:40.989Z' shiki@4.1.0: '2026-05-19T07:51:34.358Z' socket.io-client@4.8.3: '2025-12-23T16:39:16.428Z' diff --git a/pnpm-workspace.yaml b/pnpm-workspace.yaml index 666a172437..2957d678d9 100644 --- a/pnpm-workspace.yaml +++ b/pnpm-workspace.yaml @@ -214,6 +214,7 @@ catalog: remark-breaks: 4.0.0 remark-directive: 4.0.0 scheduler: 0.27.0 + server-only: 0.0.1 sharp: 0.34.5 shiki: 4.1.0 socket.io-client: 4.8.3 diff --git a/web/.env.example b/web/.env.example index 81fff4275d..2684667cd4 100644 --- a/web/.env.example +++ b/web/.env.example @@ -4,6 +4,9 @@ NEXT_PUBLIC_DEPLOY_ENV=DEVELOPMENT NEXT_PUBLIC_EDITION=SELF_HOSTED # The base path for the application NEXT_PUBLIC_BASE_PATH= +# Server-only console API origin for server-side requests. +# Usually matches CONSOLE_API_URL from Docker deployment; local dev can rely on NEXT_PUBLIC_API_PREFIX fallback. +CONSOLE_API_URL=http://localhost:5001 # The base URL of console application, refers to the Console base URL of WEB service if console domain is # different from api or web app domain. # example: https://cloud.dify.ai/console/api diff --git a/web/__tests__/apps/app-list-browsing-flow.test.tsx b/web/__tests__/apps/app-list-browsing-flow.test.tsx index 4829adacf0..e3446d4867 100644 --- a/web/__tests__/apps/app-list-browsing-flow.test.tsx +++ b/web/__tests__/apps/app-list-browsing-flow.test.tsx @@ -45,6 +45,7 @@ vi.mock('@/next/navigation', () => ({ push: mockRouterPush, replace: mockRouterReplace, }), + usePathname: () => '/apps', useSearchParams: () => new URLSearchParams(), })) @@ -228,7 +229,7 @@ describe('App List Browsing Flow', () => { mockPages = [createPage([])] renderList() - expect(screen.getByText('app.newApp.noAppsFound')).toBeInTheDocument() + expect(screen.getByText('app.firstEmpty.title')).toBeInTheDocument() }) it('should transition from loading to content when data loads', () => { @@ -283,7 +284,7 @@ describe('App List Browsing Flow', () => { renderList() - expect(screen.getByText('app.createApp')).toBeInTheDocument() + expect(screen.getByText('app.newApp.startFromBlank')).toBeInTheDocument() }) it('should hide NewAppCard when user is not a workspace editor', () => { @@ -294,7 +295,7 @@ describe('App List Browsing Flow', () => { renderList() - expect(screen.queryByText('app.createApp')).not.toBeInTheDocument() + expect(screen.queryByText('app.newApp.startFromBlank')).not.toBeInTheDocument() }) }) @@ -340,16 +341,18 @@ describe('App List Browsing Flow', () => { // -- Tab navigation -- describe('Tab Navigation', () => { - it('should render all category tabs', () => { + it('should render all category options', async () => { mockPages = [createPage([createMockApp()])] renderList() - expect(screen.getByText('app.types.all')).toBeInTheDocument() - expect(screen.getByText('app.types.workflow')).toBeInTheDocument() - expect(screen.getByText('app.types.advanced')).toBeInTheDocument() - expect(screen.getByText('app.types.chatbot')).toBeInTheDocument() - expect(screen.getByText('app.types.agent')).toBeInTheDocument() - expect(screen.getByText('app.types.completion')).toBeInTheDocument() + fireEvent.click(screen.getByRole('combobox', { name: 'app.types.all' })) + + expect(await screen.findByRole('option', { name: 'app.types.all' })).toBeInTheDocument() + expect(await screen.findByRole('option', { name: 'app.types.workflow' })).toBeInTheDocument() + expect(await screen.findByRole('option', { name: 'app.types.advanced' })).toBeInTheDocument() + expect(await screen.findByRole('option', { name: 'app.types.chatbot' })).toBeInTheDocument() + expect(await screen.findByRole('option', { name: 'app.types.agent' })).toBeInTheDocument() + expect(await screen.findByRole('option', { name: 'app.types.completion' })).toBeInTheDocument() }) }) diff --git a/web/__tests__/apps/create-app-flow.test.tsx b/web/__tests__/apps/create-app-flow.test.tsx index a487f102dd..ba3ab166de 100644 --- a/web/__tests__/apps/create-app-flow.test.tsx +++ b/web/__tests__/apps/create-app-flow.test.tsx @@ -42,6 +42,7 @@ vi.mock('@/next/navigation', () => ({ push: mockRouterPush, replace: mockRouterReplace, }), + usePathname: () => '/apps', useSearchParams: () => new URLSearchParams(), })) diff --git a/web/__tests__/header/account-dropdown-flow.test.tsx b/web/__tests__/header/account-dropdown-flow.test.tsx index eb128924c0..fd651931b5 100644 --- a/web/__tests__/header/account-dropdown-flow.test.tsx +++ b/web/__tests__/header/account-dropdown-flow.test.tsx @@ -141,7 +141,6 @@ describe('Header Account Dropdown Flow', () => { }) it('logs out, resets cached user markers, and redirects to signin', async () => { - localStorage.setItem('setup_status', 'done') localStorage.setItem('education-reverify-prev-expire-at', '1') localStorage.setItem('education-reverify-has-noticed', '1') localStorage.setItem('education-expired-has-noticed', '1') @@ -157,7 +156,6 @@ describe('Header Account Dropdown Flow', () => { expect(mockPush).toHaveBeenCalledWith('/signin') }) - expect(localStorage.getItem('setup_status')).toBeNull() expect(localStorage.getItem('education-reverify-prev-expire-at')).toBeNull() expect(localStorage.getItem('education-reverify-has-noticed')).toBeNull() expect(localStorage.getItem('education-expired-has-noticed')).toBeNull() diff --git a/web/app/(commonLayout)/__tests__/hydration-boundary.spec.tsx b/web/app/(commonLayout)/__tests__/hydration-boundary.spec.tsx new file mode 100644 index 0000000000..20dea97243 --- /dev/null +++ b/web/app/(commonLayout)/__tests__/hydration-boundary.spec.tsx @@ -0,0 +1,124 @@ +import type { ReactElement } from 'react' +import { QueryClient, QueryClientProvider } from '@tanstack/react-query' +import { render, screen } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' + +const mocks = vi.hoisted(() => ({ + queryClient: undefined as QueryClient | undefined, + profileQueryFn: vi.fn(), + systemFeaturesQueryFn: vi.fn(), + redirect: vi.fn((url: string) => { + throw new Error(`NEXT_REDIRECT:${url}`) + }), + headers: vi.fn(), + resolveServerConsoleApiUrl: vi.fn(), +})) + +vi.mock('@/context/query-client-server', () => ({ + getQueryClientServer: () => mocks.queryClient, +})) + +vi.mock('@/next/headers', () => ({ + headers: () => mocks.headers(), +})) + +vi.mock('@/next/navigation', () => ({ + redirect: (url: string) => mocks.redirect(url), +})) + +vi.mock('@/features/account-profile/server', () => ({ + resolveServerConsoleApiUrl: (...args: unknown[]) => mocks.resolveServerConsoleApiUrl(...args), + serverUserProfileQueryOptions: () => ({ + queryKey: ['common', 'user-profile'], + queryFn: mocks.profileQueryFn, + retry: false, + }), +})) + +vi.mock('@/service/system-features', () => ({ + systemFeaturesQueryOptions: () => ({ + queryKey: ['console', 'system-features'], + queryFn: mocks.systemFeaturesQueryFn, + retry: false, + }), +})) + +describe('CommonLayoutHydrationBoundary', () => { + beforeEach(() => { + vi.clearAllMocks() + mocks.queryClient = new QueryClient({ defaultOptions: { queries: { retry: false } } }) + mocks.headers.mockResolvedValue(new Headers({ + 'x-dify-pathname': '/apps', + 'x-dify-search': '?tag=workflow', + })) + mocks.resolveServerConsoleApiUrl.mockReturnValue('https://console.example.com/console/api/account/profile') + mocks.profileQueryFn.mockResolvedValue({ + profile: { + id: 'account-id', + name: 'Dify User', + email: 'user@example.com', + avatar: '', + avatar_url: null, + is_password_set: true, + }, + meta: { + currentVersion: '1.0.0', + currentEnv: 'DEVELOPMENT', + }, + }) + mocks.systemFeaturesQueryFn.mockResolvedValue({ branding: { enabled: false } }) + }) + + it('should hydrate common layout queries and render children', async () => { + const { CommonLayoutHydrationBoundary } = await import('../hydration-boundary') + + const element = await CommonLayoutHydrationBoundary({ + children:
Common shell
, + }) + + render( + + {element as ReactElement} + , + ) + expect(screen.getByText('Common shell')).toBeInTheDocument() + expect(mocks.profileQueryFn).toHaveBeenCalledTimes(1) + expect(mocks.systemFeaturesQueryFn).toHaveBeenCalledTimes(1) + }) + + it('should redirect unauthorized users to the refresh route with the current path', async () => { + mocks.profileQueryFn.mockRejectedValue(new Response(JSON.stringify({ code: 'unauthorized' }), { status: 401 })) + const { CommonLayoutHydrationBoundary } = await import('../hydration-boundary') + + await expect(CommonLayoutHydrationBoundary({ children: null })).rejects.toThrow('NEXT_REDIRECT') + + expect(mocks.redirect).toHaveBeenCalledWith('/auth/refresh?redirect_url=%2Fapps%3Ftag%3Dworkflow') + }) + + it('should redirect setup errors to install', async () => { + mocks.profileQueryFn.mockRejectedValue(new Response(JSON.stringify({ code: 'not_setup' }), { status: 401 })) + const { CommonLayoutHydrationBoundary } = await import('../hydration-boundary') + + await expect(CommonLayoutHydrationBoundary({ children: null })).rejects.toThrow('NEXT_REDIRECT') + + expect(mocks.redirect).toHaveBeenCalledWith('/install') + }) + + it('should render children without server prefetch when the server API URL is not resolvable', async () => { + mocks.resolveServerConsoleApiUrl.mockReturnValue(null) + const { CommonLayoutHydrationBoundary } = await import('../hydration-boundary') + + const element = await CommonLayoutHydrationBoundary({ + children:
Common shell
, + }) + + render( + + {element as ReactElement} + , + ) + expect(screen.getByText('Common shell')).toBeInTheDocument() + expect(mocks.profileQueryFn).not.toHaveBeenCalled() + expect(mocks.systemFeaturesQueryFn).not.toHaveBeenCalled() + }) +}) diff --git a/web/app/(commonLayout)/education-apply/page.tsx b/web/app/(commonLayout)/education-apply/page.tsx index 82e47d5c0b..e5d5a9a5fa 100644 --- a/web/app/(commonLayout)/education-apply/page.tsx +++ b/web/app/(commonLayout)/education-apply/page.tsx @@ -1,8 +1,8 @@ 'use client' import { useEffect } from 'react' +import { FullScreenLoading } from '@/app/components/full-screen-loading' import EducationApplyPage from '@/app/education-apply/education-apply-page' -import RootLoading from '@/app/loading' import { useProviderContext } from '@/context/provider-context' import { useRouter, @@ -28,7 +28,7 @@ export default function EducationApply() { }, [enableEducationPlan, isFetchedPlanInfo, router, token]) if (!isFetchedPlanInfo || !enableEducationPlan || !token || isLoadingEducationAccountInfo) - return + return return } diff --git a/web/app/(commonLayout)/error.tsx b/web/app/(commonLayout)/error.tsx index dbc5ded3e9..1548ffd741 100644 --- a/web/app/(commonLayout)/error.tsx +++ b/web/app/(commonLayout)/error.tsx @@ -2,8 +2,8 @@ import { Button } from '@langgenius/dify-ui/button' import { useTranslation } from 'react-i18next' -import RootLoading from '@/app/loading' -import { isLegacyBase401 } from '@/service/use-common' +import { FullScreenLoading } from '@/app/components/full-screen-loading' +import { isLegacyBase401 } from '@/features/account-profile/client' type Props = { error: Error & { digest?: string } @@ -18,7 +18,7 @@ export default function CommonLayoutError({ error, unstable_retry }: Props) { // Showing the "Try again" button here would just flash for a few frames before // the page navigates away, and clicking it would 401 again anyway. if (isLegacyBase401(error)) - return + return return (
diff --git a/web/app/(commonLayout)/hydration-boundary.tsx b/web/app/(commonLayout)/hydration-boundary.tsx new file mode 100644 index 0000000000..fe4cf49420 --- /dev/null +++ b/web/app/(commonLayout)/hydration-boundary.tsx @@ -0,0 +1,86 @@ +import type { ReactNode } from 'react' +import { dehydrate, HydrationBoundary } from '@tanstack/react-query' +import { getQueryClientServer } from '@/context/query-client-server' +import { resolveServerConsoleApiUrl, serverUserProfileQueryOptions } from '@/features/account-profile/server' +import { headers } from '@/next/headers' +import { redirect } from '@/next/navigation' +import { systemFeaturesQueryOptions } from '@/service/system-features' +import { basePath } from '@/utils/var' + +const CURRENT_PATHNAME_HEADER = 'x-dify-pathname' +const CURRENT_SEARCH_HEADER = 'x-dify-search' +const ACCOUNT_PROFILE_PATH = '/account/profile' +const AUTH_REFRESH_PATH = '/auth/refresh' + +type ConsoleErrorPayload = { + code?: string +} + +const isConsoleErrorPayload = (value: unknown): value is ConsoleErrorPayload => + Boolean(value) && typeof value === 'object' && !Array.isArray(value) + +const parseConsoleErrorPayload = async (error: Response): Promise => { + try { + const payload: unknown = await error.clone().json() + return isConsoleErrorPayload(payload) ? payload : null + } + catch { + return null + } +} + +const getCurrentPath = async () => { + const requestHeaders = await headers() + const pathname = requestHeaders.get(CURRENT_PATHNAME_HEADER) || `${basePath}/apps` + const search = requestHeaders.get(CURRENT_SEARCH_HEADER) || '' + return `${pathname}${search}` +} + +const redirectToAuthRefresh = async () => { + const currentPath = await getCurrentPath() + redirect(`${basePath}${AUTH_REFRESH_PATH}?redirect_url=${encodeURIComponent(currentPath)}`) +} + +const handleProfileError = async (error: unknown) => { + if (!(error instanceof Response)) + throw error + + const errorData = await parseConsoleErrorPayload(error) + if (errorData?.code === 'not_setup') + redirect(`${basePath}/install`) + if (errorData?.code === 'not_init_validated') + redirect(`${basePath}/init`) + if (error.status === 401) + await redirectToAuthRefresh() + + throw error +} + +export async function CommonLayoutHydrationBoundary({ children }: { children: ReactNode }) { + const queryClient = getQueryClientServer() + const accountProfileUrl = resolveServerConsoleApiUrl(ACCOUNT_PROFILE_PATH) + + if (!accountProfileUrl) { + return ( + + {children} + + ) + } + + try { + await Promise.all([ + queryClient.fetchQuery(serverUserProfileQueryOptions()), + queryClient.prefetchQuery(systemFeaturesQueryOptions()), + ]) + } + catch (error) { + await handleProfileError(error) + } + + return ( + + {children} + + ) +} diff --git a/web/app/(commonLayout)/layout.tsx b/web/app/(commonLayout)/layout.tsx index 36e53dc808..bf74564cdf 100644 --- a/web/app/(commonLayout)/layout.tsx +++ b/web/app/(commonLayout)/layout.tsx @@ -1,26 +1,30 @@ import type { ReactNode } from 'react' import * as React from 'react' -import { AppInitializer } from '@/app/components/app-initializer' import InSiteMessageNotification from '@/app/components/app/in-site-message/notification' import AmplitudeProvider from '@/app/components/base/amplitude' import { GoogleAnalyticsScripts } from '@/app/components/base/ga' import Zendesk from '@/app/components/base/zendesk' +import { EducationVerifyActionRecorder } from '@/app/components/education-verify-action-recorder' import { GotoAnything } from '@/app/components/goto-anything' import MainNavLayout from '@/app/components/main-nav/layout' +import { OAuthRegistrationAnalytics } from '@/app/components/oauth-registration-analytics' import ReadmePanel from '@/app/components/plugins/readme-panel' import { AppContextProvider } from '@/context/app-context-provider' import { EventEmitterContextProvider } from '@/context/event-emitter-provider' import { ModalContextProvider } from '@/context/modal-context-provider' import { ProviderContextProvider } from '@/context/provider-context-provider' import PartnerStack from '../components/billing/partner-stack' +import { CommonLayoutHydrationBoundary } from './hydration-boundary' import RoleRouteGuard from './role-route-guard' -const Layout = ({ children }: { children: ReactNode }) => { +const Layout = async ({ children }: { children: ReactNode }) => { return ( <> - + + + @@ -38,8 +42,8 @@ const Layout = ({ children }: { children: ReactNode }) => { - - + + ) } diff --git a/web/app/account/(commonLayout)/account-page/email-change-modal.tsx b/web/app/account/(commonLayout)/account-page/email-change-modal.tsx index 185c80fc20..35ad4df3e8 100644 --- a/web/app/account/(commonLayout)/account-page/email-change-modal.tsx +++ b/web/app/account/(commonLayout)/account-page/email-change-modal.tsx @@ -216,7 +216,6 @@ const EmailChangeModal = ({ onClose, email }: Props) => { const handleLogout = async () => { await logout() - localStorage.removeItem('setup_status') // Tokens are now stored in cookies and cleared by backend router.push('/signin') diff --git a/web/app/account/(commonLayout)/account-page/index.tsx b/web/app/account/(commonLayout)/account-page/index.tsx index b0719e169a..8f12906e3a 100644 --- a/web/app/account/(commonLayout)/account-page/index.tsx +++ b/web/app/account/(commonLayout)/account-page/index.tsx @@ -16,10 +16,10 @@ import PremiumBadge from '@/app/components/base/premium-badge' import Collapse from '@/app/components/header/account-setting/collapse' import { IS_CE_EDITION, validPassword } from '@/config' import { useProviderContext } from '@/context/provider-context' +import { userProfileQueryOptions } from '@/features/account-profile/client' import { consoleQuery } from '@/service/client' import { updateUserProfile } from '@/service/common' import { systemFeaturesQueryOptions } from '@/service/system-features' -import { commonQueryKeys, userProfileQueryOptions } from '@/service/use-common' import DeleteAccount from '../delete-account' import AvatarWithEdit from './AvatarWithEdit' @@ -49,7 +49,7 @@ export default function AccountPage() { // Cache is warmed by AppContextProvider's useSuspenseQuery; this hits cache synchronously. const { data: userProfileResp } = useSuspenseQuery(userProfileQueryOptions()) const userProfile = userProfileResp.profile - const mutateUserProfile = () => queryClient.invalidateQueries({ queryKey: commonQueryKeys.userProfile }) + const mutateUserProfile = () => queryClient.invalidateQueries({ queryKey: userProfileQueryOptions().queryKey }) const { isEducationAccount } = useProviderContext() const [editNameModalVisible, setEditNameModalVisible] = useState(false) const [editName, setEditName] = useState('') diff --git a/web/app/account/(commonLayout)/avatar.tsx b/web/app/account/(commonLayout)/avatar.tsx index 197a00f822..63e443f89d 100644 --- a/web/app/account/(commonLayout)/avatar.tsx +++ b/web/app/account/(commonLayout)/avatar.tsx @@ -12,8 +12,9 @@ import { useTranslation } from 'react-i18next' import { resetUser } from '@/app/components/base/amplitude/utils' import PremiumBadge from '@/app/components/base/premium-badge' import { useProviderContext } from '@/context/provider-context' +import { userProfileQueryOptions } from '@/features/account-profile/client' import { useRouter } from '@/next/navigation' -import { useLogout, userProfileQueryOptions } from '@/service/use-common' +import { useLogout } from '@/service/use-common' export default function AppSelector() { const router = useRouter() @@ -31,7 +32,6 @@ export default function AppSelector() { const handleLogout = async () => { await logout() - localStorage.removeItem('setup_status') resetUser() // Tokens are now stored in cookies and cleared by backend diff --git a/web/app/account/(commonLayout)/layout.tsx b/web/app/account/(commonLayout)/layout.tsx index 4d344c3f78..a97588d203 100644 --- a/web/app/account/(commonLayout)/layout.tsx +++ b/web/app/account/(commonLayout)/layout.tsx @@ -1,21 +1,25 @@ import type { ReactNode } from 'react' import * as React from 'react' -import { AppInitializer } from '@/app/components/app-initializer' +import { CommonLayoutHydrationBoundary } from '@/app/(commonLayout)/hydration-boundary' import AmplitudeProvider from '@/app/components/base/amplitude' import { GoogleAnalyticsScripts } from '@/app/components/base/ga' +import { EducationVerifyActionRecorder } from '@/app/components/education-verify-action-recorder' import HeaderWrapper from '@/app/components/header/header-wrapper' +import { OAuthRegistrationAnalytics } from '@/app/components/oauth-registration-analytics' import { AppContextProvider } from '@/context/app-context-provider' import { EventEmitterContextProvider } from '@/context/event-emitter-provider' import { ModalContextProvider } from '@/context/modal-context-provider' import { ProviderContextProvider } from '@/context/provider-context-provider' import Header from './header' -const Layout = ({ children }: { children: ReactNode }) => { +const Layout = async ({ children }: { children: ReactNode }) => { return ( <> - + + + @@ -30,7 +34,7 @@ const Layout = ({ children }: { children: ReactNode }) => { - + ) } diff --git a/web/app/account/oauth/authorize/layout.tsx b/web/app/account/oauth/authorize/layout.tsx index 850fe9c2b5..af4b13dabb 100644 --- a/web/app/account/oauth/authorize/layout.tsx +++ b/web/app/account/oauth/authorize/layout.tsx @@ -5,9 +5,9 @@ import { useQuery, useSuspenseQuery } from '@tanstack/react-query' import Loading from '@/app/components/base/loading' import Header from '@/app/signin/_header' import { AppContextProvider } from '@/context/app-context-provider' +import { isLegacyBase401, userProfileQueryOptions } from '@/features/account-profile/client' import useDocumentTitle from '@/hooks/use-document-title' import { systemFeaturesQueryOptions } from '@/service/system-features' -import { isLegacyBase401, userProfileQueryOptions } from '@/service/use-common' export default function SignInLayout({ children }: any) { const { data: systemFeatures } = useSuspenseQuery(systemFeaturesQueryOptions()) diff --git a/web/app/account/oauth/authorize/page.tsx b/web/app/account/oauth/authorize/page.tsx index 97dabb46a9..d461794f8b 100644 --- a/web/app/account/oauth/authorize/page.tsx +++ b/web/app/account/oauth/authorize/page.tsx @@ -16,9 +16,9 @@ import { useEffect, useRef } from 'react' import { useTranslation } from 'react-i18next' import Loading from '@/app/components/base/loading' import { useLanguage } from '@/app/components/header/account-setting/model-provider-page/hooks' -import { setOAuthPendingRedirect } from '@/app/signin/utils/post-login-redirect' +import { isLegacyBase401, userProfileQueryOptions } from '@/features/account-profile/client' import { useRouter, useSearchParams } from '@/next/navigation' -import { isLegacyBase401, useLogout, userProfileQueryOptions } from '@/service/use-common' +import { useLogout } from '@/service/use-common' import { useAuthorizeOAuthApp, useOAuthAppInfo } from '@/service/use-oauth' function buildReturnUrl(pathname: string, search: string) { @@ -80,7 +80,6 @@ export default function OAuthAuthorize() { const onLoginSwitchClick = async () => { try { const returnUrl = buildReturnUrl('/account/oauth/authorize', `?${searchParams.toString()}`) - setOAuthPendingRedirect(returnUrl) if (isLoggedIn) await logout() router.push(`/signin?redirect_url=${encodeURIComponent(returnUrl)}`) diff --git a/web/app/auth/refresh/__tests__/route.spec.ts b/web/app/auth/refresh/__tests__/route.spec.ts new file mode 100644 index 0000000000..c86b6261a8 --- /dev/null +++ b/web/app/auth/refresh/__tests__/route.spec.ts @@ -0,0 +1,104 @@ +// @vitest-environment node + +import { beforeEach, describe, expect, it, vi } from 'vitest' + +vi.mock('@/config', () => ({ + API_PREFIX: 'http://localhost:5001/console/api', +})) + +vi.mock('@/config/server', () => ({ + SERVER_CONSOLE_API_PREFIX: undefined, +})) + +vi.mock('@/utils/var', () => ({ + basePath: '', +})) + +const getSetCookieHeaders = (headers: Headers) => { + const getSetCookie = Reflect.get(headers, 'getSetCookie') + + if (typeof getSetCookie === 'function') { + const values: unknown = getSetCookie.call(headers) + return Array.isArray(values) ? values : [] + } + + const setCookie = headers.get('set-cookie') + return setCookie ? [setCookie] : [] +} + +const createRequest = (url: string, cookie?: string) => ({ + url, + headers: new Headers(cookie ? { cookie } : undefined), +}) as Request + +describe('auth refresh route', () => { + beforeEach(() => { + vi.clearAllMocks() + vi.unstubAllGlobals() + }) + + it('should refresh cookies and redirect back to the requested path', async () => { + const headers = new Headers() + Object.defineProperty(headers, 'getSetCookie', { + value: () => [ + 'access_token=new-access; Path=/; HttpOnly', + 'refresh_token=new-refresh; Path=/; HttpOnly', + ], + }) + const fetchMock = vi.fn().mockResolvedValue({ + ok: true, + headers, + } as Response) + vi.stubGlobal('fetch', fetchMock) + const { GET } = await import('../route') + + const response = await GET(createRequest( + 'http://localhost:3000/auth/refresh?redirect_url=%2Fapps%3Fcategory%3Dworkflow', + 'refresh_token=old-refresh', + )) + + expect(fetchMock).toHaveBeenCalledWith( + 'http://localhost:5001/console/api/refresh-token', + expect.objectContaining({ + method: 'POST', + cache: 'no-store', + headers: expect.any(Headers), + }), + ) + const fetchHeaders = fetchMock.mock.calls[0]?.[1]?.headers as Headers + expect(fetchHeaders.get('cookie')).toBe('refresh_token=old-refresh') + expect(response.status).toBe(303) + expect(response.headers.get('location')).toBe('http://localhost:3000/apps?category=workflow') + expect(getSetCookieHeaders(response.headers)).toEqual([ + 'access_token=new-access; Path=/; HttpOnly', + 'refresh_token=new-refresh; Path=/; HttpOnly', + ]) + }) + + it('should redirect to signin when refresh token is rejected', async () => { + vi.stubGlobal('fetch', vi.fn().mockResolvedValue(new Response(null, { status: 401 }))) + const { GET } = await import('../route') + + const response = await GET(createRequest( + 'http://localhost:3000/auth/refresh?redirect_url=%2Fapps', + 'refresh_token=expired', + )) + + expect(response.status).toBe(303) + expect(response.headers.get('location')).toBe('http://localhost:3000/signin?redirect_url=%2Fapps') + }) + + it('should ignore cross-origin redirect targets', async () => { + const fetchMock = vi.fn().mockResolvedValue(new Response(null, { status: 401 })) + vi.stubGlobal('fetch', fetchMock) + const { GET } = await import('../route') + + const response = await GET(createRequest( + 'http://localhost:3000/auth/refresh?redirect_url=https%3A%2F%2Fevil.example', + 'refresh_token=expired', + )) + + expect(response.status).toBe(303) + expect(response.headers.get('location')).toBe('http://localhost:3000/signin?redirect_url=%2Fapps') + }) +}) diff --git a/web/app/auth/refresh/route.ts b/web/app/auth/refresh/route.ts new file mode 100644 index 0000000000..998f5a5ffe --- /dev/null +++ b/web/app/auth/refresh/route.ts @@ -0,0 +1,113 @@ +import { API_PREFIX } from '@/config' +import { SERVER_CONSOLE_API_PREFIX } from '@/config/server' +import { basePath } from '@/utils/var' + +const REFRESH_TOKEN_PATH = '/refresh-token' +const AUTH_REFRESH_PATH = `${basePath}/auth/refresh` +const DEFAULT_REDIRECT_PATH = `${basePath}/apps` + +const withTrailingSlash = (value: string) => value.endsWith('/') ? value : `${value}/` +const withoutLeadingSlash = (value: string) => value.startsWith('/') ? value.slice(1) : value + +const resolveAbsoluteUrlPrefix = (value: string) => { + try { + return new URL(value).toString() + } + catch { + return null + } +} + +const resolveServerConsoleApiUrl = (pathname: string, requestUrl: URL) => { + const requestPath = withoutLeadingSlash(pathname) + const apiPrefix = SERVER_CONSOLE_API_PREFIX + || resolveAbsoluteUrlPrefix(API_PREFIX) + || new URL(API_PREFIX, requestUrl.origin).toString() + + if (!apiPrefix) + return null + + return new URL(requestPath, withTrailingSlash(apiPrefix)).toString() +} + +const resolveSafeRedirectPath = (request: Request) => { + const requestUrl = new URL(request.url) + const redirectUrl = requestUrl.searchParams.get('redirect_url') + + if (!redirectUrl) + return DEFAULT_REDIRECT_PATH + + try { + const target = new URL(redirectUrl, requestUrl.origin) + if (target.origin !== requestUrl.origin) + return DEFAULT_REDIRECT_PATH + if (target.pathname === AUTH_REFRESH_PATH) + return DEFAULT_REDIRECT_PATH + + return `${target.pathname}${target.search}` + } + catch { + return DEFAULT_REDIRECT_PATH + } +} + +const getSetCookieHeaders = (headers: Headers) => { + const getSetCookie = Reflect.get(headers, 'getSetCookie') + + if (typeof getSetCookie === 'function') { + const values: unknown = getSetCookie.call(headers) + return Array.isArray(values) + ? values.filter((value): value is string => typeof value === 'string') + : [] + } + + const setCookie = headers.get('set-cookie') + return setCookie ? [setCookie] : [] +} + +const createRedirectResponse = (request: Request, pathname: string, setCookies: string[] = []) => { + const headers = new Headers({ + 'Cache-Control': 'no-store', + 'Location': new URL(pathname, request.url).toString(), + }) + + for (const cookie of setCookies) + headers.append('Set-Cookie', cookie) + + return new Response(null, { + status: 303, + headers, + }) +} + +const createSigninRedirectResponse = (request: Request, redirectPath: string) => + createRedirectResponse(request, `${basePath}/signin?redirect_url=${encodeURIComponent(redirectPath)}`) + +export async function GET(request: Request) { + const requestUrl = new URL(request.url) + const redirectPath = resolveSafeRedirectPath(request) + const refreshUrl = resolveServerConsoleApiUrl(REFRESH_TOKEN_PATH, requestUrl) + const cookie = request.headers.get('cookie') + + if (!refreshUrl || !cookie) + return createSigninRedirectResponse(request, redirectPath) + + try { + const response = await fetch(refreshUrl, { + method: 'POST', + headers: new Headers({ + 'Content-Type': 'application/json', + cookie, + }), + cache: 'no-store', + }) + + if (!response.ok) + return createSigninRedirectResponse(request, redirectPath) + + return createRedirectResponse(request, redirectPath, getSetCookieHeaders(response.headers)) + } + catch { + return createSigninRedirectResponse(request, redirectPath) + } +} diff --git a/web/app/components/__tests__/app-initializer.spec.tsx b/web/app/components/__tests__/app-initializer.spec.tsx deleted file mode 100644 index b4c2d08f2e..0000000000 --- a/web/app/components/__tests__/app-initializer.spec.tsx +++ /dev/null @@ -1,197 +0,0 @@ -import { screen, waitFor } from '@testing-library/react' -import Cookies from 'js-cookie' -import { beforeEach, describe, expect, it, vi } from 'vitest' -import { - EDUCATION_VERIFY_URL_SEARCHPARAMS_ACTION, - EDUCATION_VERIFYING_LOCALSTORAGE_ITEM, -} from '@/app/education-apply/constants' -import { resolvePostLoginRedirect } from '@/app/signin/utils/post-login-redirect' -import { usePathname, useRouter, useSearchParams } from '@/next/navigation' -import { renderWithNuqs } from '@/test/nuqs-testing' -import { fetchSetupStatusWithCache } from '@/utils/setup-status' -import { AppInitializer } from '../app-initializer' - -const { mockSendGAEvent, mockTrackEvent } = vi.hoisted(() => ({ - mockSendGAEvent: vi.fn(), - mockTrackEvent: vi.fn(), -})) - -vi.mock('@/next/navigation', () => ({ - usePathname: vi.fn(), - useRouter: vi.fn(), - useSearchParams: vi.fn(), -})) - -vi.mock('@/utils/setup-status', () => ({ - fetchSetupStatusWithCache: vi.fn(), -})) - -vi.mock('@/app/signin/utils/post-login-redirect', () => ({ - resolvePostLoginRedirect: vi.fn(), -})) - -vi.mock('@/utils/gtag', () => ({ - sendGAEvent: (...args: unknown[]) => mockSendGAEvent(...args), -})) - -vi.mock('../base/amplitude', () => ({ - trackEvent: (...args: unknown[]) => mockTrackEvent(...args), -})) - -const mockUsePathname = vi.mocked(usePathname) -const mockUseRouter = vi.mocked(useRouter) -const mockUseSearchParams = vi.mocked(useSearchParams) -const mockFetchSetupStatusWithCache = vi.mocked(fetchSetupStatusWithCache) -const mockResolvePostLoginRedirect = vi.mocked(resolvePostLoginRedirect) -const mockReplace = vi.fn() - -describe('AppInitializer', () => { - beforeEach(() => { - vi.clearAllMocks() - vi.unstubAllGlobals() - window.localStorage.clear() - window.sessionStorage.clear() - Cookies.remove('utm_info') - vi.spyOn(console, 'error').mockImplementation(() => {}) - mockUsePathname.mockReturnValue('/apps') - mockUseRouter.mockReturnValue({ replace: mockReplace } as unknown as ReturnType) - mockUseSearchParams.mockReturnValue(new URLSearchParams() as unknown as ReturnType) - mockFetchSetupStatusWithCache.mockResolvedValue({ step: 'finished' }) - mockResolvePostLoginRedirect.mockReturnValue(null) - }) - - it('renders children after setup checks finish', async () => { - renderWithNuqs( - -
ready
-
, - ) - - await waitFor(() => expect(screen.getByText('ready')).toBeInTheDocument()) - - expect(mockFetchSetupStatusWithCache).toHaveBeenCalledTimes(1) - expect(mockReplace).not.toHaveBeenCalledWith('/signin') - }) - - it('redirects to install when setup status loading fails', async () => { - mockFetchSetupStatusWithCache.mockRejectedValue(new Error('unauthorized')) - - renderWithNuqs( - -
ready
-
, - ) - - await waitFor(() => expect(mockReplace).toHaveBeenCalledWith('/install')) - expect(screen.queryByText('ready')).not.toBeInTheDocument() - }) - - it('does not persist create app attribution from the url anymore', async () => { - renderWithNuqs( - -
ready
-
, - ) - - await waitFor(() => expect(screen.getByText('ready')).toBeInTheDocument()) - - expect(window.sessionStorage.getItem('create_app_external_attribution')).toBeNull() - }) - - it('tracks oauth registration with utm info and clears the cookie', async () => { - Cookies.set('utm_info', JSON.stringify({ - utm_source: 'linkedin', - slug: 'agent-launch', - })) - - renderWithNuqs( - -
ready
-
, - { searchParams: 'oauth_new_user=true' }, - ) - - await waitFor(() => expect(screen.getByText('ready')).toBeInTheDocument()) - - expect(mockTrackEvent).toHaveBeenCalledWith('user_registration_success_with_utm', { - method: 'oauth', - utm_source: 'linkedin', - slug: 'agent-launch', - }) - expect(mockSendGAEvent).toHaveBeenCalledWith('user_registration_success_with_utm', { - method: 'oauth', - utm_source: 'linkedin', - slug: 'agent-launch', - }) - expect(mockReplace).toHaveBeenCalledWith('/apps') - expect(Cookies.get('utm_info')).toBeUndefined() - }) - - it('falls back to the base registration event when the oauth utm cookie is invalid', async () => { - Cookies.set('utm_info', '{invalid-json') - - renderWithNuqs( - -
ready
-
, - { searchParams: 'oauth_new_user=true' }, - ) - - await waitFor(() => expect(screen.getByText('ready')).toBeInTheDocument()) - - expect(mockTrackEvent).toHaveBeenCalledWith('user_registration_success', { - method: 'oauth', - }) - expect(mockSendGAEvent).toHaveBeenCalledWith('user_registration_success', { - method: 'oauth', - }) - expect(console.error).toHaveBeenCalled() - expect(Cookies.get('utm_info')).toBeUndefined() - }) - - it('stores the education verification flag in localStorage', async () => { - mockUseSearchParams.mockReturnValue( - new URLSearchParams(`action=${EDUCATION_VERIFY_URL_SEARCHPARAMS_ACTION}`) as unknown as ReturnType, - ) - - renderWithNuqs( - -
ready
-
, - ) - - await waitFor(() => expect(screen.getByText('ready')).toBeInTheDocument()) - - expect(window.localStorage.getItem(EDUCATION_VERIFYING_LOCALSTORAGE_ITEM)).toBe('yes') - }) - - it('redirects to the resolved post-login url when one exists', async () => { - const mockLocationReplace = vi.fn() - vi.stubGlobal('location', { ...window.location, replace: mockLocationReplace }) - mockResolvePostLoginRedirect.mockReturnValue('/explore') - - renderWithNuqs( - -
ready
-
, - ) - - await waitFor(() => expect(mockLocationReplace).toHaveBeenCalledWith('/explore')) - expect(screen.queryByText('ready')).not.toBeInTheDocument() - }) - - it('redirects to signin when redirect resolution throws', async () => { - mockResolvePostLoginRedirect.mockImplementation(() => { - throw new Error('redirect resolution failed') - }) - - renderWithNuqs( - -
ready
-
, - ) - - await waitFor(() => expect(mockReplace).toHaveBeenCalledWith('/signin')) - expect(screen.queryByText('ready')).not.toBeInTheDocument() - }) -}) diff --git a/web/app/components/__tests__/education-verify-action-recorder.spec.tsx b/web/app/components/__tests__/education-verify-action-recorder.spec.tsx new file mode 100644 index 0000000000..416715abf2 --- /dev/null +++ b/web/app/components/__tests__/education-verify-action-recorder.spec.tsx @@ -0,0 +1,40 @@ +import { render, waitFor } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { + EDUCATION_VERIFY_URL_SEARCHPARAMS_ACTION, + EDUCATION_VERIFYING_LOCALSTORAGE_ITEM, +} from '@/app/education-apply/constants' +import { useSearchParams } from '@/next/navigation' +import { EducationVerifyActionRecorder } from '../education-verify-action-recorder' + +vi.mock('@/next/navigation', () => ({ + useSearchParams: vi.fn(), +})) + +const mockUseSearchParams = vi.mocked(useSearchParams) + +describe('EducationVerifyActionRecorder', () => { + beforeEach(() => { + vi.clearAllMocks() + window.localStorage.clear() + mockUseSearchParams.mockReturnValue(new URLSearchParams() as unknown as ReturnType) + }) + + it('should store the education verification flag when the callback action is present', async () => { + mockUseSearchParams.mockReturnValue( + new URLSearchParams(`action=${EDUCATION_VERIFY_URL_SEARCHPARAMS_ACTION}`) as unknown as ReturnType, + ) + + render() + + await waitFor(() => { + expect(window.localStorage.getItem(EDUCATION_VERIFYING_LOCALSTORAGE_ITEM)).toBe('yes') + }) + }) + + it('should leave localStorage unchanged for unrelated routes', () => { + render() + + expect(window.localStorage.getItem(EDUCATION_VERIFYING_LOCALSTORAGE_ITEM)).toBeNull() + }) +}) diff --git a/web/app/components/__tests__/oauth-registration-analytics.spec.tsx b/web/app/components/__tests__/oauth-registration-analytics.spec.tsx new file mode 100644 index 0000000000..6bc9fe4fe2 --- /dev/null +++ b/web/app/components/__tests__/oauth-registration-analytics.spec.tsx @@ -0,0 +1,106 @@ +import { render, waitFor } from '@testing-library/react' +import Cookies from 'js-cookie' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { useSearchParams } from '@/next/navigation' +import { OAuthRegistrationAnalytics } from '../oauth-registration-analytics' + +const { mockSendGAEvent, mockTrackEvent } = vi.hoisted(() => ({ + mockSendGAEvent: vi.fn(), + mockTrackEvent: vi.fn(), +})) + +vi.mock('@/utils/gtag', () => ({ + sendGAEvent: (...args: unknown[]) => mockSendGAEvent(...args), +})) + +vi.mock('@/next/navigation', () => ({ + useSearchParams: vi.fn(), +})) + +vi.mock('../base/amplitude', () => ({ + trackEvent: (...args: unknown[]) => mockTrackEvent(...args), +})) + +const mockUseSearchParams = vi.mocked(useSearchParams) + +const setSearchParams = (searchParams = '') => { + mockUseSearchParams.mockReturnValue(new URLSearchParams(searchParams) as unknown as ReturnType) + window.history.replaceState(null, '', `/signin${searchParams ? `?${searchParams}` : ''}`) +} + +describe('OAuthRegistrationAnalytics', () => { + beforeEach(() => { + vi.clearAllMocks() + Cookies.remove('utm_info') + vi.spyOn(console, 'error').mockImplementation(() => {}) + setSearchParams() + }) + + it('should track oauth registration with utm info and clear the query flag', async () => { + Cookies.set('utm_info', JSON.stringify({ + utm_source: 'linkedin', + slug: 'agent-launch', + })) + + setSearchParams('oauth_new_user=true&source=signin') + const replaceStateSpy = vi.spyOn(window.history, 'replaceState') + + render() + + await waitFor(() => { + expect(mockTrackEvent).toHaveBeenCalledWith('user_registration_success_with_utm', { + method: 'oauth', + utm_source: 'linkedin', + slug: 'agent-launch', + }) + }) + expect(mockSendGAEvent).toHaveBeenCalledWith('user_registration_success_with_utm', { + method: 'oauth', + utm_source: 'linkedin', + slug: 'agent-launch', + }) + expect(Cookies.get('utm_info')).toBeUndefined() + + await waitFor(() => { + expect(replaceStateSpy).toHaveBeenCalledWith(null, '', '/signin?source=signin') + }) + }) + + it('should fall back to the base registration event when the utm cookie is invalid', async () => { + Cookies.set('utm_info', '{invalid-json') + + setSearchParams('oauth_new_user=true') + render() + + await waitFor(() => { + expect(mockTrackEvent).toHaveBeenCalledWith('user_registration_success', { + method: 'oauth', + }) + }) + expect(mockSendGAEvent).toHaveBeenCalledWith('user_registration_success', { + method: 'oauth', + }) + expect(console.error).toHaveBeenCalled() + expect(Cookies.get('utm_info')).toBeUndefined() + }) + + it('should do nothing without the oauth registration query flag', () => { + render() + + expect(mockTrackEvent).not.toHaveBeenCalled() + expect(mockSendGAEvent).not.toHaveBeenCalled() + }) + + it('should clear a false oauth registration query flag without tracking', async () => { + setSearchParams('oauth_new_user=false') + const replaceStateSpy = vi.spyOn(window.history, 'replaceState') + + render() + + await waitFor(() => { + expect(replaceStateSpy).toHaveBeenCalledWith(null, '', '/signin') + }) + expect(mockTrackEvent).not.toHaveBeenCalled() + expect(mockSendGAEvent).not.toHaveBeenCalled() + }) +}) diff --git a/web/app/components/app-initializer.tsx b/web/app/components/app-initializer.tsx deleted file mode 100644 index 3d2af1ce61..0000000000 --- a/web/app/components/app-initializer.tsx +++ /dev/null @@ -1,103 +0,0 @@ -'use client' - -import type { ReactNode } from 'react' -import Cookies from 'js-cookie' -import { parseAsBoolean, useQueryState } from 'nuqs' -import { useCallback, useEffect, useState } from 'react' -import { - EDUCATION_VERIFY_URL_SEARCHPARAMS_ACTION, - EDUCATION_VERIFYING_LOCALSTORAGE_ITEM, -} from '@/app/education-apply/constants' -import RootLoading from '@/app/loading' -import { usePathname, useRouter, useSearchParams } from '@/next/navigation' -import { sendGAEvent } from '@/utils/gtag' -import { fetchSetupStatusWithCache } from '@/utils/setup-status' -import { resolvePostLoginRedirect } from '../signin/utils/post-login-redirect' -import { trackEvent } from './base/amplitude' - -type AppInitializerProps = { - children: ReactNode -} - -export const AppInitializer = ({ - children, -}: AppInitializerProps) => { - const router = useRouter() - const searchParams = useSearchParams() - // Tokens are now stored in cookies, no need to check localStorage - const pathname = usePathname() - const [init, setInit] = useState(false) - const [oauthNewUser] = useQueryState( - 'oauth_new_user', - parseAsBoolean.withOptions({ history: 'replace' }), - ) - const isSetupFinished = useCallback(async () => { - try { - const setUpStatus = await fetchSetupStatusWithCache() - return setUpStatus.step === 'finished' - } - catch (error) { - console.error(error) - return false - } - }, []) - - useEffect(() => { - (async () => { - const action = searchParams.get('action') - - if (oauthNewUser) { - let utmInfo = null - const utmInfoStr = Cookies.get('utm_info') - if (utmInfoStr) { - try { - utmInfo = JSON.parse(utmInfoStr) - } - catch (e) { - console.error('Failed to parse utm_info cookie:', e) - } - } - - // Track registration event with UTM params - trackEvent(utmInfo ? 'user_registration_success_with_utm' : 'user_registration_success', { - method: 'oauth', - ...utmInfo, - }) - - sendGAEvent(utmInfo ? 'user_registration_success_with_utm' : 'user_registration_success', { - method: 'oauth', - ...utmInfo, - }) - - Cookies.remove('utm_info') - } - - if (oauthNewUser !== null) - router.replace(pathname) - - if (action === EDUCATION_VERIFY_URL_SEARCHPARAMS_ACTION) - localStorage.setItem(EDUCATION_VERIFYING_LOCALSTORAGE_ITEM, 'yes') - - try { - const isFinished = await isSetupFinished() - if (!isFinished) { - router.replace('/install') - return - } - - const redirectUrl = resolvePostLoginRedirect(searchParams) - if (redirectUrl) { - location.replace(redirectUrl) - return - } - - setInit(true) - } - catch { - router.replace('/signin') - } - })() - }, [isSetupFinished, router, pathname, searchParams, oauthNewUser]) - - return init ? children : -} diff --git a/web/app/components/apps/__tests__/list.spec.tsx b/web/app/components/apps/__tests__/list.spec.tsx index 5e1d37546a..87ac870e7e 100644 --- a/web/app/components/apps/__tests__/list.spec.tsx +++ b/web/app/components/apps/__tests__/list.spec.tsx @@ -1,4 +1,4 @@ -import { act, fireEvent, screen } from '@testing-library/react' +import { act, fireEvent, screen, waitFor } from '@testing-library/react' import userEvent from '@testing-library/user-event' import * as React from 'react' import { createSystemFeaturesWrapper } from '@/__tests__/utils/mock-system-features' @@ -14,9 +14,11 @@ const mockUseWorkflowOnlineUsers = vi.hoisted(() => vi.fn((_options: unknown) => const mockReplace = vi.fn() const mockRouter = { replace: mockReplace } +let mockSearchParams = new URLSearchParams('') vi.mock('@/next/navigation', () => ({ useRouter: () => mockRouter, - useSearchParams: () => new URLSearchParams(''), + usePathname: () => '/apps', + useSearchParams: () => mockSearchParams, })) vi.mock('@/service/client', () => ({ @@ -57,12 +59,10 @@ vi.mock('@/context/provider-context', () => ({ })) const mockSetKeywords = vi.fn() -const mockSetTagIDs = vi.fn() const mockSetIsCreatedByMe = vi.fn() const mockSetCategory = vi.fn() const mockQueryState = { category: 'all', - tagIDs: [] as string[], keywords: '', isCreatedByMe: false, emptyAppList: false, @@ -73,11 +73,20 @@ vi.mock('../hooks/use-apps-query-state', () => ({ query: mockQueryState, setCategory: mockSetCategory, setKeywords: mockSetKeywords, - setTagIDs: mockSetTagIDs, setIsCreatedByMe: mockSetIsCreatedByMe, }), })) +vi.mock('@/features/tag-management/components/tag-filter', () => ({ + TagFilter: ({ value, onChange, onOpenTagManagement }: { value: string[], onChange: (value: string[]) => void, onOpenTagManagement: () => void }) => ( +
+ + {value.join(',')} + +
+ ), +})) + let mockOnDSLFileDropped: ((file: File) => void) | null = null let mockDragging = false vi.mock('../hooks/use-dsl-drag-drop', () => ({ @@ -258,6 +267,7 @@ beforeAll(() => { // Render helper wrapping with shared nuqs testing helper plus a seeded // systemFeatures cache so List can resolve its useSuspenseQuery. const renderList = (searchParams = '') => { + mockSearchParams = new URLSearchParams(searchParams) const { wrapper: SystemFeaturesWrapper } = createSystemFeaturesWrapper({ systemFeatures: { branding: { enabled: false } }, }) @@ -286,7 +296,6 @@ describe('List', () => { mockServiceState.isLoading = false mockServiceState.isFetchingNextPage = false mockQueryState.category = 'all' - mockQueryState.tagIDs = [] mockQueryState.keywords = '' mockQueryState.isCreatedByMe = false mockQueryState.emptyAppList = false @@ -489,12 +498,12 @@ describe('List', () => { describe('App List Query', () => { it('should build paged query input from active filters', () => { - mockQueryState.tagIDs = ['tag-1'] mockQueryState.keywords = 'sales' mockQueryState.isCreatedByMe = true mockQueryState.category = AppModeEnum.WORKFLOW renderList() + fireEvent.click(screen.getByText('common.tag.placeholder')) const options = mockAppListInfiniteOptions.mock.calls.at(-1)?.[0] as AppListInfiniteOptions @@ -511,6 +520,17 @@ describe('List', () => { expect(options.getNextPageParam({ has_more: true, page: 2 })).toBe(3) expect(options.getNextPageParam({ has_more: false, page: 2 })).toBeUndefined() }) + + it('should remove legacy tagIDs from URL while preserving other filters', async () => { + renderList('?category=workflow&tagIDs=tag-1;tag-2&keywords=sales&isCreatedByMe=true') + + await waitFor(() => { + expect(mockReplace).toHaveBeenCalledWith( + '/apps?category=workflow&keywords=sales&isCreatedByMe=true', + { scroll: false }, + ) + }) + }) }) describe('Tag Filter', () => { diff --git a/web/app/components/apps/hooks/__tests__/use-apps-query-state.spec.tsx b/web/app/components/apps/hooks/__tests__/use-apps-query-state.spec.tsx index 7adb658c36..e6aa20d38d 100644 --- a/web/app/components/apps/hooks/__tests__/use-apps-query-state.spec.tsx +++ b/web/app/components/apps/hooks/__tests__/use-apps-query-state.spec.tsx @@ -5,6 +5,7 @@ import { APP_LIST_SEARCH_DEBOUNCE_MS } from '../../constants' import { useAppsQueryState } from '../use-apps-query-state' const renderWithAdapter = (searchParams = '') => { + // eslint-disable-next-line react/use-state -- renderHook executes a custom hook, not React.useState return renderHookWithNuqs(() => useAppsQueryState(), { searchParams }) } @@ -18,14 +19,12 @@ describe('useAppsQueryState', () => { expect(result.current.query).toEqual({ category: 'all', - tagIDs: [], keywords: '', isCreatedByMe: false, emptyAppList: false, }) expect(typeof result.current.setCategory).toBe('function') expect(typeof result.current.setKeywords).toBe('function') - expect(typeof result.current.setTagIDs).toBe('function') expect(typeof result.current.setIsCreatedByMe).toBe('function') }) @@ -36,7 +35,6 @@ describe('useAppsQueryState', () => { expect(result.current.query).toEqual({ category: AppModeEnum.WORKFLOW, - tagIDs: ['tag1', 'tag2'], keywords: 'search term', isCreatedByMe: true, emptyAppList: true, @@ -119,33 +117,6 @@ describe('useAppsQueryState', () => { } }) - it('should update tag filter URL state', async () => { - const { result, onUrlUpdate } = renderWithAdapter() - - act(() => { - result.current.setTagIDs(['tag1', 'tag2']) - }) - - await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled()) - const update = onUrlUpdate.mock.calls.at(-1)![0] - expect(result.current.query.tagIDs).toEqual(['tag1', 'tag2']) - expect(update.searchParams.get('tagIDs')).toBe('tag1;tag2') - expect(update.options.history).toBe('push') - }) - - it('should remove tagIDs from URL when empty', async () => { - const { result, onUrlUpdate } = renderWithAdapter('?tagIDs=tag1;tag2') - - act(() => { - result.current.setTagIDs([]) - }) - - await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled()) - const update = onUrlUpdate.mock.calls.at(-1)![0] - expect(result.current.query.tagIDs).toEqual([]) - expect(update.searchParams.has('tagIDs')).toBe(false) - }) - it('should update created-by-me URL state', async () => { const { result, onUrlUpdate } = renderWithAdapter() diff --git a/web/app/components/apps/hooks/use-apps-query-state.ts b/web/app/components/apps/hooks/use-apps-query-state.ts index 0d15b8b030..bb92538b7a 100644 --- a/web/app/components/apps/hooks/use-apps-query-state.ts +++ b/web/app/components/apps/hooks/use-apps-query-state.ts @@ -1,4 +1,4 @@ -import { debounce, parseAsArrayOf, parseAsBoolean, parseAsString, parseAsStringLiteral, useQueryStates } from 'nuqs' +import { debounce, parseAsBoolean, parseAsString, parseAsStringLiteral, useQueryStates } from 'nuqs' import { useCallback, useMemo } from 'react' import { AppModes } from '@/types/app' import { APP_LIST_SEARCH_DEBOUNCE_MS } from '../constants' @@ -16,9 +16,6 @@ const appListQueryParsers = { category: parseAsStringLiteral(APP_LIST_CATEGORY_VALUES) .withDefault('all') .withOptions({ history: 'push' }), - tagIDs: parseAsArrayOf(parseAsString, ';') - .withDefault([]) - .withOptions({ history: 'push' }), keywords: parseAsString.withDefault('').withOptions({ limitUrlUpdates: debounce(APP_LIST_SEARCH_DEBOUNCE_MS), }), @@ -39,10 +36,6 @@ export function useAppsQueryState() { setQuery({ keywords }) }, [setQuery]) - const setTagIDs = useCallback((tagIDs: string[]) => { - setQuery({ tagIDs }) - }, [setQuery]) - const setIsCreatedByMe = useCallback((isCreatedByMe: boolean) => { setQuery({ isCreatedByMe }) }, [setQuery]) @@ -51,7 +44,6 @@ export function useAppsQueryState() { query, setCategory, setKeywords, - setTagIDs, setIsCreatedByMe, - }), [query, setCategory, setKeywords, setTagIDs, setIsCreatedByMe]) + }), [query, setCategory, setKeywords, setIsCreatedByMe]) } diff --git a/web/app/components/apps/list.tsx b/web/app/components/apps/list.tsx index ceb894d537..ebf2ec446c 100644 --- a/web/app/components/apps/list.tsx +++ b/web/app/components/apps/list.tsx @@ -10,6 +10,7 @@ import { NEED_REFRESH_APP_LIST_KEY } from '@/config' import { useAppContext } from '@/context/app-context' import { useProviderContext } from '@/context/provider-context' import { CheckModal } from '@/hooks/use-pay' +import { usePathname, useRouter, useSearchParams } from '@/next/navigation' import { consoleQuery } from '@/service/client' import { systemFeaturesQueryOptions } from '@/service/system-features' import { AppModeEnum } from '@/types/app' @@ -32,15 +33,18 @@ function List({ controlRefreshList = 0 }: { controlRefreshList?: number }) { const { data: systemFeatures } = useSuspenseQuery(systemFeaturesQueryOptions()) const { isCurrentWorkspaceEditor, isCurrentWorkspaceDatasetOperator, isLoadingCurrentWorkspace } = useAppContext() const { onPlanInfoChanged } = useProviderContext() + const searchParams = useSearchParams() + const pathname = usePathname() + const { replace } = useRouter() // eslint-disable-next-line react/use-state -- custom URL query hook, not React.useState const { - query: { category, tagIDs, keywords, isCreatedByMe, emptyAppList }, + query: { category, keywords, isCreatedByMe, emptyAppList }, setCategory, setKeywords, - setTagIDs, setIsCreatedByMe, } = useAppsQueryState() + const [tagIDs, setTagIDs] = useState([]) const debouncedKeywords = useDebounce(keywords, { wait: APP_LIST_SEARCH_DEBOUNCE_MS }) const newAppCardRef = useRef(null) const containerRef = useRef(null) @@ -61,6 +65,16 @@ function List({ controlRefreshList = 0 }: { controlRefreshList?: number }) { enabled: isCurrentWorkspaceEditor, }) + useEffect(() => { + if (!searchParams.has('tagIDs')) + return + + const params = new URLSearchParams(searchParams.toString()) + params.delete('tagIDs') + const query = params.toString() + replace(query ? `${pathname}?${query}` : pathname, { scroll: false }) + }, [pathname, replace, searchParams]) + const appListQuery = useMemo(() => ({ page: 1, limit: 30, diff --git a/web/app/components/education-verify-action-recorder.tsx b/web/app/components/education-verify-action-recorder.tsx new file mode 100644 index 0000000000..017bfa945e --- /dev/null +++ b/web/app/components/education-verify-action-recorder.tsx @@ -0,0 +1,19 @@ +'use client' + +import { useEffect } from 'react' +import { + EDUCATION_VERIFY_URL_SEARCHPARAMS_ACTION, + EDUCATION_VERIFYING_LOCALSTORAGE_ITEM, +} from '@/app/education-apply/constants' +import { useSearchParams } from '@/next/navigation' + +export function EducationVerifyActionRecorder() { + const searchParams = useSearchParams() + + useEffect(() => { + if (searchParams.get('action') === EDUCATION_VERIFY_URL_SEARCHPARAMS_ACTION) + localStorage.setItem(EDUCATION_VERIFYING_LOCALSTORAGE_ITEM, 'yes') + }, [searchParams]) + + return null +} diff --git a/web/app/(commonLayout)/loading.tsx b/web/app/components/full-screen-loading.tsx similarity index 58% rename from web/app/(commonLayout)/loading.tsx rename to web/app/components/full-screen-loading.tsx index f9721ce5e0..c38fb1ed32 100644 --- a/web/app/(commonLayout)/loading.tsx +++ b/web/app/components/full-screen-loading.tsx @@ -1,6 +1,6 @@ -import Loading from '@/app/components/base/loading' +import Loading from './base/loading' -export default function CommonLayoutLoading() { +export function FullScreenLoading() { return (
diff --git a/web/app/components/header/account-dropdown/__tests__/index.spec.tsx b/web/app/components/header/account-dropdown/__tests__/index.spec.tsx index 107ebd7028..a894832936 100644 --- a/web/app/components/header/account-dropdown/__tests__/index.spec.tsx +++ b/web/app/components/header/account-dropdown/__tests__/index.spec.tsx @@ -277,7 +277,6 @@ describe('AccountDropdown', () => { // Assert await waitFor(() => { expect(mockLogout).toHaveBeenCalled() - expect(localStorage.removeItem).toHaveBeenCalledWith('setup_status') expect(mockPush).toHaveBeenCalledWith('/signin') }) }) diff --git a/web/app/components/header/account-dropdown/index.tsx b/web/app/components/header/account-dropdown/index.tsx index 1925651a66..ef17ec5ff4 100644 --- a/web/app/components/header/account-dropdown/index.tsx +++ b/web/app/components/header/account-dropdown/index.tsx @@ -42,7 +42,6 @@ export default function AppSelector({ const handleLogout = async () => { await logout() resetUser() - localStorage.removeItem('setup_status') // Tokens are now stored in cookies and cleared by backend // To avoid use other account's education notice info diff --git a/web/app/components/header/header-wrapper.tsx b/web/app/components/header/header-wrapper.tsx index 6a35c34dd8..ea45626407 100644 --- a/web/app/components/header/header-wrapper.tsx +++ b/web/app/components/header/header-wrapper.tsx @@ -1,7 +1,11 @@ 'use client' +import type { EventEmitterValue } from '@/context/event-emitter' import { cn } from '@langgenius/dify-ui/cn' import * as React from 'react' +import { useState } from 'react' +import { useEventEmitterContextContext } from '@/context/event-emitter' import { usePathname } from '@/next/navigation' +import { useLocalStorageBoolean } from '@/utils/local-storage' import s from './index.module.css' type HeaderWrapperProps = { @@ -13,9 +17,20 @@ const HeaderWrapper = ({ }: HeaderWrapperProps) => { const pathname = usePathname() const isBordered = ['/apps', '/datasets/create', '/tools'].includes(pathname) + const inWorkflowCanvas = pathname.endsWith('/workflow') + const isPipelineCanvas = pathname.endsWith('/pipeline') + const storedHideHeader = useLocalStorageBoolean('workflow-canvas-maximize') + const [eventHideHeader, setEventHideHeader] = useState(null) + const hideHeader = eventHideHeader ?? storedHideHeader + const { eventEmitter } = useEventEmitterContextContext() + + eventEmitter?.useSubscription((v: EventEmitterValue) => { + if (typeof v === 'object' && v?.type === 'workflow-canvas-maximize' && typeof v.payload === 'boolean') + setEventHideHeader(v.payload) + }) return ( -
+
{children}
) diff --git a/web/app/components/header/maintenance-notice.tsx b/web/app/components/header/maintenance-notice.tsx index c60830c3cc..6d0adfbeba 100644 --- a/web/app/components/header/maintenance-notice.tsx +++ b/web/app/components/header/maintenance-notice.tsx @@ -3,19 +3,22 @@ import { useTranslation } from 'react-i18next' import { X } from '@/app/components/base/icons/src/vender/line/general' import { useLanguage } from '@/app/components/header/account-setting/model-provider-page/hooks' import { NOTICE_I18N } from '@/i18n-config/language' +import { setLocalStorageItem, useLocalStorageItem } from '@/utils/local-storage' const MaintenanceNotice = () => { const { t } = useTranslation() const locale = useLanguage() - const [showNotice, setShowNotice] = useState(() => localStorage.getItem('hide-maintenance-notice') !== '1') + const hiddenNotice = useLocalStorageItem('hide-maintenance-notice') === '1' + const [closedInSession, setClosedInSession] = useState(false) + const showNotice = !hiddenNotice && !closedInSession const handleJumpNotice = () => { window.open(NOTICE_I18N.href, '_blank') } const handleCloseNotice = () => { - localStorage.setItem('hide-maintenance-notice', '1') - setShowNotice(false) + setLocalStorageItem('hide-maintenance-notice', '1') + setClosedInSession(true) } const titleByLocale: { [key: string]: string } = NOTICE_I18N.title diff --git a/web/app/components/oauth-registration-analytics.tsx b/web/app/components/oauth-registration-analytics.tsx new file mode 100644 index 0000000000..73e90a1870 --- /dev/null +++ b/web/app/components/oauth-registration-analytics.tsx @@ -0,0 +1,66 @@ +'use client' + +import Cookies from 'js-cookie' +import { useEffect, useRef } from 'react' +import { useSearchParams } from '@/next/navigation' +import { sendGAEvent } from '@/utils/gtag' +import { trackEvent } from './base/amplitude' + +const OAUTH_NEW_USER_PARAM = 'oauth_new_user' + +const isRecord = (value: unknown): value is Record => + Boolean(value) && typeof value === 'object' && !Array.isArray(value) + +const removeOAuthNewUserParam = () => { + const url = new URL(window.location.href) + url.searchParams.delete(OAUTH_NEW_USER_PARAM) + window.history.replaceState(window.history.state, '', `${url.pathname}${url.search}${url.hash}`) +} + +export function OAuthRegistrationAnalytics() { + const searchParams = useSearchParams() + const oauthNewUserParam = searchParams.get(OAUTH_NEW_USER_PARAM) + const handledParamRef = useRef(null) + + useEffect(() => { + if (oauthNewUserParam === null || handledParamRef.current === oauthNewUserParam) + return + + handledParamRef.current = oauthNewUserParam + const oauthNewUser = oauthNewUserParam === 'true' + if (!oauthNewUser) { + removeOAuthNewUserParam() + return + } + + let utmInfo: Record | null = null + const utmInfoStr = Cookies.get('utm_info') + if (utmInfoStr) { + try { + const parsed: unknown = JSON.parse(utmInfoStr) + if (isRecord(parsed)) + utmInfo = parsed + } + catch (e) { + console.error('Failed to parse utm_info cookie:', e) + } + } + + const eventName = utmInfo ? 'user_registration_success_with_utm' : 'user_registration_success' + + trackEvent(eventName, { + method: 'oauth', + ...utmInfo, + }) + + sendGAEvent(eventName, { + method: 'oauth', + ...utmInfo, + }) + + Cookies.remove('utm_info') + removeOAuthNewUserParam() + }, [oauthNewUserParam]) + + return null +} diff --git a/web/app/components/plugins/plugin-page/plugin-tasks/components/__tests__/plugin-item.spec.tsx b/web/app/components/plugins/plugin-page/plugin-tasks/components/__tests__/plugin-item.spec.tsx index ea8257af2a..e4dac6d4bf 100644 --- a/web/app/components/plugins/plugin-page/plugin-tasks/components/__tests__/plugin-item.spec.tsx +++ b/web/app/components/plugins/plugin-page/plugin-tasks/components/__tests__/plugin-item.spec.tsx @@ -3,6 +3,12 @@ import { fireEvent, render, screen } from '@testing-library/react' import { PluginSource, TaskStatus } from '@/app/components/plugins/types' import PluginItem from '../plugin-item' +vi.mock('@/app/components/base/icons/src/vender/solid/mediaAndDevices', () => ({ + MagicBox: ({ className }: { className?: string }) => ( + + ), +})) + vi.mock('@/app/components/plugins/card/base/card-icon', () => ({ default: ({ src, size }: { src: string, size: string }) => (
@@ -108,6 +114,23 @@ describe('PluginItem', () => { expect(cardIcon).toHaveAttribute('data-src', 'https://example.com/icons/my-icon.svg') expect(cardIcon).toHaveAttribute('data-size', 'small') }) + + it('should show default tool icon when plugin icon is empty', () => { + const { container } = render( + } + statusText="status" + />, + ) + + expect(mockGetIconUrl).not.toHaveBeenCalled() + expect(screen.queryByTestId('card-icon')).not.toBeInTheDocument() + expect(container.querySelector('[data-testid="magic-box-icon"]')).toHaveClass('size-8', 'text-text-tertiary') + expect(screen.getByTestId('status-icon').parentElement).toHaveClass('absolute', '-bottom-0.5', '-right-0.5', 'z-10') + }) }) describe('Props', () => { diff --git a/web/app/components/plugins/plugin-page/plugin-tasks/components/plugin-item.tsx b/web/app/components/plugins/plugin-page/plugin-tasks/components/plugin-item.tsx index 530544719a..f3a7eb6a84 100644 --- a/web/app/components/plugins/plugin-page/plugin-tasks/components/plugin-item.tsx +++ b/web/app/components/plugins/plugin-page/plugin-tasks/components/plugin-item.tsx @@ -1,6 +1,7 @@ import type { FC, ReactNode } from 'react' import type { PluginStatus } from '@/app/components/plugins/types' import type { Locale } from '@/i18n-config' +import { MagicBox } from '@/app/components/base/icons/src/vender/solid/mediaAndDevices' import CardIcon from '@/app/components/plugins/card/base/card-icon' type PluginItemProps = { @@ -24,13 +25,20 @@ const PluginItem: FC = ({ action, onClear, }) => { + const hasPluginIcon = !!plugin.icon + return (
- + {hasPluginIcon + ? ( + + ) + // eslint-disable-next-line hyoban/prefer-tailwind-icons -- Reuse the same MagicBox component as the marketplace install button. + : }
{statusIcon}
diff --git a/web/app/components/workflow/store/workflow/layout-slice.ts b/web/app/components/workflow/store/workflow/layout-slice.ts index 833d8e7e92..fe5a3d1483 100644 --- a/web/app/components/workflow/store/workflow/layout-slice.ts +++ b/web/app/components/workflow/store/workflow/layout-slice.ts @@ -1,4 +1,5 @@ import type { StateCreator } from 'zustand' +import { getLocalStorageBoolean, getLocalStorageNumber } from '@/utils/local-storage' export type LayoutSliceShape = { workflowCanvasWidth?: number @@ -20,6 +21,8 @@ export type LayoutSliceShape = { setBottomPanelHeight: (height: number) => void variableInspectPanelHeight: number // min-height = 120px; default-height = 320px; setVariableInspectPanelHeight: (height: number) => void + maximizeCanvas: boolean + setMaximizeCanvas: (maximize: boolean) => void } export const createLayoutSlice: StateCreator = set => ({ @@ -32,10 +35,10 @@ export const createLayoutSlice: StateCreator = set => ({ rightPanelWidth: undefined, setRightPanelWidth: width => set(state => state.rightPanelWidth === width ? state : ({ rightPanelWidth: width })), - nodePanelWidth: localStorage.getItem('workflow-node-panel-width') ? Number.parseFloat(localStorage.getItem('workflow-node-panel-width')!) : 400, + nodePanelWidth: getLocalStorageNumber('workflow-node-panel-width', 400), setNodePanelWidth: width => set(state => state.nodePanelWidth === width ? state : ({ nodePanelWidth: width })), - previewPanelWidth: localStorage.getItem('debug-and-preview-panel-width') ? Number.parseFloat(localStorage.getItem('debug-and-preview-panel-width')!) : 400, + previewPanelWidth: getLocalStorageNumber('debug-and-preview-panel-width', 400), setPreviewPanelWidth: width => set(state => state.previewPanelWidth === width ? state : ({ previewPanelWidth: width })), otherPanelWidth: 400, @@ -47,7 +50,10 @@ export const createLayoutSlice: StateCreator = set => ({ bottomPanelHeight: 324, setBottomPanelHeight: height => set(state => state.bottomPanelHeight === height ? state : ({ bottomPanelHeight: height })), - variableInspectPanelHeight: localStorage.getItem('workflow-variable-inpsect-panel-height') ? Number.parseFloat(localStorage.getItem('workflow-variable-inpsect-panel-height')!) : 320, + variableInspectPanelHeight: getLocalStorageNumber('workflow-variable-inpsect-panel-height', 320), setVariableInspectPanelHeight: height => set(state => state.variableInspectPanelHeight === height ? state : ({ variableInspectPanelHeight: height })), + maximizeCanvas: getLocalStorageBoolean('workflow-canvas-maximize'), + setMaximizeCanvas: maximize => set(state => + state.maximizeCanvas === maximize ? state : ({ maximizeCanvas: maximize })), }) diff --git a/web/app/components/workflow/store/workflow/panel-slice.ts b/web/app/components/workflow/store/workflow/panel-slice.ts index 09f08b68ba..9bace47b44 100644 --- a/web/app/components/workflow/store/workflow/panel-slice.ts +++ b/web/app/components/workflow/store/workflow/panel-slice.ts @@ -1,4 +1,5 @@ import type { StateCreator } from 'zustand' +import { getLocalStorageNumber } from '@/utils/local-storage' export type WorkflowContextMenuTarget = | { type: 'panel' } @@ -33,7 +34,7 @@ export type PanelSliceShape = { } export const createPanelSlice: StateCreator = set => ({ - panelWidth: localStorage.getItem('workflow-node-panel-width') ? Number.parseFloat(localStorage.getItem('workflow-node-panel-width')!) : 420, + panelWidth: getLocalStorageNumber('workflow-node-panel-width', 420), showFeaturesPanel: false, setShowFeaturesPanel: showFeaturesPanel => set(() => ({ showFeaturesPanel })), showWorkflowVersionHistoryPanel: false, diff --git a/web/app/components/workflow/store/workflow/workflow-slice.ts b/web/app/components/workflow/store/workflow/workflow-slice.ts index 58e3debc63..bbcd178da1 100644 --- a/web/app/components/workflow/store/workflow/workflow-slice.ts +++ b/web/app/components/workflow/store/workflow/workflow-slice.ts @@ -6,6 +6,7 @@ import type { WorkflowRunningData, } from '@/app/components/workflow/types' import type { FileUploadConfigResponse } from '@/models/common' +import { getLocalStorageItem, setLocalStorageItem } from '@/utils/local-storage' type PreviewRunningData = WorkflowRunningData & { resultTabActive?: boolean @@ -21,6 +22,14 @@ type MousePosition = { elementY: number } +const getStoredControlMode = () => { + const storedControlMode = getLocalStorageItem('workflow-operation-mode') + if (storedControlMode === 'pointer' || storedControlMode === 'hand' || storedControlMode === 'comment') + return storedControlMode + + return 'pointer' +} + export type WorkflowSliceShape = { workflowRunningData?: PreviewRunningData setWorkflowRunningData: (workflowData: PreviewRunningData) => void @@ -92,16 +101,10 @@ export const createWorkflowSlice: StateCreator = set => ({ setSelection: selection => set(() => ({ selection })), bundleNodeSize: null, setBundleNodeSize: bundleNodeSize => set(() => ({ bundleNodeSize })), - controlMode: (() => { - const storedControlMode = localStorage.getItem('workflow-operation-mode') - if (storedControlMode === 'pointer' || storedControlMode === 'hand' || storedControlMode === 'comment') - return storedControlMode - - return 'pointer' - })(), + controlMode: getStoredControlMode(), setControlMode: (controlMode) => { set(() => ({ controlMode })) - localStorage.setItem('workflow-operation-mode', controlMode) + setLocalStorageItem('workflow-operation-mode', controlMode) }, pendingComment: null, setPendingComment: pendingComment => set(() => ({ pendingComment })), diff --git a/web/app/device/__tests__/page-terminal.spec.tsx b/web/app/device/__tests__/page-terminal.spec.tsx index 57d749897c..cb69653ead 100644 --- a/web/app/device/__tests__/page-terminal.spec.tsx +++ b/web/app/device/__tests__/page-terminal.spec.tsx @@ -6,9 +6,10 @@ import DevicePage from '../page' const mockPush = vi.fn() const mockReplace = vi.fn() const mockDeviceLookup = vi.fn() +let mockSearchParams: Record = {} vi.mock('@/next/navigation', () => ({ - useSearchParams: () => ({ get: () => null }), + useSearchParams: () => ({ get: (key: string) => mockSearchParams[key] ?? null }), useRouter: () => ({ push: mockPush, replace: mockReplace }), usePathname: () => '/device', })) @@ -38,8 +39,11 @@ vi.mock('@/service/system-features', () => ({ systemFeaturesQueryOptions: () => ({ queryKey: ['sys'], queryFn: async () => ({}) }), })) -vi.mock('@/service/use-common', () => ({ +vi.mock('@/features/account-profile/client', () => ({ userProfileQueryOptions: () => ({ queryKey: ['profile'], queryFn: async () => null }), +})) + +vi.mock('@/service/use-common', () => ({ commonQueryKeys: { currentWorkspace: ['currentWorkspace'] }, })) @@ -53,6 +57,12 @@ let MockDeviceFlowError: MockDeviceFlowErrorCtor beforeEach(async () => { vi.clearAllMocks() + mockSearchParams = {} + // router.replace(pathname) in the real app drops the query string; mirror + // that so useSearchParams reflects the cleared URL on the next render. + mockReplace.mockImplementation(() => { + mockSearchParams = {} + }) mockUseQuery.mockReturnValue({ data: undefined, isError: false } as ReturnType) const mod = await import('@/service/device-flow') as { DeviceFlowError: MockDeviceFlowErrorCtor } MockDeviceFlowError = mod.DeviceFlowError @@ -110,3 +120,41 @@ describe('error_lookup_failed terminal state', () => { expect(screen.queryByText('Could not verify the code')).not.toBeInTheDocument() }) }) + +describe('sso_error inline banner on the code-entry page', () => { + const SSO_BANNER_COPY = /identity is linked to a Dify account/i + + it('shows the error banner with friendly copy when sso_error is present', async () => { + mockSearchParams = { sso_error: 'email_belongs_to_dify_account' } + render() + expect(await screen.findByText(SSO_BANNER_COPY)).toBeInTheDocument() + }) + + it('keeps the code-entry screen visible (error on main page, not a separate view)', async () => { + mockSearchParams = { sso_error: 'email_belongs_to_dify_account' } + render() + await screen.findByText(SSO_BANNER_COPY) + expect(screen.getByRole('textbox')).toBeInTheDocument() + expect(screen.getByRole('button', { name: /Continue/i })).toBeInTheDocument() + }) + + it('does not surface the raw backend error code', async () => { + mockSearchParams = { sso_error: 'email_belongs_to_dify_account' } + render() + await screen.findByText(SSO_BANNER_COPY) + expect(screen.queryByText('email_belongs_to_dify_account')).not.toBeInTheDocument() + }) + + it('does not scrub the param on mount (regression: error was wiped by router.replace)', async () => { + mockSearchParams = { sso_error: 'email_belongs_to_dify_account' } + render() + await screen.findByText(SSO_BANNER_COPY) + expect(mockReplace).not.toHaveBeenCalled() + }) + + it('shows no banner when sso_error is absent', () => { + render() + expect(screen.getByRole('textbox')).toBeInTheDocument() + expect(screen.queryByText(SSO_BANNER_COPY)).not.toBeInTheDocument() + }) +}) diff --git a/web/app/device/page.tsx b/web/app/device/page.tsx index 83def36a75..aa09936b1f 100644 --- a/web/app/device/page.tsx +++ b/web/app/device/page.tsx @@ -5,16 +5,17 @@ import { Button } from '@langgenius/dify-ui/button' import { useQuery } from '@tanstack/react-query' import { useEffect, useState } from 'react' import Divider from '@/app/components/base/divider' +import { userProfileQueryOptions } from '@/features/account-profile/client' import { usePathname, useRouter, useSearchParams } from '@/next/navigation' import { post } from '@/service/base' import { deviceLookup } from '@/service/device-flow' import { systemFeaturesQueryOptions } from '@/service/system-features' -import { commonQueryKeys, userProfileQueryOptions } from '@/service/use-common' +import { commonQueryKeys } from '@/service/use-common' import AuthorizeAccount from './components/authorize-account' import AuthorizeSSO from './components/authorize-sso' import Chooser from './components/chooser' import CodeInput from './components/code-input' -import { classifyLookupError } from './utils/error-copy' +import { classifyLookupError, ssoErrorCopy } from './utils/error-copy' import { isValidUserCode } from './utils/user-code' type View @@ -33,6 +34,7 @@ export default function DevicePage() { const pathname = usePathname() const urlUserCode = (searchParams.get('user_code') || '').trim().toUpperCase() const ssoVerified = searchParams.get('sso_verified') === '1' + const ssoError = searchParams.get('sso_error') || '' const [typed, setTyped] = useState('') const [view, setView] = useState({ kind: 'code_entry' }) @@ -125,6 +127,12 @@ export default function DevicePage() { <> {view.kind === 'code_entry' && (
+ {ssoError && ( +
+ +

{ssoErrorCopy(ssoError)}

+
+ )}

Authorize Dify CLI

diff --git a/web/app/device/utils/error-copy.ts b/web/app/device/utils/error-copy.ts index 9360fb167e..d0184dad7b 100644 --- a/web/app/device/utils/error-copy.ts +++ b/web/app/device/utils/error-copy.ts @@ -30,6 +30,18 @@ export function approveErrorCopy(err: unknown): string { return DEFAULT_MESSAGE } +// SSO-branch failures arrive as a `sso_error` query param set by the backend +// (oauth_device_sso sso-complete) when it redirects back to /device. +const SSO_ERROR_COPY: Record = { + email_belongs_to_dify_account: 'This identity is linked to a Dify account. Use “Sign in with Dify account” instead.', +} + +const DEFAULT_SSO_ERROR_MESSAGE = 'Single sign-on could not be completed. Try again.' + +export function ssoErrorCopy(code: string): string { + return SSO_ERROR_COPY[code] ?? DEFAULT_SSO_ERROR_MESSAGE +} + export type LookupOutcome = 'expired' | 'rate_limited' | 'failed' export function classifyLookupError(err: unknown): LookupOutcome { diff --git a/web/app/education-apply/user-info.tsx b/web/app/education-apply/user-info.tsx index be9b319038..b25d693a65 100644 --- a/web/app/education-apply/user-info.tsx +++ b/web/app/education-apply/user-info.tsx @@ -15,7 +15,6 @@ const UserInfo = () => { const handleLogout = async () => { await logout() - localStorage.removeItem('setup_status') // Tokens are now stored in cookies and cleared by backend router.push('/signin') diff --git a/web/app/error.tsx b/web/app/error.tsx new file mode 100644 index 0000000000..33c2b9189c --- /dev/null +++ b/web/app/error.tsx @@ -0,0 +1,33 @@ +'use client' + +import { Button } from '@langgenius/dify-ui/button' +import { useTranslation } from 'react-i18next' +import { FullScreenLoading } from '@/app/components/full-screen-loading' +import { isLegacyBase401 } from '@/features/account-profile/client' + +type Props = { + error: Error & { digest?: string } + reset?: () => void + unstable_retry?: () => void +} + +export default function AppError({ error, reset, unstable_retry }: Props) { + const { t } = useTranslation('common') + const retry = reset ?? unstable_retry + + if (isLegacyBase401(error)) + return + + return ( +

+
+ {t('errorBoundary.message')} +
+ {retry && ( + + )} +
+ ) +} diff --git a/web/app/install/installForm.spec.tsx b/web/app/install/installForm.spec.tsx index a9b8cc02be..65ad0b0df1 100644 --- a/web/app/install/installForm.spec.tsx +++ b/web/app/install/installForm.spec.tsx @@ -154,7 +154,6 @@ describe('InstallForm', () => { render() await waitFor(() => { - expect(localStorage.setItem).toHaveBeenCalledWith('setup_status', 'finished') expect(mockPush).toHaveBeenCalledWith('/signin') }) }) diff --git a/web/app/install/installForm.tsx b/web/app/install/installForm.tsx index d9b6e0c7ad..20d2b26769 100644 --- a/web/app/install/installForm.tsx +++ b/web/app/install/installForm.tsx @@ -87,7 +87,6 @@ const InstallForm = () => { useEffect(() => { fetchSetupStatus().then((res: SetupStatusResponse) => { if (res.step === 'finished') { - localStorage.setItem('setup_status', 'finished') router.push('/signin') } else { diff --git a/web/app/loading.tsx b/web/app/loading.tsx deleted file mode 100644 index b108baaa97..0000000000 --- a/web/app/loading.tsx +++ /dev/null @@ -1,9 +0,0 @@ -import Loading from '@/app/components/base/loading' - -export default function RootLoading() { - return ( -
- -
- ) -} diff --git a/web/app/signin/normal-form.tsx b/web/app/signin/normal-form.tsx index 76ac6b2ab4..23576837c0 100644 --- a/web/app/signin/normal-form.tsx +++ b/web/app/signin/normal-form.tsx @@ -6,11 +6,11 @@ import * as React from 'react' import { useCallback, useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' import { IS_CE_EDITION } from '@/config' +import { isLegacyBase401, userProfileQueryOptions } from '@/features/account-profile/client' import Link from '@/next/link' import { useRouter, useSearchParams } from '@/next/navigation' import { invitationCheck } from '@/service/common' import { systemFeaturesQueryOptions } from '@/service/system-features' -import { isLegacyBase401, userProfileQueryOptions } from '@/service/use-common' import { LicenseStatus } from '@/types/feature' import Loading from '../components/base/loading' import MailAndCodeAuth from './components/mail-and-code-auth' diff --git a/web/app/signin/utils/__tests__/post-login-redirect.spec.ts b/web/app/signin/utils/__tests__/post-login-redirect.spec.ts new file mode 100644 index 0000000000..00e270db2b --- /dev/null +++ b/web/app/signin/utils/__tests__/post-login-redirect.spec.ts @@ -0,0 +1,31 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { resolvePostLoginRedirect, setPostLoginRedirect } from '../post-login-redirect' + +describe('post-login redirect utilities', () => { + beforeEach(() => { + vi.useRealTimers() + window.localStorage.clear() + window.sessionStorage.clear() + }) + + it('should use the redirect_url query param first', () => { + const searchParams = new URLSearchParams({ + redirect_url: encodeURIComponent('/account/oauth/authorize?client_id=app&redirect_uri=https%3A%2F%2Fexample.com%2Fcallback'), + }) + + expect(resolvePostLoginRedirect(searchParams as unknown as Parameters[0])).toBe('/account/oauth/authorize?client_id=app&redirect_uri=https%3A%2F%2Fexample.com%2Fcallback') + }) + + it('should recover a valid device redirect from sessionStorage once', () => { + setPostLoginRedirect('/device?user_code=ABCD&sso_verified=true') + + expect(resolvePostLoginRedirect()).toBe('/device?user_code=ABCD&sso_verified=true') + expect(resolvePostLoginRedirect()).toBeNull() + }) + + it('should ignore invalid stored redirects', () => { + setPostLoginRedirect('https://example.com/device?user_code=ABCD') + + expect(resolvePostLoginRedirect()).toBeNull() + }) +}) diff --git a/web/app/signin/utils/post-login-redirect.ts b/web/app/signin/utils/post-login-redirect.ts index 2c8d82d9bb..363cd6bdf6 100644 --- a/web/app/signin/utils/post-login-redirect.ts +++ b/web/app/signin/utils/post-login-redirect.ts @@ -1,6 +1,5 @@ import type { ReadonlyURLSearchParams } from '@/next/navigation' -const OAUTH_AUTHORIZE_PENDING_KEY = 'oauth_authorize_pending_redirect' const REDIRECT_URL_KEY = 'redirect_url' const DEVICE_REDIRECT_KEY = 'dify_post_login_redirect' const DEVICE_TTL_MS = 15 * 60 * 1000 @@ -10,13 +9,6 @@ const ALLOWED: Record> = { '/account/oauth/authorize': new Set(['client_id', 'scope', 'state', 'redirect_uri']), } -type OAuthPendingRedirect = { - value?: string - expiry?: number -} - -const getCurrentUnixTimestamp = () => Math.floor(Date.now() / 1000) - function validate(target: string): string | null { if (typeof window === 'undefined') return null @@ -87,51 +79,14 @@ function getDeviceRedirect(): string | null { } } -function removeOAuthPendingRedirect() { - try { - localStorage.removeItem(OAUTH_AUTHORIZE_PENDING_KEY) - } - catch {} -} - -function getOAuthPendingRedirect(): string | null { - try { - const raw = localStorage.getItem(OAUTH_AUTHORIZE_PENDING_KEY) - if (!raw) - return null - removeOAuthPendingRedirect() - const item: OAuthPendingRedirect = JSON.parse(raw) - if (!item.value || typeof item.expiry !== 'number') - return null - return getCurrentUnixTimestamp() > item.expiry ? null : item.value - } - catch { - removeOAuthPendingRedirect() - return null - } -} - -export function setOAuthPendingRedirect(url: string, ttlSeconds: number = 300) { - try { - const item: OAuthPendingRedirect = { - value: url, - expiry: getCurrentUnixTimestamp() + ttlSeconds, - } - localStorage.setItem(OAUTH_AUTHORIZE_PENDING_KEY, JSON.stringify(item)) - } - catch {} -} - export const resolvePostLoginRedirect = (searchParams?: ReadonlyURLSearchParams) => { if (searchParams) { const redirectUrl = searchParams.get(REDIRECT_URL_KEY) if (redirectUrl) { try { - removeOAuthPendingRedirect() return decodeURIComponent(redirectUrl) } catch { - removeOAuthPendingRedirect() return redirectUrl } } @@ -139,5 +94,5 @@ export const resolvePostLoginRedirect = (searchParams?: ReadonlyURLSearchParams) const device = getDeviceRedirect() if (device) return device - return getOAuthPendingRedirect() + return null } diff --git a/web/config/server.ts b/web/config/server.ts new file mode 100644 index 0000000000..388363642d --- /dev/null +++ b/web/config/server.ts @@ -0,0 +1,10 @@ +import { env } from '@/env' + +import 'server-only' + +const withoutTrailingSlash = (value: string) => value.endsWith('/') ? value.slice(0, -1) : value + +// Server-side requests need the origin; browser requests should keep using NEXT_PUBLIC_API_PREFIX. +export const SERVER_CONSOLE_API_PREFIX = env.CONSOLE_API_URL + ? `${withoutTrailingSlash(env.CONSOLE_API_URL)}/console/api` + : undefined diff --git a/web/context/app-context-provider.tsx b/web/context/app-context-provider.tsx index ab09f4618c..2963c90dbc 100644 --- a/web/context/app-context-provider.tsx +++ b/web/context/app-context-provider.tsx @@ -16,11 +16,11 @@ import { useSelector, } from '@/context/app-context' import { env } from '@/env' +import { userProfileQueryOptions } from '@/features/account-profile/client' import { systemFeaturesQueryOptions } from '@/service/system-features' import { useCurrentWorkspace, useLangGeniusVersion, - userProfileQueryOptions, } from '@/service/use-common' type AppContextProviderProps = { @@ -29,11 +29,6 @@ type AppContextProviderProps = { export const AppContextProvider: FC = ({ children }) => { const queryClient = useQueryClient() - // Boot point for the (commonLayout) tree: - // - useSuspenseQuery for systemFeatures triggers app/loading.tsx until cache is warm. - // - useSuspenseQuery for userProfile triggers (commonLayout)/loading.tsx until cache is warm. - // After this provider mounts, downstream components reading the same queryKeys hit cache - // and never suspend again, so their useSuspenseQuery calls return data synchronously. const { data: systemFeatures } = useSuspenseQuery(systemFeaturesQueryOptions()) const { data: userProfileResp } = useSuspenseQuery(userProfileQueryOptions()) const { data: currentWorkspaceResp, isPending: isLoadingCurrentWorkspace, isFetching: isValidatingCurrentWorkspace } = useCurrentWorkspace() @@ -65,7 +60,7 @@ export const AppContextProvider: FC = ({ children }) => const isCurrentWorkspaceDatasetOperator = useMemo(() => currentWorkspace.role === 'dataset_operator', [currentWorkspace.role]) const mutateUserProfile = useCallback(() => { - queryClient.invalidateQueries({ queryKey: ['common', 'user-profile'] }) + queryClient.invalidateQueries({ queryKey: userProfileQueryOptions().queryKey }) }, [queryClient]) const mutateCurrentWorkspace = useCallback(() => { diff --git a/web/context/modal-context.test.tsx b/web/context/modal-context.test.tsx index a4a8d252b4..c73fca886a 100644 --- a/web/context/modal-context.test.tsx +++ b/web/context/modal-context.test.tsx @@ -21,6 +21,10 @@ vi.mock('@/next/navigation', () => ({ useSearchParams: vi.fn(() => new URLSearchParams()), })) +vi.mock('@/app/components/billing/pricing', () => ({ + default: () =>
billing.plansCommon.mostPopular
, +})) + const mockUseProviderContext = vi.fn() vi.mock('@/context/provider-context', () => ({ useProviderContext: () => mockUseProviderContext(), diff --git a/web/contract/console/account.ts b/web/contract/console/account.ts index a8a468c40d..5e8e27e015 100644 --- a/web/contract/console/account.ts +++ b/web/contract/console/account.ts @@ -1,6 +1,29 @@ import { type } from '@orpc/contract' import { base } from '../base' +export type AccountProfileResponse = { + id: string + name: string + email: string + avatar: string + avatar_url: string | null + is_password_set: boolean + interface_language?: string + interface_theme?: string + timezone?: string + last_login_at?: string + last_active_at?: string + last_login_ip?: string + created_at?: string +} + +export const accountProfileContract = base + .route({ + path: '/account/profile', + method: 'GET', + }) + .output(type()) + export const accountAvatarContract = base .route({ path: '/account/avatar', diff --git a/web/contract/router.ts b/web/contract/router.ts index ef4cba5e69..1a117300a5 100644 --- a/web/contract/router.ts +++ b/web/contract/router.ts @@ -1,7 +1,7 @@ import type { InferContractRouterInputs } from '@orpc/contract' import { contract as communityContract } from '@dify/contracts/api/console/orpc.gen' import { contract as enterpriseContract } from '@dify/contracts/enterprise/orpc.gen' -import { accountAvatarContract } from './console/account' +import { accountAvatarContract, accountProfileContract } from './console/account' import { appDeleteContract, appListContract, workflowOnlineUsersContract } from './console/apps' import { bindPartnerStackContract, invoicesContract } from './console/billing' import { @@ -77,6 +77,10 @@ export const consoleRouterContract = { account: { ...communityContract.account, avatar: accountAvatarContract, + profile: { + ...communityContract.account.profile, + get: accountProfileContract, + }, }, systemFeatures: systemFeaturesContract, apps: { diff --git a/web/env.ts b/web/env.ts index 0c2868be5c..16a0ea2da1 100644 --- a/web/env.ts +++ b/web/env.ts @@ -142,6 +142,7 @@ const clientSchema = { export const env = createEnv({ server: { + CONSOLE_API_URL: z.string().optional(), /** * Maximum length of segmentation tokens for indexing */ diff --git a/web/features/account-profile/__tests__/server.spec.ts b/web/features/account-profile/__tests__/server.spec.ts new file mode 100644 index 0000000000..79ed7fa571 --- /dev/null +++ b/web/features/account-profile/__tests__/server.spec.ts @@ -0,0 +1,79 @@ +import type { AccountProfileResponse } from '@/contract/console/account' +import { QueryClient } from '@tanstack/react-query' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { userProfileQueryOptions } from '../client' +import { resolveServerConsoleApiUrl } from '../server' + +const headersMock = vi.fn() +const cookiesMock = vi.fn() + +vi.mock('@/config/server', () => ({ + SERVER_CONSOLE_API_PREFIX: undefined, +})) + +vi.mock('@/next/headers', () => ({ + headers: () => headersMock(), + cookies: () => cookiesMock(), +})) + +const createProfile = (overrides: Partial = {}): AccountProfileResponse => ({ + id: 'account-id', + name: 'Dify User', + email: 'user@example.com', + avatar: '', + avatar_url: null, + is_password_set: true, + ...overrides, +}) + +describe('serverUserProfileQueryOptions', () => { + beforeEach(() => { + vi.clearAllMocks() + headersMock.mockResolvedValue(new Headers({ cookie: 'session=abc' })) + cookiesMock.mockResolvedValue({ + get: vi.fn(() => ({ value: 'csrf-token' })), + }) + }) + + it('should reuse the client profile query key and return the same data shape', async () => { + const fetchMock = vi.fn().mockResolvedValue(new Response(JSON.stringify(createProfile()), { + status: 200, + headers: { + 'content-type': 'application/json', + 'x-version': '1.2.3', + 'x-env': 'DEVELOPMENT', + }, + })) + vi.stubGlobal('fetch', fetchMock) + const { serverUserProfileQueryOptions } = await import('../server') + const queryClient = new QueryClient({ defaultOptions: { queries: { retry: false } } }) + + const data = await queryClient.fetchQuery(serverUserProfileQueryOptions()) + + expect(serverUserProfileQueryOptions().queryKey).toEqual(userProfileQueryOptions().queryKey) + expect(data).toEqual({ + profile: createProfile(), + meta: { + currentVersion: '1.2.3', + currentEnv: 'DEVELOPMENT', + }, + }) + expect(fetchMock).toHaveBeenCalledWith( + 'http://localhost:5001/console/api/account/profile', + expect.objectContaining({ + method: 'GET', + cache: 'no-store', + headers: expect.any(Headers), + }), + ) + }) + + it('should skip relative API prefixes unless a server API origin is configured', () => { + expect(resolveServerConsoleApiUrl('/account/profile', undefined, '/console/api')).toBeNull() + expect(resolveServerConsoleApiUrl('/account/profile', 'https://console.example.com/console/api', '/console/api')).toBe('https://console.example.com/console/api/account/profile') + }) + + it('should preserve absolute API prefixes', () => { + expect(resolveServerConsoleApiUrl('/account/profile', undefined, 'https://console.example.com/console/api')).toBe('https://console.example.com/console/api/account/profile') + }) +}) diff --git a/web/features/account-profile/client.ts b/web/features/account-profile/client.ts new file mode 100644 index 0000000000..636d3f8450 --- /dev/null +++ b/web/features/account-profile/client.ts @@ -0,0 +1,41 @@ +import type { AccountProfileResponse } from '@/contract/console/account' +import { queryOptions } from '@tanstack/react-query' +import { IS_DEV } from '@/config' +// eslint-disable-next-line no-restricted-imports +import { get } from '@/service/base' +import { consoleQuery } from '@/service/client' + +export type UserProfileWithMeta = { + profile: AccountProfileResponse + meta: { + currentVersion: string | null + currentEnv: string | null + } +} + +export const isLegacyBase401 = (err: unknown): boolean => + err instanceof Response && err.status === 401 + +export const userProfileQueryOptions = () => + queryOptions({ + queryKey: consoleQuery.account.profile.get.queryKey(), + queryFn: async () => { + const response = await get('/account/profile', {}, { + needAllResponseContent: true, + silent: true, + }) + const profile: AccountProfileResponse = await response.clone().json() + return { + profile, + meta: { + currentVersion: response.headers.get('x-version'), + currentEnv: IS_DEV + ? 'DEVELOPMENT' + : response.headers.get('x-env'), + }, + } + }, + staleTime: 0, + gcTime: 0, + retry: (failureCount, error) => !isLegacyBase401(error) && failureCount < 3, + }) diff --git a/web/features/account-profile/server.ts b/web/features/account-profile/server.ts new file mode 100644 index 0000000000..8885f5f15f --- /dev/null +++ b/web/features/account-profile/server.ts @@ -0,0 +1,81 @@ +import type { UserProfileWithMeta } from './client' +import type { AccountProfileResponse } from '@/contract/console/account' +import { queryOptions } from '@tanstack/react-query' +import { API_PREFIX, CSRF_COOKIE_NAME, CSRF_HEADER_NAME } from '@/config' +import { SERVER_CONSOLE_API_PREFIX } from '@/config/server' +import { cookies, headers } from '@/next/headers' +import { consoleQuery } from '@/service/client' + +const ACCOUNT_PROFILE_PATH = '/account/profile' + +const withTrailingSlash = (value: string) => value.endsWith('/') ? value : `${value}/` +const withoutLeadingSlash = (value: string) => value.startsWith('/') ? value.slice(1) : value + +const resolveAbsoluteUrlPrefix = (value: string) => { + try { + return new URL(value).toString() + } + catch { + return null + } +} + +export const resolveServerConsoleApiUrl = ( + pathname: string, + serverConsoleApiPrefix = SERVER_CONSOLE_API_PREFIX, + publicApiPrefix = API_PREFIX, +) => { + const requestPath = withoutLeadingSlash(pathname) + const apiPrefix = serverConsoleApiPrefix || resolveAbsoluteUrlPrefix(publicApiPrefix) + + if (!apiPrefix) + return null + + return new URL(requestPath, withTrailingSlash(apiPrefix)).toString() +} + +const getServerRequestHeaders = async () => { + const requestHeaders = await headers() + const cookieStore = await cookies() + const outgoingHeaders = new Headers({ + 'Content-Type': 'application/json', + }) + const cookie = requestHeaders.get('cookie') + if (cookie) + outgoingHeaders.set('cookie', cookie) + const csrfToken = cookieStore.get(CSRF_COOKIE_NAME())?.value + if (csrfToken) + outgoingHeaders.set(CSRF_HEADER_NAME, csrfToken) + return outgoingHeaders +} + +export const serverUserProfileQueryOptions = () => + queryOptions({ + queryKey: consoleQuery.account.profile.get.queryKey(), + queryFn: async () => { + const profileUrl = resolveServerConsoleApiUrl(ACCOUNT_PROFILE_PATH) + if (!profileUrl) + throw new Error('Server account profile URL is not configured') + + const response = await fetch(profileUrl, { + method: 'GET', + headers: await getServerRequestHeaders(), + cache: 'no-store', + }) + + if (!response.ok) + throw response + + const profile: AccountProfileResponse = await response.clone().json() + return { + profile, + meta: { + currentVersion: response.headers.get('x-version'), + currentEnv: response.headers.get('x-env'), + }, + } + }, + staleTime: 0, + gcTime: 0, + retry: false, + }) diff --git a/web/package.json b/web/package.json index f599e7e019..92e3f6b051 100644 --- a/web/package.json +++ b/web/package.json @@ -136,6 +136,7 @@ "remark-breaks": "catalog:", "remark-directive": "catalog:", "scheduler": "catalog:", + "server-only": "catalog:", "sharp": "catalog:", "shiki": "catalog:", "socket.io-client": "catalog:", diff --git a/web/proxy.ts b/web/proxy.ts index d735c9f568..354f830619 100644 --- a/web/proxy.ts +++ b/web/proxy.ts @@ -6,6 +6,8 @@ import { NextResponse } from 'next/server' import { env } from '@/env' const NECESSARY_DOMAIN = '*.sentry.io http://localhost:* http://127.0.0.1:* https://analytics.google.com googletagmanager.com *.googletagmanager.com https://www.google-analytics.com https://ungh.cc https://api2.amplitude.com *.amplitude.com' +const CURRENT_PATHNAME_HEADER = 'x-dify-pathname' +const CURRENT_SEARCH_HEADER = 'x-dify-search' const wrapResponseWithXFrameOptions = (response: NextResponse, pathname: string) => { // prevent clickjacking: https://owasp.org/www-community/attacks/Clickjacking @@ -16,8 +18,10 @@ const wrapResponseWithXFrameOptions = (response: NextResponse, pathname: string) return response } export function proxy(request: NextRequest) { - const { pathname } = request.nextUrl + const { pathname, search } = request.nextUrl const requestHeaders = new Headers(request.headers) + requestHeaders.set(CURRENT_PATHNAME_HEADER, pathname) + requestHeaders.set(CURRENT_SEARCH_HEADER, search) const isWhiteListEnabled = !!env.NEXT_PUBLIC_CSP_WHITELIST && process.env.NODE_ENV === 'production' if (!isWhiteListEnabled) { diff --git a/web/service/__tests__/base-request.spec.ts b/web/service/__tests__/base-request.spec.ts new file mode 100644 index 0000000000..658f73ee98 --- /dev/null +++ b/web/service/__tests__/base-request.spec.ts @@ -0,0 +1,62 @@ +import { afterEach, describe, expect, it, vi } from 'vitest' + +const createUnauthorizedResponse = () => + new Response(JSON.stringify({ + code: 'unauthorized', + message: 'Invalid Authorization token.', + status: 401, + }), { + status: 401, + headers: { + 'Content-Type': 'application/json', + }, + }) + +async function loadServerRequest() { + vi.resetModules() + + const mockBaseFetch = vi.fn(async () => { + throw createUnauthorizedResponse() + }) + const mockRefreshAccessTokenOrReLogin = vi.fn() + + vi.doMock('@/utils/client', () => ({ + isClient: false, + isServer: true, + })) + vi.doMock('../fetch', () => ({ + base: mockBaseFetch, + ContentType: { + audio: 'audio/mpeg', + download: 'application/octet-stream', + downloadZip: 'application/zip', + json: 'application/json', + }, + getBaseOptions: vi.fn(() => ({})), + })) + vi.doMock('../refresh-token', () => ({ + refreshAccessTokenOrReLogin: mockRefreshAccessTokenOrReLogin, + })) + + const { request } = await import('../base') + + return { + request, + mockRefreshAccessTokenOrReLogin, + } +} + +describe('request 401 handling', () => { + afterEach(() => { + vi.resetModules() + vi.restoreAllMocks() + }) + + it('should not run browser auth recovery when handling 401 on the server', async () => { + const { request, mockRefreshAccessTokenOrReLogin } = await loadServerRequest() + + await expect(request('/account/profile')).rejects.toMatchObject({ status: 401 }) + + expect(mockRefreshAccessTokenOrReLogin).not.toHaveBeenCalled() + }) +}) diff --git a/web/service/base.ts b/web/service/base.ts index 7e35b3d789..9bd7be8704 100644 --- a/web/service/base.ts +++ b/web/service/base.ts @@ -31,6 +31,7 @@ import { toast } from '@langgenius/dify-ui/toast' import Cookies from 'js-cookie' import { API_PREFIX, CSRF_COOKIE_NAME, CSRF_HEADER_NAME, IS_CE_EDITION, PASSPORT_HEADER_NAME, PUBLIC_API_PREFIX, WEB_APP_SHARE_CODE_HEADER_NAME } from '@/config' import { asyncRunSafe } from '@/utils' +import { isClient } from '@/utils/client' import { basePath } from '@/utils/var' import { base, ContentType, getBaseOptions } from './fetch' import { refreshAccessTokenOrReLogin } from './refresh-token' @@ -132,22 +133,22 @@ export type IOtherOptions = { } function jumpTo(url: string) { - if (!url) + if (!url || !isClient) return - const targetPath = new URL(url, globalThis.location.origin).pathname - if (targetPath === globalThis.location.pathname) + const targetPath = new URL(url, window.location.origin).pathname + if (targetPath === window.location.pathname) return - globalThis.location.href = url + window.location.href = url } const OAUTH_AUTHORIZE_PATH = '/account/oauth/authorize' export const buildSigninUrlWithRedirect = (): string => { - const loginUrl = `${globalThis.location.origin}${basePath}/signin` + const loginUrl = `${isClient ? window.location.origin : ''}${basePath}/signin` // Only preserve redirect URL for OAuth authorize pages - if (globalThis.location.pathname.includes(OAUTH_AUTHORIZE_PATH)) { - const currentUrl = globalThis.location.href + if (isClient && window.location.pathname.includes(OAUTH_AUTHORIZE_PATH)) { + const currentUrl = window.location.href return `${loginUrl}?redirect_url=${encodeURIComponent(currentUrl)}` } @@ -165,17 +166,20 @@ function unicodeToChar(text: string) { const WBB_APP_LOGIN_PATH = '/webapp-signin' function requiredWebSSOLogin(message?: string, code?: number) { - const params = new URLSearchParams() - // prevent redirect loop - if (globalThis.location.pathname === WBB_APP_LOGIN_PATH) + if (!isClient) return - params.append('redirect_url', encodeURIComponent(`${globalThis.location.pathname}${globalThis.location.search}`)) + const params = new URLSearchParams() + // prevent redirect loop + if (window.location.pathname === WBB_APP_LOGIN_PATH) + return + + params.append('redirect_url', encodeURIComponent(`${window.location.pathname}${window.location.search}`)) if (message) params.append('message', message) if (code) params.append('code', String(code)) - globalThis.location.href = `${globalThis.location.origin}${basePath}${WBB_APP_LOGIN_PATH}?${params.toString()}` + window.location.href = `${window.location.origin}${basePath}${WBB_APP_LOGIN_PATH}?${params.toString()}` } function formatURL(url: string, isPublicAPI: boolean) { @@ -759,10 +763,13 @@ export const request = async(url: string, options = {}, otherOptions?: IOther return resp const errResp: Response = err as any if (errResp.status === 401) { + if (!isClient) + return Promise.reject(err) + const [parseErr, errRespData] = await asyncRunSafe(errResp.json()) - const loginUrl = `${globalThis.location.origin}${basePath}/signin` + const loginUrl = `${window.location.origin}${basePath}/signin` if (parseErr) { - globalThis.location.href = loginUrl + window.location.href = loginUrl return Promise.reject(err) } if (/\/login/.test(url)) @@ -780,7 +787,7 @@ export const request = async(url: string, options = {}, otherOptions?: IOther } if (code === 'unauthorized_and_force_logout') { // Cookies will be cleared by the backend - globalThis.location.reload() + window.location.reload() return Promise.reject(err) } const { @@ -796,11 +803,11 @@ export const request = async(url: string, options = {}, otherOptions?: IOther return Promise.reject(err) } if (code === 'not_init_validated' && IS_CE_EDITION) { - jumpTo(`${globalThis.location.origin}${basePath}/init`) + jumpTo(`${window.location.origin}${basePath}/init`) return Promise.reject(err) } if (code === 'not_setup' && IS_CE_EDITION) { - jumpTo(`${globalThis.location.origin}${basePath}/install`) + jumpTo(`${window.location.origin}${basePath}/install`) return Promise.reject(err) } @@ -811,9 +818,9 @@ export const request = async(url: string, options = {}, otherOptions?: IOther // /device is the device-flow chooser; logged-out is a valid state // there. Redirecting to /signin loses the user_code context and // the post-login flow lands on /apps instead of returning here. - if (location.pathname === `${basePath}/device`) + if (window.location.pathname === `${basePath}/device`) return Promise.reject(err) - if (location.pathname !== `${basePath}/signin` || !IS_CE_EDITION) { + if (window.location.pathname !== `${basePath}/signin` || !IS_CE_EDITION) { jumpTo(buildSigninUrlWithRedirect()) return Promise.reject(err) } diff --git a/web/service/refresh-token.ts b/web/service/refresh-token.ts index b00a46eb6e..3c69927f27 100644 --- a/web/service/refresh-token.ts +++ b/web/service/refresh-token.ts @@ -1,5 +1,6 @@ import { API_PREFIX } from '@/config' import { fetchWithRetry } from '@/utils' +import { isClient } from '@/utils/client' const LOCAL_STORAGE_KEY = 'is_other_tab_refreshing' @@ -81,6 +82,9 @@ function releaseRefreshLock() { } export async function refreshAccessTokenOrReLogin(timeout: number) { + if (!isClient) + return Promise.reject(new Error('refresh token is client-only')) + return Promise.race([new Promise((resolve, reject) => setTimeout(() => { releaseRefreshLock() reject(new Error('request timeout')) diff --git a/web/service/use-common.ts b/web/service/use-common.ts index 5cdaa929de..315065b720 100644 --- a/web/service/use-common.ts +++ b/web/service/use-common.ts @@ -16,28 +16,15 @@ import type { PluginProvider, StructuredOutputRulesRequestBody, StructuredOutputRulesResponse, - UserProfileResponse, } from '@/models/common' import type { RETRIEVE_METHOD } from '@/types/app' -import { queryOptions, useMutation, useQuery, useQueryClient } from '@tanstack/react-query' -import { IS_DEV } from '@/config' +import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query' import { get, post } from './base' -/** - * True iff `err` is a 401 Response thrown by `service/base.ts`. - * - * Narrow on purpose: oRPC throws `ORPCError`, not `Response`, so this predicate - * returns `false` for oRPC 401s. Naming makes that scope visible. If you need - * 401 detection for an oRPC path, add a separate `isOrpc401` helper. - */ -export const isLegacyBase401 = (err: unknown): boolean => - err instanceof Response && err.status === 401 - const NAME_SPACE = 'common' export const commonQueryKeys = { fileUploadConfig: [NAME_SPACE, 'file-upload-config'] as const, - userProfile: [NAME_SPACE, 'user-profile'] as const, currentWorkspace: [NAME_SPACE, 'current-workspace'] as const, members: [NAME_SPACE, 'members'] as const, filePreview: (fileID: string) => [NAME_SPACE, 'file-preview', fileID] as const, @@ -71,49 +58,6 @@ export const useFileUploadConfig = () => { }) } -type UserProfileWithMeta = { - profile: UserProfileResponse - meta: { - currentVersion: string | null - currentEnv: string | null - } -} - -/** - * Session probe for `/account/profile`. Helper (not hook) because oRPC can't - * express the `x-version` / `x-env` response headers we post-process. - * - * Bindings: - * commonLayout -> `useSuspenseQuery(userProfileQueryOptions())` - * signin/oauth -> `useQuery({ ...userProfileQueryOptions(), throwOnError: err => !isLegacyBase401(err) })` - * - * `silent: true` + `retry: !isLegacyBase401` makes 401 a synchronous *state* (no toast, - * no ~7s retry storm). Transient errors still get the default 3 retries. - */ -export const userProfileQueryOptions = () => - queryOptions({ - queryKey: commonQueryKeys.userProfile, - queryFn: async () => { - const response = await get('/account/profile', {}, { - needAllResponseContent: true, - silent: true, - }) as Response - const profile = await response.clone().json() as UserProfileResponse - return { - profile, - meta: { - currentVersion: response.headers.get('x-version'), - currentEnv: IS_DEV - ? 'DEVELOPMENT' - : response.headers.get('x-env'), - }, - } - }, - staleTime: 0, - gcTime: 0, - retry: (failureCount, error) => !isLegacyBase401(error) && failureCount < 3, - }) - export const useLangGeniusVersion = (currentVersion?: string | null, enabled?: boolean) => { return useQuery({ queryKey: commonQueryKeys.langGeniusVersion(currentVersion || undefined), diff --git a/web/utils/local-storage.ts b/web/utils/local-storage.ts new file mode 100644 index 0000000000..961a9bdd0a --- /dev/null +++ b/web/utils/local-storage.ts @@ -0,0 +1,106 @@ +import { useSyncExternalStore } from 'react' +import { isClient } from './client' + +const LOCAL_STORAGE_CHANGE_EVENT = 'dify-local-storage-change' + +type LocalStorageChangeDetail = { + key: string +} + +export const getLocalStorageItem = (key: string, fallback: string | null = null) => { + if (!isClient) + return fallback + + try { + return window.localStorage.getItem(key) ?? fallback + } + catch { + return fallback + } +} + +export const setLocalStorageItem = (key: string, value: string) => { + if (!isClient) + return + + try { + window.localStorage.setItem(key, value) + window.dispatchEvent(new CustomEvent(LOCAL_STORAGE_CHANGE_EVENT, { + detail: { key }, + })) + } + catch { + + } +} + +/* @public */ +export const removeLocalStorageItem = (key: string) => { + if (!isClient) + return + + try { + window.localStorage.removeItem(key) + window.dispatchEvent(new CustomEvent(LOCAL_STORAGE_CHANGE_EVENT, { + detail: { key }, + })) + } + catch { + + } +} + +export const getLocalStorageBoolean = (key: string, fallback = false) => { + const value = getLocalStorageItem(key) + if (value === null) + return fallback + + return value === 'true' +} + +export const getLocalStorageNumber = (key: string, fallback: number) => { + const value = getLocalStorageItem(key) + if (!value) + return fallback + + const parsed = Number.parseFloat(value) + return Number.isNaN(parsed) ? fallback : parsed +} + +const subscribeLocalStorage = (key: string, onStoreChange: () => void) => { + if (!isClient) + return () => {} + + const handleChange = (event: Event) => { + if (event instanceof StorageEvent && event.key !== key) + return + if (event instanceof CustomEvent && event.detail?.key !== key) + return + + onStoreChange() + } + + window.addEventListener('storage', handleChange) + window.addEventListener(LOCAL_STORAGE_CHANGE_EVENT, handleChange) + + return () => { + window.removeEventListener('storage', handleChange) + window.removeEventListener(LOCAL_STORAGE_CHANGE_EVENT, handleChange) + } +} + +export const useLocalStorageItem = (key: string, fallback: string | null = null) => { + return useSyncExternalStore( + onStoreChange => subscribeLocalStorage(key, onStoreChange), + () => getLocalStorageItem(key, fallback), + () => fallback, + ) +} + +export const useLocalStorageBoolean = (key: string, fallback = false) => { + const value = useLocalStorageItem(key) + if (value === null) + return fallback + + return value === 'true' +} diff --git a/web/utils/setup-status.spec.ts b/web/utils/setup-status.spec.ts deleted file mode 100644 index be96b43eba..0000000000 --- a/web/utils/setup-status.spec.ts +++ /dev/null @@ -1,139 +0,0 @@ -import type { SetupStatusResponse } from '@/models/common' - -import { fetchSetupStatus } from '@/service/common' - -import { fetchSetupStatusWithCache } from './setup-status' - -vi.mock('@/service/common', () => ({ - fetchSetupStatus: vi.fn(), -})) - -const mockFetchSetupStatus = vi.mocked(fetchSetupStatus) - -describe('setup-status utilities', () => { - beforeEach(() => { - vi.clearAllMocks() - localStorage.clear() - }) - - describe('fetchSetupStatusWithCache', () => { - describe('when cache exists', () => { - it('should return cached finished status without API call', async () => { - localStorage.setItem('setup_status', 'finished') - - const result = await fetchSetupStatusWithCache() - - expect(result).toEqual({ step: 'finished' }) - expect(mockFetchSetupStatus).not.toHaveBeenCalled() - }) - - it('should not modify localStorage when returning cached value', async () => { - localStorage.setItem('setup_status', 'finished') - - await fetchSetupStatusWithCache() - - expect(localStorage.getItem('setup_status')).toBe('finished') - }) - }) - - describe('when cache does not exist', () => { - it('should call API and cache finished status', async () => { - const apiResponse: SetupStatusResponse = { step: 'finished' } - mockFetchSetupStatus.mockResolvedValue(apiResponse) - - const result = await fetchSetupStatusWithCache() - - expect(mockFetchSetupStatus).toHaveBeenCalledTimes(1) - expect(result).toEqual(apiResponse) - expect(localStorage.getItem('setup_status')).toBe('finished') - }) - - it('should call API and remove cache when not finished', async () => { - const apiResponse: SetupStatusResponse = { step: 'not_started' } - mockFetchSetupStatus.mockResolvedValue(apiResponse) - - const result = await fetchSetupStatusWithCache() - - expect(mockFetchSetupStatus).toHaveBeenCalledTimes(1) - expect(result).toEqual(apiResponse) - expect(localStorage.getItem('setup_status')).toBeNull() - }) - - it('should clear stale cache when API returns not_started', async () => { - localStorage.setItem('setup_status', 'some_invalid_value') - const apiResponse: SetupStatusResponse = { step: 'not_started' } - mockFetchSetupStatus.mockResolvedValue(apiResponse) - - const result = await fetchSetupStatusWithCache() - - expect(result).toEqual(apiResponse) - expect(localStorage.getItem('setup_status')).toBeNull() - }) - }) - - describe('cache edge cases', () => { - it('should call API when cache value is empty string', async () => { - localStorage.setItem('setup_status', '') - const apiResponse: SetupStatusResponse = { step: 'finished' } - mockFetchSetupStatus.mockResolvedValue(apiResponse) - - const result = await fetchSetupStatusWithCache() - - expect(mockFetchSetupStatus).toHaveBeenCalledTimes(1) - expect(result).toEqual(apiResponse) - }) - - it('should call API when cache value is not "finished"', async () => { - localStorage.setItem('setup_status', 'not_started') - const apiResponse: SetupStatusResponse = { step: 'finished' } - mockFetchSetupStatus.mockResolvedValue(apiResponse) - - const result = await fetchSetupStatusWithCache() - - expect(mockFetchSetupStatus).toHaveBeenCalledTimes(1) - expect(result).toEqual(apiResponse) - }) - - it('should call API when localStorage key does not exist', async () => { - const apiResponse: SetupStatusResponse = { step: 'finished' } - mockFetchSetupStatus.mockResolvedValue(apiResponse) - - const result = await fetchSetupStatusWithCache() - - expect(mockFetchSetupStatus).toHaveBeenCalledTimes(1) - expect(result).toEqual(apiResponse) - }) - }) - - describe('API response handling', () => { - it('should preserve setup_at from API response', async () => { - const setupDate = new Date('2024-01-01') - const apiResponse: SetupStatusResponse = { - step: 'finished', - setup_at: setupDate, - } - mockFetchSetupStatus.mockResolvedValue(apiResponse) - - const result = await fetchSetupStatusWithCache() - - expect(result).toEqual(apiResponse) - expect(result.setup_at).toEqual(setupDate) - }) - - it('should propagate API errors', async () => { - const apiError = new Error('Network error') - mockFetchSetupStatus.mockRejectedValue(apiError) - - await expect(fetchSetupStatusWithCache()).rejects.toThrow('Network error') - }) - - it('should not update cache when API call fails', async () => { - mockFetchSetupStatus.mockRejectedValue(new Error('API error')) - - await expect(fetchSetupStatusWithCache()).rejects.toThrow() - - expect(localStorage.getItem('setup_status')).toBeNull() - }) - }) - }) -}) diff --git a/web/utils/setup-status.ts b/web/utils/setup-status.ts deleted file mode 100644 index 7a2810bffd..0000000000 --- a/web/utils/setup-status.ts +++ /dev/null @@ -1,21 +0,0 @@ -import type { SetupStatusResponse } from '@/models/common' -import { fetchSetupStatus } from '@/service/common' - -const SETUP_STATUS_KEY = 'setup_status' - -const isSetupStatusCached = (): boolean => - localStorage.getItem(SETUP_STATUS_KEY) === 'finished' - -export const fetchSetupStatusWithCache = async (): Promise => { - if (isSetupStatusCached()) - return { step: 'finished' } - - const status = await fetchSetupStatus() - - if (status.step === 'finished') - localStorage.setItem(SETUP_STATUS_KEY, 'finished') - else - localStorage.removeItem(SETUP_STATUS_KEY) - - return status -}