Compare commits

..

5 Commits

Author SHA1 Message Date
fdb4008b9a fix: knowledge hit-testing render failed. (#36106) (#36110)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
2026-05-13 15:48:40 +08:00
7e5c65175a fix(security): isolate Langfuse v3 SDK TracerProvider to prevent cros… (#36107) 2026-05-13 00:05:15 -07:00
09690fbb7d fix: remove unnecessary 'relative' class from DialogContent in install-bundle and publish-as-knowledge-pipeline-modal components 2026-05-12 19:02:13 +08:00
1144060a39 refactor: clean up DialogContent styles and enhance Publisher component functionality
- Removed unnecessary 'relative' class from DialogContent in InstallBundle and PublishAsKnowledgePipelineModal components for cleaner styling.
- Enhanced Publisher component to manage the state and functionality for publishing as a knowledge pipeline, including modal handling and API integration.
- Updated tests to reflect changes in the Publisher component and ensure proper functionality of the publish-as modal.
2026-05-12 19:01:20 +08:00
6a7ec862b1 fix: can not create empty knowledge 2026-05-12 18:08:56 +08:00
13 changed files with 439 additions and 127 deletions

View File

@ -39,11 +39,8 @@ class HitTestingPayload(BaseModel):
class DatasetsHitTestingBase:
@staticmethod
def _normalize_hit_testing_query(query: Any) -> str:
"""Return the user-visible query string from legacy and current response shapes."""
if isinstance(query, str):
return query
def _extract_hit_testing_query(query: Any) -> str:
"""Return the query string from the service response shape."""
if isinstance(query, dict):
content = query.get("content")
if isinstance(content, str):
@ -52,15 +49,15 @@ class DatasetsHitTestingBase:
raise ValueError("Invalid hit testing query response")
@staticmethod
def _normalize_hit_testing_records(records: Any) -> list[dict[str, Any]]:
"""Coerce nullable collection fields into lists before response validation."""
def _prepare_hit_testing_records(records: Any) -> list[dict[str, Any]]:
"""Ensure collection fields match the API schema before response validation."""
if not isinstance(records, list):
return []
raise ValueError("Invalid hit testing records response")
normalized_records: list[dict[str, Any]] = []
for record in records:
if not isinstance(record, dict):
continue
raise ValueError("Invalid hit testing record response")
normalized_record = dict(record)
segment = normalized_record.get("segment")
@ -118,8 +115,8 @@ class DatasetsHitTestingBase:
limit=10,
)
return {
"query": DatasetsHitTestingBase._normalize_hit_testing_query(response.get("query")),
"records": DatasetsHitTestingBase._normalize_hit_testing_records(
"query": DatasetsHitTestingBase._extract_hit_testing_query(response.get("query")),
"records": DatasetsHitTestingBase._prepare_hit_testing_records(
marshal(response.get("records", []), hit_testing_record_fields)
),
}

View File

@ -13,6 +13,8 @@ from langfuse.api import (
TraceBody,
)
from langfuse.api.commons.types.usage import Usage
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider
from sqlalchemy.orm import sessionmaker
from core.ops.base_trace_instance import BaseTraceInstance
@ -52,13 +54,40 @@ class LangFuseDataTrace(BaseTraceInstance):
langfuse_config: LangfuseConfig,
):
super().__init__(langfuse_config)
# Isolated TracerProvider prevents the langfuse v3 SDK from attaching its
# SpanProcessor to the global OpenTelemetry TracerProvider, which would
# otherwise siphon every Flask/Celery/SQLAlchemy span in the process into
# this tenant's Langfuse project. See langfuse upgrade guide v2 -> v3.
self._tracer_provider = TracerProvider(
resource=Resource.create({"service.name": "dify-langfuse-app-trace"}),
)
self.langfuse_client = Langfuse(
public_key=langfuse_config.public_key,
secret_key=langfuse_config.secret_key,
host=langfuse_config.host,
tracer_provider=self._tracer_provider,
)
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
def close(self) -> None:
"""Flush and shut down the isolated TracerProvider.
Called explicitly when the trace instance is evicted from the cache, or
implicitly via ``__del__`` on garbage collection. Idempotent.
"""
provider = getattr(self, "_tracer_provider", None)
if provider is None:
return
try:
provider.shutdown()
except Exception:
logger.debug("Failed to shut down Langfuse TracerProvider", exc_info=True)
finally:
self._tracer_provider = None
def __del__(self) -> None:
self.close()
@staticmethod
def _get_completion_start_time(
start_time: datetime | None, time_to_first_token: float | int | None

View File

@ -50,20 +50,94 @@ def trace_instance(langfuse_config, monkeypatch: pytest.MonkeyPatch):
def test_init(langfuse_config, monkeypatch: pytest.MonkeyPatch):
from opentelemetry.sdk.trace import TracerProvider
mock_langfuse = MagicMock()
monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.Langfuse", mock_langfuse)
monkeypatch.setenv("FILES_URL", "http://test.url")
instance = LangFuseDataTrace(langfuse_config)
mock_langfuse.assert_called_once_with(
public_key=langfuse_config.public_key,
secret_key=langfuse_config.secret_key,
host=langfuse_config.host,
)
mock_langfuse.assert_called_once()
kwargs = mock_langfuse.call_args.kwargs
assert kwargs["public_key"] == langfuse_config.public_key
assert kwargs["secret_key"] == langfuse_config.secret_key
assert kwargs["host"] == langfuse_config.host
assert isinstance(kwargs["tracer_provider"], TracerProvider)
assert kwargs["tracer_provider"] is instance._tracer_provider
assert instance.file_base_url == "http://test.url"
def test_init_passes_isolated_tracer_provider_to_langfuse(
langfuse_config, monkeypatch: pytest.MonkeyPatch
):
"""Regression test for langfuse v3 SDK side effect.
Without an explicit ``tracer_provider=`` kwarg, the Langfuse v3 SDK
attaches a ``LangfuseSpanProcessor`` to the *global* OpenTelemetry
TracerProvider — siphoning every Flask / Celery / SQLAlchemy span in the
process into the tenant's Langfuse project. See langfuse upgrade-path
docs (v2 -> v3) and GitHub discussion #9136.
The fix is to construct an isolated ``TracerProvider`` and pass it via
``tracer_provider=`` so the SDK never touches the global one.
"""
from opentelemetry import trace as otel_trace_api
from opentelemetry.sdk.trace import TracerProvider
captured: dict[str, object] = {}
def fake_langfuse(**kwargs):
captured.update(kwargs)
return MagicMock()
monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.Langfuse", fake_langfuse)
instance = LangFuseDataTrace(langfuse_config)
# 1. tracer_provider kwarg must be supplied (drives the no-pollution branch
# in langfuse.LangfuseResourceManager._init_tracer_provider).
assert "tracer_provider" in captured, (
"Langfuse() must receive an explicit tracer_provider=; without it the "
"v3 SDK attaches its SpanProcessor to the global OTEL TracerProvider."
)
passed_provider = captured["tracer_provider"]
assert isinstance(passed_provider, TracerProvider)
assert passed_provider is instance._tracer_provider
# 2. The instance's provider must not be the global one.
global_provider = otel_trace_api.get_tracer_provider()
assert passed_provider is not global_provider
def test_close_shuts_down_tracer_provider(langfuse_config, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.Langfuse", lambda **kwargs: MagicMock())
instance = LangFuseDataTrace(langfuse_config)
provider = instance._tracer_provider
provider_shutdown = MagicMock()
monkeypatch.setattr(provider, "shutdown", provider_shutdown)
instance.close()
provider_shutdown.assert_called_once()
assert instance._tracer_provider is None
def test_close_is_idempotent(langfuse_config, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.Langfuse", lambda **kwargs: MagicMock())
instance = LangFuseDataTrace(langfuse_config)
provider_shutdown = MagicMock()
monkeypatch.setattr(instance._tracer_provider, "shutdown", provider_shutdown)
instance.close()
instance.close()
provider_shutdown.assert_called_once()
def test_trace_dispatch(trace_instance, monkeypatch: pytest.MonkeyPatch):
methods = [
"workflow_trace",

View File

@ -3,6 +3,8 @@ import logging
import time
from typing import Any, TypedDict, cast
from sqlalchemy import select
from core.app.app_config.entities import ModelConfig
from core.rag.datasource.retrieval_service import DefaultRetrievalModelDict, RetrievalService
from core.rag.index_processor.constant.query_type import QueryType
@ -13,6 +15,7 @@ from extensions.ext_database import db
from graphon.model_runtime.entities import LLMMode
from models import Account
from models.dataset import Dataset, DatasetQuery
from models.dataset import Document as DatasetDocument
from models.enums import CreatorUserRole, DatasetQuerySource
logger = logging.getLogger(__name__)
@ -41,6 +44,59 @@ class HitTestingRetrievalModelDict(DefaultRetrievalModelDict, total=False):
class HitTestingService:
@staticmethod
def _dump_dataset_document(document: DatasetDocument) -> dict[str, Any]:
return {
"id": document.id,
"data_source_type": document.data_source_type,
"name": document.name,
"doc_type": document.doc_type,
"doc_metadata": document.doc_metadata,
}
@classmethod
def _dump_retrieval_records(cls, records: list[Any]) -> list[dict[str, Any]]:
dumped_records = [record.model_dump() for record in records]
document_ids = {
segment.get("document_id")
for record in dumped_records
if isinstance(record, dict)
for segment in [record.get("segment")]
if isinstance(segment, dict) and segment.get("document_id")
}
if not document_ids:
return dumped_records
documents = {
document.id: cls._dump_dataset_document(document)
for document in db.session.scalars(
select(DatasetDocument).where(DatasetDocument.id.in_(document_ids))
).all()
}
records_with_documents: list[dict[str, Any]] = []
missing_document_ids: set[str] = set()
for record in dumped_records:
segment = record.get("segment")
if not isinstance(segment, dict):
records_with_documents.append(record)
continue
document_id = segment.get("document_id")
if document_id in documents:
segment["document"] = documents[document_id]
records_with_documents.append(record)
elif document_id:
missing_document_ids.add(document_id)
if missing_document_ids:
logger.warning(
"Skipping hit-testing records with missing documents, document_ids=%s",
sorted(missing_document_ids),
)
return records_with_documents
@classmethod
def retrieve(
cls,
@ -174,7 +230,7 @@ class HitTestingService:
"query": {
"content": query,
},
"records": [record.model_dump() for record in records],
"records": cls._dump_retrieval_records(records),
}
@classmethod

View File

@ -120,7 +120,7 @@ class TestParseArgs:
class TestPerformHitTesting:
def test_success(self, dataset):
response = {
"query": "hello",
"query": {"content": "hello"},
"records": [],
}
@ -134,7 +134,7 @@ class TestPerformHitTesting:
assert result["query"] == "hello"
assert result["records"] == []
def test_success_normalizes_legacy_query_and_nullable_list_fields(self, dataset):
def test_success_prepares_nullable_list_fields(self, dataset):
response = {
"query": {"content": "hello"},
"records": [
@ -170,6 +170,18 @@ class TestPerformHitTesting:
}
]
def test_invalid_query_response_raises_value_error(self):
with pytest.raises(ValueError, match="Invalid hit testing query response"):
DatasetsHitTestingBase._extract_hit_testing_query("hello")
def test_invalid_records_response_raises_value_error(self):
with pytest.raises(ValueError, match="Invalid hit testing records response"):
DatasetsHitTestingBase._prepare_hit_testing_records({"records": []})
def test_invalid_record_response_raises_value_error(self):
with pytest.raises(ValueError, match="Invalid hit testing record response"):
DatasetsHitTestingBase._prepare_hit_testing_records(["record"])
def test_index_not_initialized(self, dataset):
with patch.object(
HitTestingService,

View File

@ -103,7 +103,7 @@ class TestHitTestingApiPost:
mock_dataset_svc.get_dataset.return_value = mock_dataset
mock_dataset_svc.check_dataset_permission.return_value = None
mock_hit_svc.retrieve.return_value = {"query": "test query", "records": []}
mock_hit_svc.retrieve.return_value = {"query": {"content": "test query"}, "records": []}
mock_hit_svc.hit_testing_args_check.return_value = None
mock_marshal.return_value = []
@ -149,7 +149,7 @@ class TestHitTestingApiPost:
"score_threshold": 0.8,
}
mock_hit_svc.retrieve.return_value = {"query": "complex query", "records": []}
mock_hit_svc.retrieve.return_value = {"query": {"content": "complex query"}, "records": []}
mock_hit_svc.hit_testing_args_check.return_value = None
mock_marshal.return_value = []
@ -194,7 +194,7 @@ class TestHitTestingApiPost:
mock_dataset_svc.get_dataset.return_value = mock_dataset
mock_dataset_svc.check_dataset_permission.return_value = None
mock_hit_svc.retrieve.return_value = {"query": "filtered query", "records": []}
mock_hit_svc.retrieve.return_value = {"query": {"content": "filtered query"}, "records": []}
mock_hit_svc.hit_testing_args_check.return_value = None
mock_marshal.return_value = []
@ -232,7 +232,7 @@ class TestHitTestingApiPost:
@patch("controllers.console.datasets.hit_testing_base.HitTestingService")
@patch("controllers.console.datasets.hit_testing_base.DatasetService")
@patch("controllers.console.datasets.hit_testing_base.current_user", new_callable=lambda: Mock(spec=Account))
def test_post_normalizes_legacy_query_and_nullable_list_fields(
def test_post_prepares_nullable_list_fields(
self,
mock_current_user,
mock_dataset_svc,
@ -241,7 +241,7 @@ class TestHitTestingApiPost:
mock_ns,
app,
):
"""Test service API normalizes legacy query shape and nullable list fields."""
"""Test service API prepares nullable list fields from marshalled records."""
dataset_id = str(uuid.uuid4())
tenant_id = str(uuid.uuid4())

View File

@ -0,0 +1,88 @@
from unittest.mock import Mock, patch
from services.hit_testing_service import HitTestingService
def _retrieval_record(payload: dict):
record = Mock()
record.model_dump.return_value = payload
return record
def _dataset_document(
document_id: str = "document-1",
name: str = "guide.md",
data_source_type: str = "upload_file",
doc_type: str | None = None,
doc_metadata: dict | None = None,
):
document = Mock()
document.id = document_id
document.name = name
document.data_source_type = data_source_type
document.doc_type = doc_type
document.doc_metadata = doc_metadata
return document
class TestHitTestingServiceDumpRecords:
def test_dump_dataset_document_returns_frontend_required_fields(self):
document = _dataset_document(doc_metadata={"source": "manual"})
assert HitTestingService._dump_dataset_document(document) == {
"id": "document-1",
"data_source_type": "upload_file",
"name": "guide.md",
"doc_type": None,
"doc_metadata": {"source": "manual"},
}
def test_dump_retrieval_records_returns_dumped_records_without_document_ids(self):
record = _retrieval_record({"segment": None, "score": 0.95})
assert HitTestingService._dump_retrieval_records([record]) == [{"segment": None, "score": 0.95}]
def test_dump_retrieval_records_injects_documents_and_keeps_non_segment_records(self):
record_without_segment = _retrieval_record({"segment": None, "score": 0.95})
record_with_document = _retrieval_record(
{
"segment": {
"id": "segment-1",
"document_id": "document-1",
},
"score": 0.9,
}
)
scalars_result = Mock()
scalars_result.all.return_value = [_dataset_document()]
with patch("services.hit_testing_service.db.session.scalars", return_value=scalars_result):
result = HitTestingService._dump_retrieval_records([record_without_segment, record_with_document])
assert result[0] == {"segment": None, "score": 0.95}
assert result[1]["segment"]["document"] == {
"id": "document-1",
"data_source_type": "upload_file",
"name": "guide.md",
"doc_type": None,
"doc_metadata": None,
}
def test_dump_retrieval_records_skips_records_with_missing_documents(self, caplog):
record = _retrieval_record(
{
"segment": {
"id": "segment-1",
"document_id": "missing-document",
},
"score": 0.95,
}
)
scalars_result = Mock()
scalars_result.all.return_value = []
with patch("services.hit_testing_service.db.session.scalars", return_value=scalars_result):
result = HitTestingService._dump_retrieval_records([record])
assert result == []
assert "Skipping hit-testing records with missing documents" in caplog.text

View File

@ -57,7 +57,7 @@ const InstallBundle: FC<Props> = ({
foldAnimInto()
}}
>
<DialogContent className={cn('relative w-full max-w-[480px] overflow-hidden! text-left align-middle', cn(modalClassName, 'shadows-shadow-xl flex min-w-[560px] flex-col items-start rounded-2xl border-[0.5px] border-components-panel-border bg-components-panel-bg p-0'))}>
<DialogContent className={cn('w-full max-w-[480px] overflow-hidden! text-left align-middle', cn(modalClassName, 'shadows-shadow-xl flex min-w-[560px] flex-col items-start rounded-2xl border-[0.5px] border-components-panel-border bg-components-panel-bg p-0'))}>
<DialogCloseButton />
<div className="flex items-start gap-2 self-stretch pt-6 pr-14 pb-3 pl-6">

View File

@ -77,7 +77,7 @@ const PublishAsKnowledgePipelineModal = ({
return (
<>
<Dialog open>
<DialogContent className="relative w-full max-w-[480px]! overflow-hidden! border-none p-0! text-left align-middle">
<DialogContent className="w-full max-w-[480px]! overflow-hidden! border-none p-0! text-left align-middle">
<div className="relative flex items-center p-6 pr-14 pb-3 title-2xl-semi-bold text-text-primary">
{t('common.publishAs', { ns: 'pipeline' })}

View File

@ -480,7 +480,9 @@ describe('publisher', () => {
it('should show publish as knowledge pipeline modal when permitted', async () => {
mockPublishedAt.mockReturnValue(1700000000)
mockIsAllowPublishAsCustomKnowledgePipelineTemplate.mockReturnValue(true)
renderWithQueryClient(<Popup />)
renderWithQueryClient(<Publisher />)
fireEvent.click(screen.getByText('workflow.common.publish'))
const publishAsButton = screen.getAllByRole('button').find(btn =>
btn.textContent?.includes('pipeline.common.publishAs'),
@ -495,7 +497,9 @@ describe('publisher', () => {
it('should close publish as knowledge pipeline modal when cancel is clicked', async () => {
mockPublishedAt.mockReturnValue(1700000000)
mockIsAllowPublishAsCustomKnowledgePipelineTemplate.mockReturnValue(true)
renderWithQueryClient(<Popup />)
renderWithQueryClient(<Publisher />)
fireEvent.click(screen.getByText('workflow.common.publish'))
const publishAsButton = screen.getAllByRole('button').find(btn =>
btn.textContent?.includes('pipeline.common.publishAs'),
@ -516,7 +520,9 @@ describe('publisher', () => {
it('should call publishAsCustomizedPipeline when confirm is clicked in modal', async () => {
mockPublishedAt.mockReturnValue(1700000000)
mockPublishAsCustomizedPipeline.mockResolvedValue({})
renderWithQueryClient(<Popup />)
renderWithQueryClient(<Publisher />)
fireEvent.click(screen.getByText('workflow.common.publish'))
const publishAsButton = screen.getAllByRole('button').find(btn =>
btn.textContent?.includes('pipeline.common.publishAs'),
@ -538,6 +544,35 @@ describe('publisher', () => {
})
})
})
it('should publish as template with empty pipeline id fallback', async () => {
mockPublishedAt.mockReturnValue(1700000000)
mockPipelineId.mockReturnValue(undefined as unknown as string)
mockPublishAsCustomizedPipeline.mockResolvedValue({})
renderWithQueryClient(<Publisher />)
fireEvent.click(screen.getByText('workflow.common.publish'))
const publishAsButton = screen.getAllByRole('button').find(btn =>
btn.textContent?.includes('pipeline.common.publishAs'),
)
fireEvent.click(publishAsButton!)
await waitFor(() => {
expect(screen.getByTestId('publish-as-knowledge-pipeline-modal')).toBeInTheDocument()
})
fireEvent.click(screen.getByTestId('modal-confirm'))
await waitFor(() => {
expect(mockPublishAsCustomizedPipeline).toHaveBeenCalledWith({
pipelineId: '',
name: 'Test Pipeline',
icon_info: { type: 'emoji', emoji: '📚', background: '#fff' },
description: 'Test description',
})
})
})
})
describe('API Calls and Async Operations', () => {
@ -607,7 +642,9 @@ describe('publisher', () => {
it('should show success notification for publish as template', async () => {
mockPublishedAt.mockReturnValue(1700000000)
mockPublishAsCustomizedPipeline.mockResolvedValue({})
renderWithQueryClient(<Popup />)
renderWithQueryClient(<Publisher />)
fireEvent.click(screen.getByText('workflow.common.publish'))
const publishAsButton = screen.getAllByRole('button').find(btn =>
btn.textContent?.includes('pipeline.common.publishAs'),
@ -633,7 +670,9 @@ describe('publisher', () => {
it('should invalidate customized template list after publish as template', async () => {
mockPublishedAt.mockReturnValue(1700000000)
mockPublishAsCustomizedPipeline.mockResolvedValue({})
renderWithQueryClient(<Popup />)
renderWithQueryClient(<Publisher />)
fireEvent.click(screen.getByText('workflow.common.publish'))
const publishAsButton = screen.getAllByRole('button').find(btn =>
btn.textContent?.includes('pipeline.common.publishAs'),
@ -686,7 +725,9 @@ describe('publisher', () => {
it('should show error notification when publish as template fails', async () => {
mockPublishedAt.mockReturnValue(1700000000)
mockPublishAsCustomizedPipeline.mockRejectedValue(new Error('Template publish failed'))
renderWithQueryClient(<Popup />)
renderWithQueryClient(<Publisher />)
fireEvent.click(screen.getByText('workflow.common.publish'))
const publishAsButton = screen.getAllByRole('button').find(btn =>
btn.textContent?.includes('pipeline.common.publishAs'),
@ -710,7 +751,9 @@ describe('publisher', () => {
it('should close modal after publish as template error', async () => {
mockPublishedAt.mockReturnValue(1700000000)
mockPublishAsCustomizedPipeline.mockRejectedValue(new Error('Template publish failed'))
renderWithQueryClient(<Popup />)
renderWithQueryClient(<Publisher />)
fireEvent.click(screen.getByText('workflow.common.publish'))
const publishAsButton = screen.getAllByRole('button').find(btn =>
btn.textContent?.includes('pipeline.common.publishAs'),
@ -1051,7 +1094,9 @@ describe('publisher', () => {
it('should complete full publish as template flow', async () => {
mockPublishedAt.mockReturnValue(1700000000)
mockPublishAsCustomizedPipeline.mockResolvedValue({})
renderWithQueryClient(<Popup />)
renderWithQueryClient(<Publisher />)
fireEvent.click(screen.getByText('workflow.common.publish'))
const publishAsButton = screen.getAllByRole('button').find(btn =>
btn.textContent?.includes('pipeline.common.publishAs'),

View File

@ -327,11 +327,18 @@ describe('Popup', () => {
it('should request closing the outer popover before opening publish-as modal', () => {
const onRequestClose = vi.fn()
render(<Popup onRequestClose={onRequestClose} />)
const onShowPublishAsKnowledgePipelineModal = vi.fn()
render(
<Popup
onRequestClose={onRequestClose}
onShowPublishAsKnowledgePipelineModal={onShowPublishAsKnowledgePipelineModal}
/>,
)
fireEvent.click(screen.getByText('pipeline.common.publishAs'))
expect(onRequestClose).toHaveBeenCalledTimes(1)
expect(onShowPublishAsKnowledgePipelineModal).toHaveBeenCalledTimes(1)
})
})
@ -352,27 +359,6 @@ describe('Popup', () => {
})
})
describe('Publish params', () => {
it('should publish as template with empty pipeline id fallback', async () => {
mockPipelineId = undefined
mockUseBoolean
.mockImplementationOnce((initial: boolean) => [initial, { setFalse: vi.fn(), setTrue: vi.fn() }])
.mockImplementationOnce((initial: boolean) => [initial, { setFalse: vi.fn(), setTrue: vi.fn() }])
.mockImplementationOnce(() => [true, { setFalse: vi.fn(), setTrue: vi.fn() }])
.mockImplementationOnce((initial: boolean) => [initial, { setFalse: vi.fn(), setTrue: vi.fn() }])
render(<Popup />)
fireEvent.click(screen.getByTestId('publish-as-confirm'))
expect(mockPublishAsCustomizedPipeline).toHaveBeenCalledWith({
pipelineId: '',
name: 'My Pipeline',
icon_info: { icon_type: 'emoji' },
description: 'desc',
})
})
})
describe('Time formatting', () => {
it('should format published time', () => {
render(<Popup />)

View File

@ -1,6 +1,8 @@
import type { IconInfo } from '@/models/datasets'
import { Button } from '@langgenius/dify-ui/button'
import { cn } from '@langgenius/dify-ui/cn'
import { Popover, PopoverContent, PopoverTrigger } from '@langgenius/dify-ui/popover'
import { toast } from '@langgenius/dify-ui/toast'
import { RiArrowDownSLine } from '@remixicon/react'
import { useBoolean } from 'ahooks'
import {
@ -10,6 +12,11 @@ import {
} from 'react'
import { useTranslation } from 'react-i18next'
import { useNodesSyncDraft } from '@/app/components/workflow/hooks'
import { useStore } from '@/app/components/workflow/store'
import { useDocLink } from '@/context/i18n'
import Link from '@/next/link'
import { useInvalidCustomizedTemplateList, usePublishAsCustomizedPipeline } from '@/service/use-pipeline'
import PublishAsKnowledgePipelineModal from '../../publish-as-knowledge-pipeline-modal'
import Popup from './popup'
const Publisher = () => {
@ -17,6 +24,12 @@ const Publisher = () => {
const [open, setOpen] = useState(false)
const [confirmVisible, { setFalse: hideConfirm, setTrue: showConfirm }] = useBoolean(false)
const { handleSyncWorkflowDraft } = useNodesSyncDraft()
const docLink = useDocLink()
const pipelineId = useStore(s => s.pipelineId)
const { mutateAsync: publishAsCustomizedPipeline } = usePublishAsCustomizedPipeline()
const invalidCustomizedTemplateList = useInvalidCustomizedTemplateList()
const [showPublishAsKnowledgePipelineModal, setShowPublishAsKnowledgePipelineModal] = useState(false)
const [isPublishingAsCustomizedPipeline, setIsPublishingAsCustomizedPipeline] = useState(false)
const handleOpenChange = useCallback((newOpen: boolean) => {
if (!newOpen && confirmVisible)
@ -28,38 +41,86 @@ const Publisher = () => {
const closePopover = useCallback(() => {
setOpen(false)
}, [])
const openPublishAsKnowledgePipelineModal = useCallback(() => {
setShowPublishAsKnowledgePipelineModal(true)
}, [])
const hidePublishAsKnowledgePipelineModal = useCallback(() => {
setShowPublishAsKnowledgePipelineModal(false)
}, [])
const handlePublishAsKnowledgePipeline = useCallback(async (name: string, icon: IconInfo, description?: string) => {
try {
setIsPublishingAsCustomizedPipeline(true)
await publishAsCustomizedPipeline({
pipelineId: pipelineId || '',
name,
icon_info: icon,
description,
})
toast.success(t('publishTemplate.success.message', { ns: 'datasetPipeline' }), {
description: (
<div className="flex flex-col gap-y-1">
<span className="system-xs-regular text-text-secondary">
{t('publishTemplate.success.tip', { ns: 'datasetPipeline' })}
</span>
<Link href={docLink()} target="_blank" className="inline-block system-xs-medium-uppercase text-text-accent">
{t('publishTemplate.success.learnMore', { ns: 'datasetPipeline' })}
</Link>
</div>
),
})
invalidCustomizedTemplateList()
}
catch {
toast.error(t('publishTemplate.error.message', { ns: 'datasetPipeline' }))
}
finally {
setIsPublishingAsCustomizedPipeline(false)
hidePublishAsKnowledgePipelineModal()
}
}, [docLink, hidePublishAsKnowledgePipelineModal, invalidCustomizedTemplateList, pipelineId, publishAsCustomizedPipeline, t])
return (
<Popover
open={open}
onOpenChange={handleOpenChange}
>
<PopoverTrigger
nativeButton
render={(
<Button
className="px-2"
variant="primary"
>
<span className="pl-1">{t('common.publish', { ns: 'workflow' })}</span>
<RiArrowDownSLine className="h-4 w-4" />
</Button>
)}
/>
<PopoverContent
placement="bottom-end"
sideOffset={4}
alignOffset={40}
popupClassName={cn('border-none bg-transparent shadow-none', confirmVisible && 'hidden')}
<>
<Popover
open={open}
onOpenChange={handleOpenChange}
>
<Popup
onRequestClose={closePopover}
confirmVisible={confirmVisible}
onShowConfirm={showConfirm}
onHideConfirm={hideConfirm}
<PopoverTrigger
nativeButton
render={(
<Button
className="px-2"
variant="primary"
>
<span className="pl-1">{t('common.publish', { ns: 'workflow' })}</span>
<RiArrowDownSLine className="h-4 w-4" />
</Button>
)}
/>
</PopoverContent>
</Popover>
<PopoverContent
placement="bottom-end"
sideOffset={4}
alignOffset={40}
popupClassName={cn('border-none bg-transparent shadow-none', confirmVisible && 'hidden')}
>
<Popup
onRequestClose={closePopover}
confirmVisible={confirmVisible}
onShowConfirm={showConfirm}
onHideConfirm={hideConfirm}
isPublishingAsCustomizedPipeline={isPublishingAsCustomizedPipeline}
onShowPublishAsKnowledgePipelineModal={openPublishAsKnowledgePipelineModal}
/>
</PopoverContent>
</Popover>
{showPublishAsKnowledgePipelineModal && (
<PublishAsKnowledgePipelineModal
confirmDisabled={isPublishingAsCustomizedPipeline}
onConfirm={handlePublishAsKnowledgePipeline}
onCancel={hidePublishAsKnowledgePipelineModal}
/>
)}
</>
)
}

View File

@ -1,4 +1,3 @@
import type { IconInfo } from '@/models/datasets'
import type { PublishWorkflowParams } from '@/types/workflow'
import {
AlertDialog,
@ -25,7 +24,6 @@ import ShortcutsName from '@/app/components/workflow/shortcuts-name'
import { useStore, useWorkflowStore } from '@/app/components/workflow/store'
import { getKeyboardKeyCodeBySystem } from '@/app/components/workflow/utils'
import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail'
import { useDocLink } from '@/context/i18n'
import { useModalContextSelector } from '@/context/modal-context'
import { useProviderContextSelector } from '@/context/provider-context'
import { useDatasetApiAccessUrl } from '@/hooks/use-api-access-url'
@ -34,9 +32,8 @@ import Link from '@/next/link'
import { useParams, useRouter } from '@/next/navigation'
import { useInvalidDatasetList } from '@/service/knowledge/use-dataset'
import { useInvalid } from '@/service/use-base'
import { publishedPipelineInfoQueryKeyPrefix, useInvalidCustomizedTemplateList, usePublishAsCustomizedPipeline } from '@/service/use-pipeline'
import { publishedPipelineInfoQueryKeyPrefix } from '@/service/use-pipeline'
import { usePublishWorkflow } from '@/service/use-workflow'
import PublishAsKnowledgePipelineModal from '../../publish-as-knowledge-pipeline-modal'
const PUBLISH_SHORTCUT = ['ctrl', '⇧', 'P']
type PopupProps = {
@ -44,6 +41,8 @@ type PopupProps = {
confirmVisible?: boolean
onShowConfirm?: () => void
onHideConfirm?: () => void
isPublishingAsCustomizedPipeline?: boolean
onShowPublishAsKnowledgePipelineModal?: () => void
}
const Popup = ({
@ -51,11 +50,12 @@ const Popup = ({
confirmVisible: controlledConfirmVisible,
onShowConfirm,
onHideConfirm,
isPublishingAsCustomizedPipeline = false,
onShowPublishAsKnowledgePipelineModal,
}: PopupProps) => {
const { t } = useTranslation()
const { datasetId } = useParams()
const { push } = useRouter()
const docLink = useDocLink()
const publishedAt = useStore(s => s.publishedAt)
const draftUpdatedAt = useStore(s => s.draftUpdatedAt)
const pipelineId = useStore(s => s.pipelineId)
@ -73,9 +73,6 @@ const Popup = ({
const showConfirm = onShowConfirm ?? showLocalConfirm
const hideConfirm = onHideConfirm ?? hideLocalConfirm
const [publishing, { setFalse: hidePublishing, setTrue: showPublishing }] = useBoolean(false)
const { mutateAsync: publishAsCustomizedPipeline } = usePublishAsCustomizedPipeline()
const [showPublishAsKnowledgePipelineModal, { setFalse: hidePublishAsKnowledgePipelineModal, setTrue: setShowPublishAsKnowledgePipelineModal }] = useBoolean(false)
const [isPublishingAsCustomizedPipeline, { setFalse: hidePublishingAsCustomizedPipeline, setTrue: showPublishingAsCustomizedPipeline }] = useBoolean(false)
const invalidPublishedPipelineInfo = useInvalid([...publishedPipelineInfoQueryKeyPrefix, pipelineId])
const invalidDatasetList = useInvalidDatasetList()
const handleHideConfirm = useCallback(() => {
@ -145,47 +142,15 @@ const Popup = ({
const goToAddDocuments = useCallback(() => {
push(`/datasets/${datasetId}/documents/create-from-pipeline`)
}, [datasetId, push])
const invalidCustomizedTemplateList = useInvalidCustomizedTemplateList()
const handlePublishAsKnowledgePipeline = useCallback(async (name: string, icon: IconInfo, description?: string) => {
try {
showPublishingAsCustomizedPipeline()
await publishAsCustomizedPipeline({
pipelineId: pipelineId || '',
name,
icon_info: icon,
description,
})
toast.success(t('publishTemplate.success.message', { ns: 'datasetPipeline' }), {
description: (
<div className="flex flex-col gap-y-1">
<span className="system-xs-regular text-text-secondary">
{t('publishTemplate.success.tip', { ns: 'datasetPipeline' })}
</span>
<Link href={docLink()} target="_blank" className="inline-block system-xs-medium-uppercase text-text-accent">
{t('publishTemplate.success.learnMore', { ns: 'datasetPipeline' })}
</Link>
</div>
),
})
invalidCustomizedTemplateList()
}
catch {
toast.error(t('publishTemplate.error.message', { ns: 'datasetPipeline' }))
}
finally {
hidePublishingAsCustomizedPipeline()
hidePublishAsKnowledgePipelineModal()
}
}, [showPublishingAsCustomizedPipeline, publishAsCustomizedPipeline, pipelineId, t, invalidCustomizedTemplateList, hidePublishingAsCustomizedPipeline, hidePublishAsKnowledgePipelineModal, docLink])
const handleClickPublishAsKnowledgePipeline = useCallback(() => {
onRequestClose?.()
if (!isAllowPublishAsCustomKnowledgePipelineTemplate) {
setShowPricingModal()
}
else {
setShowPublishAsKnowledgePipelineModal()
onShowPublishAsKnowledgePipelineModal?.()
}
}, [isAllowPublishAsCustomKnowledgePipelineTemplate, onRequestClose, setShowPublishAsKnowledgePipelineModal, setShowPricingModal])
}, [isAllowPublishAsCustomKnowledgePipelineTemplate, onRequestClose, onShowPublishAsKnowledgePipelineModal, setShowPricingModal])
return (
<div className={cn('rounded-2xl border-[0.5px] border-components-panel-border bg-components-panel-bg shadow-xl shadow-shadow-shadow-5', isAllowPublishAsCustomKnowledgePipelineTemplate ? 'w-[360px]' : 'w-[400px]')}>
<div className="p-4 pt-3">
@ -279,7 +244,6 @@ const Popup = ({
</AlertDialogActions>
</AlertDialogContent>
</AlertDialog>
{showPublishAsKnowledgePipelineModal && (<PublishAsKnowledgePipelineModal confirmDisabled={isPublishingAsCustomizedPipeline} onConfirm={handlePublishAsKnowledgePipeline} onCancel={hidePublishAsKnowledgePipelineModal} />)}
</div>
)
}