Compare commits

..

17 Commits

Author SHA1 Message Date
753dc8752d feat: support relative mode for message clean command 2026-03-02 14:45:52 +08:00
aca3d1900e chore(deps-dev): update types-aiofiles requirement from ~=24.1.0 to ~=25.1.0 in /api (#32803)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-02 13:59:53 +09:00
ddc47c2f39 chore(deps): update pyjwt requirement from ~=2.10.1 to ~=2.11.0 in /api (#32804)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-02 13:59:06 +09:00
2dc9bc00d6 refactor(ci): use diff -u for pyrefly diff output (#32813)
Co-authored-by: Stable Genius <259448942+stablegenius49@users.noreply.github.com>
2026-03-02 13:58:13 +09:00
335b500aea test: add unit tests for base components (#32818)
Co-authored-by: CodingOnStar <hanxujiang@dify.com>
2026-03-02 11:40:43 +08:00
8cc775d9f2 fix: optimize workflow_run iter query (#32815) 2026-03-02 11:01:11 +08:00
yyh
1a33903887 feat(web): add root isolation layer for portal stacking context (#32807) 2026-03-02 10:59:56 +08:00
00dbaef04f fix: use declared_attr.directive for WorkflowNodeExecutionModel.__table_args__ (#32656)
Signed-off-by: edvatar <88481784+toroleapinc@users.noreply.github.com>
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
2026-03-02 10:15:06 +08:00
248202c220 fix: remove references to non-existent Document attributes in test (#32654)
Signed-off-by: edvatar <88481784+toroleapinc@users.noreply.github.com>
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
2026-03-02 10:14:56 +08:00
691c9911c7 fix(ci): make pyrefly diff comments focus on diagnostics (#32778) 2026-03-02 10:11:23 +08:00
baeea77c5b fix: typo in WebScraper plugin description: "Scrapper" → "Scraper" (#32790)
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2026-03-02 10:10:28 +08:00
9da98e6c6c fix: fix import error (#32800) 2026-03-02 08:59:53 +08:00
99
a01de98721 refactor(workflow): decouple start node external dependencies (#32793) 2026-03-02 02:01:41 +08:00
17c1538e03 refactor(workflow): move PromptMessageMemory to model_runtime.memory (#32796)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-02 01:58:02 +08:00
69b3e94630 refactor: inject workflow node memory via protocol (#32784) 2026-03-02 01:55:49 +08:00
ef2b5d6107 refactor(api): move llm quota deduction to app graph layer (#32786) 2026-03-01 23:25:36 +08:00
fa4b8910c8 chore: support code-inspector for vinext (#32788) 2026-03-01 20:27:57 +08:00
497 changed files with 2049 additions and 1266 deletions

View File

@ -29,20 +29,26 @@ jobs:
- name: Install dependencies
run: uv sync --project api --dev
- name: Prepare diagnostics extractor
run: |
git show ${{ github.event.pull_request.head.sha }}:api/libs/pyrefly_diagnostics.py > /tmp/pyrefly_diagnostics.py
- name: Run pyrefly on PR branch
run: |
uv run --directory api pyrefly check > /tmp/pyrefly_pr.txt 2>&1 || true
uv run --directory api --dev pyrefly check 2>&1 \
| uv run --directory api python /tmp/pyrefly_diagnostics.py > /tmp/pyrefly_pr.txt || true
- name: Checkout base branch
run: git checkout ${{ github.base_ref }}
- name: Run pyrefly on base branch
run: |
uv run --directory api pyrefly check > /tmp/pyrefly_base.txt 2>&1 || true
uv run --directory api --dev pyrefly check 2>&1 \
| uv run --directory api python /tmp/pyrefly_diagnostics.py > /tmp/pyrefly_base.txt || true
- name: Compute diff
run: |
diff /tmp/pyrefly_base.txt /tmp/pyrefly_pr.txt > pyrefly_diff.txt || true
diff -u /tmp/pyrefly_base.txt /tmp/pyrefly_pr.txt > pyrefly_diff.txt || true
- name: Save PR number
run: |

View File

@ -29,6 +29,8 @@ ignore_imports =
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory
core.workflow.nodes.loop.loop_node -> core.app.workflow.node_factory
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.layers.llm_quota
core.workflow.nodes.loop.loop_node -> core.app.workflow.layers.llm_quota
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_engine
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph
@ -52,7 +54,6 @@ ignore_imports =
core.workflow.nodes.agent.agent_node -> extensions.ext_database
core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database
core.workflow.nodes.llm.file_saver -> extensions.ext_database
core.workflow.nodes.llm.llm_utils -> extensions.ext_database
core.workflow.nodes.llm.node -> extensions.ext_database
core.workflow.nodes.tool.tool_node -> extensions.ext_database
# TODO(QuantumGhost): use DI to avoid depending on global DB.
@ -91,7 +92,7 @@ forbidden_modules =
core.moderation
core.ops
core.plugin
core.model_runtime.prompt
core.prompt
core.provider_manager
core.rag
core.repositories
@ -107,14 +108,11 @@ ignore_imports =
core.workflow.nodes.agent.agent_node -> core.tools.tool_manager
core.workflow.nodes.document_extractor.node -> core.helper.ssrf_proxy
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.layers.llm_quota
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.index_processor.index_processor_factory
core.workflow.nodes.llm.llm_utils -> configs
core.workflow.nodes.llm.llm_utils -> core.model_manager
core.workflow.nodes.llm.protocols -> core.model_manager
core.workflow.nodes.llm.llm_utils -> core.model_runtime.model_providers.__base.large_language_model
core.workflow.nodes.llm.llm_utils -> models.model
core.workflow.nodes.llm.llm_utils -> models.provider
core.workflow.nodes.llm.llm_utils -> services.credit_pool_service
core.workflow.nodes.llm.node -> core.tools.signature
core.workflow.nodes.tool.tool_node -> core.callback_handler.workflow_tool_callback_handler
core.workflow.nodes.tool.tool_node -> core.tools.tool_engine
@ -127,16 +125,14 @@ ignore_imports =
core.workflow.nodes.human_input.human_input_node -> core.app.entities.app_invoke_entities
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.app.entities.app_invoke_entities
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.app_config.entities
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_runtime.prompt.advanced_prompt_transform
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_runtime.prompt.simple_prompt_transform
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.advanced_prompt_transform
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.simple_prompt_transform
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_runtime.model_providers.__base.large_language_model
core.workflow.nodes.question_classifier.question_classifier_node -> core.model_runtime.prompt.simple_prompt_transform
core.workflow.nodes.start.entities -> core.app.app_config.entities
core.workflow.nodes.start.start_node -> core.app.app_config.entities
core.workflow.nodes.question_classifier.question_classifier_node -> core.prompt.simple_prompt_transform
core.workflow.workflow_entry -> core.app.apps.exc
core.workflow.workflow_entry -> core.app.entities.app_invoke_entities
core.workflow.workflow_entry -> core.app.workflow.layers.llm_quota
core.workflow.workflow_entry -> core.app.workflow.node_factory
core.workflow.nodes.llm.llm_utils -> core.entities.provider_entities
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager
core.workflow.nodes.question_classifier.question_classifier_node -> core.model_manager
core.workflow.nodes.tool.tool_node -> core.tools.utils.message_transformer
@ -148,16 +144,15 @@ ignore_imports =
core.workflow.nodes.llm.node -> core.llm_generator.output_parser.errors
core.workflow.nodes.llm.node -> core.llm_generator.output_parser.structured_output
core.workflow.nodes.llm.node -> core.model_manager
core.workflow.nodes.agent.entities -> core.model_runtime.prompt.entities.advanced_prompt_entities
core.workflow.nodes.llm.entities -> core.model_runtime.prompt.entities.advanced_prompt_entities
core.workflow.nodes.llm.llm_utils -> core.model_runtime.prompt.entities.advanced_prompt_entities
core.workflow.nodes.llm.node -> core.model_runtime.prompt.entities.advanced_prompt_entities
core.workflow.nodes.llm.node -> core.model_runtime.prompt.utils.prompt_message_util
core.workflow.nodes.parameter_extractor.entities -> core.model_runtime.prompt.entities.advanced_prompt_entities
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_runtime.prompt.entities.advanced_prompt_entities
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_runtime.prompt.utils.prompt_message_util
core.workflow.nodes.question_classifier.entities -> core.model_runtime.prompt.entities.advanced_prompt_entities
core.workflow.nodes.question_classifier.question_classifier_node -> core.model_runtime.prompt.utils.prompt_message_util
core.workflow.nodes.agent.entities -> core.prompt.entities.advanced_prompt_entities
core.workflow.nodes.llm.entities -> core.prompt.entities.advanced_prompt_entities
core.workflow.nodes.llm.node -> core.prompt.entities.advanced_prompt_entities
core.workflow.nodes.llm.node -> core.prompt.utils.prompt_message_util
core.workflow.nodes.parameter_extractor.entities -> core.prompt.entities.advanced_prompt_entities
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.entities.advanced_prompt_entities
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.utils.prompt_message_util
core.workflow.nodes.question_classifier.entities -> core.prompt.entities.advanced_prompt_entities
core.workflow.nodes.question_classifier.question_classifier_node -> core.prompt.utils.prompt_message_util
core.workflow.nodes.knowledge_index.entities -> core.rag.retrieval.retrieval_methods
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.retrieval.retrieval_methods
core.workflow.nodes.knowledge_index.knowledge_index_node -> models.dataset
@ -172,7 +167,6 @@ ignore_imports =
core.workflow.nodes.agent.agent_node -> extensions.ext_database
core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database
core.workflow.nodes.llm.file_saver -> extensions.ext_database
core.workflow.nodes.llm.llm_utils -> extensions.ext_database
core.workflow.nodes.llm.node -> extensions.ext_database
core.workflow.nodes.tool.tool_node -> extensions.ext_database
core.workflow.nodes.human_input.human_input_node -> extensions.ext_database
@ -180,7 +174,7 @@ ignore_imports =
core.workflow.workflow_entry -> extensions.otel.runtime
core.workflow.nodes.agent.agent_node -> models
core.workflow.nodes.base.node -> models.enums
core.workflow.nodes.llm.llm_utils -> models.provider_ids
core.workflow.nodes.loop.loop_node -> core.app.workflow.layers.llm_quota
core.workflow.nodes.llm.node -> models.model
core.workflow.workflow_entry -> models.enums
core.workflow.nodes.agent.agent_node -> services
@ -190,12 +184,7 @@ ignore_imports =
name = Model Runtime Internal Imports
type = forbidden
source_modules =
core.model_runtime.callbacks
core.model_runtime.entities
core.model_runtime.errors
core.model_runtime.model_providers
core.model_runtime.schema_validators
core.model_runtime.utils
core.model_runtime
forbidden_modules =
configs
controllers
@ -225,7 +214,7 @@ forbidden_modules =
core.moderation
core.ops
core.plugin
core.model_runtime.prompt
core.prompt
core.provider_manager
core.rag
core.repositories

View File

@ -2598,15 +2598,29 @@ def migrate_oss(
@click.option(
"--start-from",
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
required=True,
required=False,
default=None,
help="Lower bound (inclusive) for created_at.",
)
@click.option(
"--end-before",
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
required=True,
required=False,
default=None,
help="Upper bound (exclusive) for created_at.",
)
@click.option(
"--from-days-ago",
type=int,
default=None,
help="Relative lower bound in days ago (inclusive). Must be used with --before-days.",
)
@click.option(
"--before-days",
type=int,
default=None,
help="Relative upper bound in days ago (exclusive). Required for relative mode.",
)
@click.option("--batch-size", default=1000, show_default=True, help="Batch size for selecting messages.")
@click.option(
"--graceful-period",
@ -2618,8 +2632,10 @@ def migrate_oss(
def clean_expired_messages(
batch_size: int,
graceful_period: int,
start_from: datetime.datetime,
end_before: datetime.datetime,
start_from: datetime.datetime | None,
end_before: datetime.datetime | None,
from_days_ago: int | None,
before_days: int | None,
dry_run: bool,
):
"""
@ -2630,18 +2646,64 @@ def clean_expired_messages(
start_at = time.perf_counter()
try:
abs_mode = start_from is not None and end_before is not None
rel_mode = before_days is not None
if abs_mode and rel_mode:
raise click.UsageError(
"Options are mutually exclusive: use either (--start-from,--end-before) "
"or (--from-days-ago,--before-days)."
)
if from_days_ago is not None and before_days is None:
raise click.UsageError("--from-days-ago must be used together with --before-days.")
if (start_from is None) ^ (end_before is None):
raise click.UsageError("Both --start-from and --end-before are required when using absolute time range.")
if not abs_mode and not rel_mode:
raise click.UsageError(
"You must provide either (--start-from,--end-before) or (--before-days [--from-days-ago])."
)
if rel_mode:
if before_days < 0:
raise click.UsageError("--before-days must be >= 0.")
if from_days_ago is not None:
if from_days_ago < 0:
raise click.UsageError("--from-days-ago must be >= 0.")
if from_days_ago <= before_days:
raise click.UsageError("--from-days-ago must be greater than --before-days.")
# Create policy based on billing configuration
# NOTE: graceful_period will be ignored when billing is disabled.
policy = create_message_clean_policy(graceful_period_days=graceful_period)
# Create and run the cleanup service
service = MessagesCleanService.from_time_range(
policy=policy,
start_from=start_from,
end_before=end_before,
batch_size=batch_size,
dry_run=dry_run,
)
if abs_mode:
service = MessagesCleanService.from_time_range(
policy=policy,
start_from=start_from,
end_before=end_before,
batch_size=batch_size,
dry_run=dry_run,
)
elif from_days_ago is None:
service = MessagesCleanService.from_days(
policy=policy,
days=before_days,
batch_size=batch_size,
dry_run=dry_run,
)
else:
now = datetime.datetime.now()
service = MessagesCleanService.from_time_range(
policy=policy,
start_from=now - datetime.timedelta(days=from_days_ago),
end_before=now - datetime.timedelta(days=before_days),
batch_size=batch_size,
dry_run=dry_run,
)
stats = service.run()
end_at = time.perf_counter()

View File

@ -8,9 +8,9 @@ from sqlalchemy.orm import Session
from controllers.common.schema import register_schema_model
from controllers.console.app.mcp_server import AppMCPServerStatus
from controllers.mcp import mcp_ns
from core.app.app_config.entities import VariableEntity
from core.mcp import types as mcp_types
from core.mcp.server.streamable_http import handle_mcp_request
from core.workflow.variables.input_entities import VariableEntity
from extensions.ext_database import db
from libs import helper
from models.model import App, AppMCPServer, AppMode, EndUser

View File

@ -32,7 +32,7 @@ from core.model_runtime.entities import (
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
from core.model_runtime.entities.model_entities import ModelFeature
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.prompt.utils.extract_thread_messages import extract_thread_messages
from core.prompt.utils.extract_thread_messages import extract_thread_messages
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import (
ToolParameter,

View File

@ -17,8 +17,8 @@ from core.model_runtime.entities.message_entities import (
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
from core.ops.ops_trace_manager import TraceQueueManager
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import ToolInvokeMeta
from core.tools.tool_engine import ToolEngine

View File

@ -21,7 +21,7 @@ from core.model_runtime.entities import (
UserPromptMessage,
)
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
from core.model_runtime.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
from core.tools.entities.tool_entities import ToolInvokeMeta
from core.tools.tool_engine import ToolEngine
from core.workflow.file import file_manager

View File

@ -5,7 +5,7 @@ from core.app.app_config.entities import (
PromptTemplateEntity,
)
from core.model_runtime.entities.message_entities import PromptMessageRole
from core.model_runtime.prompt.simple_prompt_transform import ModelMode
from core.prompt.simple_prompt_transform import ModelMode
from models.model import AppMode

View File

@ -1,7 +1,8 @@
import re
from core.app.app_config.entities import ExternalDataVariableEntity, VariableEntity, VariableEntityType
from core.app.app_config.entities import ExternalDataVariableEntity
from core.external_data_tool.factory import ExternalDataToolFactory
from core.workflow.variables.input_entities import VariableEntity, VariableEntityType
_ALLOWED_VARIABLE_ENTITY_TYPE = frozenset(
[

View File

@ -2,12 +2,12 @@ from collections.abc import Sequence
from enum import StrEnum, auto
from typing import Any, Literal
from jsonschema import Draft7Validator, SchemaError
from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, Field
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.entities.message_entities import PromptMessageRole
from core.workflow.file import FileTransferMethod, FileType, FileUploadConfig
from core.workflow.file import FileUploadConfig
from core.workflow.variables.input_entities import VariableEntity as WorkflowVariableEntity
from models.model import AppMode
@ -90,61 +90,7 @@ class PromptTemplateEntity(BaseModel):
advanced_completion_prompt_template: AdvancedCompletionPromptTemplateEntity | None = None
class VariableEntityType(StrEnum):
TEXT_INPUT = "text-input"
SELECT = "select"
PARAGRAPH = "paragraph"
NUMBER = "number"
EXTERNAL_DATA_TOOL = "external_data_tool"
FILE = "file"
FILE_LIST = "file-list"
CHECKBOX = "checkbox"
JSON_OBJECT = "json_object"
class VariableEntity(BaseModel):
"""
Variable Entity.
"""
# `variable` records the name of the variable in user inputs.
variable: str
label: str
description: str = ""
type: VariableEntityType
required: bool = False
hide: bool = False
default: Any = None
max_length: int | None = None
options: Sequence[str] = Field(default_factory=list)
allowed_file_types: Sequence[FileType] | None = Field(default_factory=list)
allowed_file_extensions: Sequence[str] | None = Field(default_factory=list)
allowed_file_upload_methods: Sequence[FileTransferMethod] | None = Field(default_factory=list)
json_schema: dict | None = Field(default=None)
@field_validator("description", mode="before")
@classmethod
def convert_none_description(cls, v: Any) -> str:
return v or ""
@field_validator("options", mode="before")
@classmethod
def convert_none_options(cls, v: Any) -> Sequence[str]:
return v or []
@field_validator("json_schema")
@classmethod
def validate_json_schema(cls, schema: dict | None) -> dict | None:
if schema is None:
return None
try:
Draft7Validator.check_schema(schema)
except SchemaError as e:
raise ValueError(f"Invalid JSON schema: {e.message}")
return schema
class RagPipelineVariableEntity(VariableEntity):
class RagPipelineVariableEntity(WorkflowVariableEntity):
"""
Rag Pipeline Variable Entity.
"""
@ -314,7 +260,7 @@ class AppConfig(BaseModel):
app_id: str
app_mode: AppMode
additional_features: AppAdditionalFeatures | None = None
variables: list[VariableEntity] = []
variables: list[WorkflowVariableEntity] = []
sensitive_word_avoidance: SensitiveWordAvoidanceEntity | None = None

View File

@ -1,6 +1,7 @@
import re
from core.app.app_config.entities import RagPipelineVariableEntity, VariableEntity
from core.app.app_config.entities import RagPipelineVariableEntity
from core.workflow.variables.input_entities import VariableEntity
from models.workflow import Workflow

View File

@ -32,8 +32,8 @@ from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotA
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer
from core.helper.trace_id_helper import extract_external_trace_id_from_args
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.model_runtime.prompt.utils.get_thread_messages_length import get_thread_messages_length
from core.ops.ops_trace_manager import TraceQueueManager
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
from core.repositories import DifyCoreRepositoryFactory
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.repositories.draft_variable_repository import (

View File

@ -3,7 +3,6 @@ from typing import TYPE_CHECKING, Any, Union, final
from sqlalchemy.orm import Session
from core.app.app_config.entities import VariableEntityType
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.enums import NodeType
from core.workflow.file import File, FileUploadConfig
@ -12,13 +11,14 @@ from core.workflow.repositories.draft_variable_repository import (
DraftVariableSaverFactory,
NoopDraftVariableSaver,
)
from core.workflow.variables.input_entities import VariableEntityType
from factories import file_factory
from libs.orjson import orjson_dumps
from models import Account, EndUser
from services.workflow_draft_variable_service import DraftVariableSaver as DraftVariableSaverImpl
if TYPE_CHECKING:
from core.app.app_config.entities import VariableEntity
from core.workflow.variables.input_entities import VariableEntity
class BaseAppGenerator:

View File

@ -33,14 +33,10 @@ from core.model_runtime.entities.message_entities import (
)
from core.model_runtime.entities.model_entities import ModelPropertyKey
from core.model_runtime.errors.invoke import InvokeBadRequestError
from core.model_runtime.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.model_runtime.prompt.entities.advanced_prompt_entities import (
ChatModelMessage,
CompletionModelPromptTemplate,
MemoryConfig,
)
from core.model_runtime.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform
from core.moderation.input_moderation import InputModeration
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform
from core.tools.tool_file_manager import ToolFileManager
from core.workflow.file.enums import FileTransferMethod, FileType
from extensions.ext_database import db

View File

@ -27,7 +27,7 @@ from core.app.entities.task_entities import (
CompletionAppStreamResponse,
)
from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline
from core.model_runtime.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from extensions.ext_database import db
from extensions.ext_redis import get_pubsub_broadcast_channel
from libs.broadcast_channel.channel import Topic

View File

@ -1 +1,5 @@
"""LLM-related application services."""
from .quota import deduct_llm_quota, ensure_llm_quota_available
__all__ = ["deduct_llm_quota", "ensure_llm_quota_available"]

93
api/core/app/llm/quota.py Normal file
View File

@ -0,0 +1,93 @@
from sqlalchemy import update
from sqlalchemy.orm import Session
from configs import dify_config
from core.entities.model_entities import ModelStatus
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
from core.errors.error import QuotaExceededError
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMUsage
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.provider import Provider, ProviderType
from models.provider_ids import ModelProviderID
def ensure_llm_quota_available(*, model_instance: ModelInstance) -> None:
provider_model_bundle = model_instance.provider_model_bundle
provider_configuration = provider_model_bundle.configuration
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
return
provider_model = provider_configuration.get_provider_model(
model_type=model_instance.model_type_instance.model_type,
model=model_instance.model_name,
)
if provider_model and provider_model.status == ModelStatus.QUOTA_EXCEEDED:
raise QuotaExceededError(f"Model provider {model_instance.provider} quota exceeded.")
def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None:
provider_model_bundle = model_instance.provider_model_bundle
provider_configuration = provider_model_bundle.configuration
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
return
system_configuration = provider_configuration.system_configuration
quota_unit = None
for quota_configuration in system_configuration.quota_configurations:
if quota_configuration.quota_type == system_configuration.current_quota_type:
quota_unit = quota_configuration.quota_unit
if quota_configuration.quota_limit == -1:
return
break
used_quota = None
if quota_unit:
if quota_unit == QuotaUnit.TOKENS:
used_quota = usage.total_tokens
elif quota_unit == QuotaUnit.CREDITS:
used_quota = dify_config.get_model_credits(model_instance.model_name)
else:
used_quota = 1
if used_quota is not None and system_configuration.current_quota_type is not None:
if system_configuration.current_quota_type == ProviderQuotaType.TRIAL:
from services.credit_pool_service import CreditPoolService
CreditPoolService.check_and_deduct_credits(
tenant_id=tenant_id,
credits_required=used_quota,
)
elif system_configuration.current_quota_type == ProviderQuotaType.PAID:
from services.credit_pool_service import CreditPoolService
CreditPoolService.check_and_deduct_credits(
tenant_id=tenant_id,
credits_required=used_quota,
pool_type="paid",
)
else:
with Session(db.engine) as session:
stmt = (
update(Provider)
.where(
Provider.tenant_id == tenant_id,
# TODO: Use provider name with prefix after the data migration.
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == system_configuration.current_quota_type.value,
Provider.quota_limit > Provider.quota_used,
)
.values(
quota_used=Provider.quota_used + used_quota,
last_used=naive_utc_now(),
)
)
session.execute(stmt)
session.commit()

View File

@ -52,10 +52,10 @@ from core.model_runtime.entities.message_entities import (
TextPromptMessageContent,
)
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.prompt.utils.prompt_message_util import PromptMessageUtil
from core.model_runtime.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.tools.signature import sign_tool_file
from core.workflow.file import helpers as file_helpers
from core.workflow.file.enums import FileTransferMethod

View File

@ -1,9 +1,11 @@
"""Workflow-level GraphEngine layers that depend on outer infrastructure."""
from .llm_quota import LLMQuotaLayer
from .observability import ObservabilityLayer
from .persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
__all__ = [
"LLMQuotaLayer",
"ObservabilityLayer",
"PersistenceWorkflowInfo",
"WorkflowPersistenceLayer",

View File

@ -0,0 +1,128 @@
"""
LLM quota deduction layer for GraphEngine.
This layer centralizes model-quota deduction outside node implementations.
"""
import logging
from typing import TYPE_CHECKING, cast, final
from typing_extensions import override
from core.app.llm import deduct_llm_quota, ensure_llm_quota_available
from core.errors.error import QuotaExceededError
from core.model_manager import ModelInstance
from core.workflow.enums import NodeType
from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_events import GraphEngineEvent, GraphNodeEventBase
from core.workflow.graph_events.node import NodeRunSucceededEvent
from core.workflow.nodes.base.node import Node
if TYPE_CHECKING:
from core.workflow.nodes.llm.node import LLMNode
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
logger = logging.getLogger(__name__)
@final
class LLMQuotaLayer(GraphEngineLayer):
"""Graph layer that applies LLM quota deduction after node execution."""
def __init__(self) -> None:
super().__init__()
self._abort_sent = False
@override
def on_graph_start(self) -> None:
self._abort_sent = False
@override
def on_event(self, event: GraphEngineEvent) -> None:
_ = event
@override
def on_graph_end(self, error: Exception | None) -> None:
_ = error
@override
def on_node_run_start(self, node: Node) -> None:
if self._abort_sent:
return
model_instance = self._extract_model_instance(node)
if model_instance is None:
return
try:
ensure_llm_quota_available(model_instance=model_instance)
except QuotaExceededError as exc:
self._set_stop_event(node)
self._send_abort_command(reason=str(exc))
logger.warning("LLM quota check failed, node_id=%s, error=%s", node.id, exc)
@override
def on_node_run_end(
self, node: Node, error: Exception | None, result_event: GraphNodeEventBase | None = None
) -> None:
if error is not None or not isinstance(result_event, NodeRunSucceededEvent):
return
model_instance = self._extract_model_instance(node)
if model_instance is None:
return
try:
deduct_llm_quota(
tenant_id=node.tenant_id,
model_instance=model_instance,
usage=result_event.node_run_result.llm_usage,
)
except QuotaExceededError as exc:
self._set_stop_event(node)
self._send_abort_command(reason=str(exc))
logger.warning("LLM quota deduction exceeded, node_id=%s, error=%s", node.id, exc)
except Exception:
logger.exception("LLM quota deduction failed, node_id=%s", node.id)
@staticmethod
def _set_stop_event(node: Node) -> None:
stop_event = getattr(node.graph_runtime_state, "stop_event", None)
if stop_event is not None:
stop_event.set()
def _send_abort_command(self, *, reason: str) -> None:
if not self.command_channel or self._abort_sent:
return
try:
self.command_channel.send_command(
AbortCommand(
command_type=CommandType.ABORT,
reason=reason,
)
)
self._abort_sent = True
except Exception:
logger.exception("Failed to send quota abort command")
@staticmethod
def _extract_model_instance(node: Node) -> ModelInstance | None:
try:
match node.node_type:
case NodeType.LLM:
return cast("LLMNode", node).model_instance
case NodeType.PARAMETER_EXTRACTOR:
return cast("ParameterExtractorNode", node).model_instance
case NodeType.QUESTION_CLASSIFIER:
return cast("QuestionClassifierNode", node).model_instance
case _:
return None
except AttributeError:
logger.warning(
"LLMQuotaLayer skipped quota deduction because node does not expose a model instance, node_id=%s",
node.id,
)
return None

View File

@ -1,6 +1,8 @@
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any, cast, final
from sqlalchemy import select
from sqlalchemy.orm import Session
from typing_extensions import override
from configs import dify_config
@ -11,14 +13,16 @@ from core.helper.code_executor.code_executor import (
CodeExecutor,
)
from core.helper.ssrf_proxy import ssrf_proxy
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.memory import PromptMessageMemory
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.tools.tool_file_manager import ToolFileManager
from core.workflow.entities.graph_config import NodeConfigDict
from core.workflow.enums import NodeType
from core.workflow.enums import NodeType, SystemVariableKey
from core.workflow.file.file_manager import file_manager
from core.workflow.graph.graph import NodeFactory
from core.workflow.nodes.base.node import Node
@ -29,11 +33,9 @@ from core.workflow.nodes.datasource import DatasourceNode
from core.workflow.nodes.document_extractor import DocumentExtractorNode, UnstructuredApiConfig
from core.workflow.nodes.http_request import HttpRequestNode, build_http_request_config
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
from core.workflow.nodes.llm import llm_utils
from core.workflow.nodes.llm.entities import ModelConfig
from core.workflow.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError
from core.workflow.nodes.llm.node import LLMNode
from core.workflow.nodes.llm.protocols import PromptMessageMemory
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
@ -41,12 +43,34 @@ from core.workflow.nodes.template_transform.template_renderer import (
CodeExecutorJinja2TemplateRenderer,
)
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
from core.workflow.variables.segments import StringSegment
from extensions.ext_database import db
from models.model import Conversation
if TYPE_CHECKING:
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState
def fetch_memory(
*,
conversation_id: str | None,
app_id: str,
node_data_memory: MemoryConfig | None,
model_instance: ModelInstance,
) -> TokenBufferMemory | None:
if not node_data_memory or not conversation_id:
return None
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id)
conversation = session.scalar(stmt)
if not conversation:
return None
return TokenBufferMemory(conversation=conversation, model_instance=model_instance)
class DefaultWorkflowCodeExecutor:
def execute(
self,
@ -221,6 +245,7 @@ class DifyNodeFactory(NodeFactory):
if node_type == NodeType.QUESTION_CLASSIFIER:
model_instance = self._build_model_instance_for_llm_node(node_data)
memory = self._build_memory_for_llm_node(node_data=node_data, model_instance=model_instance)
return QuestionClassifierNode(
id=node_id,
config=node_config,
@ -229,10 +254,12 @@ class DifyNodeFactory(NodeFactory):
credentials_provider=self._llm_credentials_provider,
model_factory=self._llm_model_factory,
model_instance=model_instance,
memory=memory,
)
if node_type == NodeType.PARAMETER_EXTRACTOR:
model_instance = self._build_model_instance_for_llm_node(node_data)
memory = self._build_memory_for_llm_node(node_data=node_data, model_instance=model_instance)
return ParameterExtractorNode(
id=node_id,
config=node_config,
@ -241,6 +268,7 @@ class DifyNodeFactory(NodeFactory):
credentials_provider=self._llm_credentials_provider,
model_factory=self._llm_model_factory,
model_instance=model_instance,
memory=memory,
)
return node_class(
@ -295,8 +323,14 @@ class DifyNodeFactory(NodeFactory):
return None
node_memory = MemoryConfig.model_validate(raw_memory_config)
return llm_utils.fetch_memory(
variable_pool=self.graph_runtime_state.variable_pool,
conversation_id_variable = self.graph_runtime_state.variable_pool.get(
["sys", SystemVariableKey.CONVERSATION_ID]
)
conversation_id = (
conversation_id_variable.value if isinstance(conversation_id_variable, StringSegment) else None
)
return fetch_memory(
conversation_id=conversation_id,
app_id=self.graph_init_params.app_id,
node_data_memory=node_memory,
model_instance=model_instance,

View File

@ -27,10 +27,10 @@ from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.model_runtime.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.ops.utils import measure_time
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
from extensions.ext_database import db
from extensions.ext_storage import storage

View File

@ -4,10 +4,10 @@ from collections.abc import Mapping
from typing import Any, cast
from configs import dify_config
from core.app.app_config.entities import VariableEntity, VariableEntityType
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.features.rate_limiting.rate_limit import RateLimitGenerator
from core.mcp import types as mcp_types
from core.workflow.variables.input_entities import VariableEntity, VariableEntityType
from models.model import App, AppMCPServer, AppMode, EndUser
from services.app_generate_service import AppGenerateService

View File

@ -14,7 +14,7 @@ from core.model_runtime.entities import (
UserPromptMessage,
)
from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes
from core.model_runtime.prompt.utils.extract_thread_messages import extract_thread_messages
from core.prompt.utils.extract_thread_messages import extract_thread_messages
from core.workflow.file import file_manager
from extensions.ext_database import db
from factories import file_factory

View File

@ -0,0 +1,3 @@
from .prompt_message_memory import DEFAULT_MEMORY_MAX_TOKEN_LIMIT, PromptMessageMemory
__all__ = ["DEFAULT_MEMORY_MAX_TOKEN_LIMIT", "PromptMessageMemory"]

View File

@ -0,0 +1,18 @@
from __future__ import annotations
from collections.abc import Sequence
from typing import Protocol
from core.model_runtime.entities import PromptMessage
DEFAULT_MEMORY_MAX_TOKEN_LIMIT = 2000
class PromptMessageMemory(Protocol):
"""Port for loading memory as prompt messages."""
def get_history_prompt_messages(
self, max_token_limit: int = DEFAULT_MEMORY_MAX_TOKEN_LIMIT, message_limit: int | None = None
) -> Sequence[PromptMessage]:
"""Return historical prompt messages constrained by token/message limits."""
...

View File

@ -2,6 +2,7 @@ import tempfile
from binascii import hexlify, unhexlify
from collections.abc import Generator
from core.app.llm import deduct_llm_quota
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
from core.model_manager import ModelManager
from core.model_runtime.entities.llm_entities import (
@ -29,7 +30,6 @@ from core.plugin.entities.request import (
)
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.utils.model_invocation_utils import ModelInvocationUtils
from core.workflow.nodes.llm import llm_utils
from models.account import Tenant
@ -63,16 +63,14 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
def handle() -> Generator[LLMResultChunk, None, None]:
for chunk in response:
if chunk.delta.usage:
llm_utils.deduct_llm_quota(
tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage
)
deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage)
chunk.prompt_messages = []
yield chunk
return handle()
else:
if response.usage:
llm_utils.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
def handle_non_streaming(response: LLMResult) -> Generator[LLMResultChunk, None, None]:
yield LLMResultChunk(
@ -126,16 +124,14 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
def handle() -> Generator[LLMResultChunkWithStructuredOutput, None, None]:
for chunk in response:
if chunk.delta.usage:
llm_utils.deduct_llm_quota(
tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage
)
deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage)
chunk.prompt_messages = []
yield chunk
return handle()
else:
if response.usage:
llm_utils.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
def handle_non_streaming(
response: LLMResultWithStructuredOutput,

View File

@ -14,13 +14,9 @@ from core.model_runtime.entities import (
UserPromptMessage,
)
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
from core.model_runtime.prompt.entities.advanced_prompt_entities import (
ChatModelMessage,
CompletionModelPromptTemplate,
MemoryConfig,
)
from core.model_runtime.prompt.prompt_transform import PromptTransform
from core.model_runtime.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from core.prompt.prompt_transform import PromptTransform
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.file import file_manager
from core.workflow.file.models import File
from core.workflow.runtime import VariablePool

View File

@ -10,7 +10,7 @@ from core.model_runtime.entities.message_entities import (
UserPromptMessage,
)
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.prompt.prompt_transform import PromptTransform
from core.prompt.prompt_transform import PromptTransform
class AgentHistoryPromptTransform(PromptTransform):

View File

@ -5,7 +5,7 @@ from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.message_entities import PromptMessage
from core.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey
from core.model_runtime.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
class PromptTransform:

View File

@ -15,9 +15,9 @@ from core.model_runtime.entities.message_entities import (
TextPromptMessageContent,
UserPromptMessage,
)
from core.model_runtime.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.model_runtime.prompt.prompt_transform import PromptTransform
from core.model_runtime.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.prompt.prompt_transform import PromptTransform
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.file import file_manager
from models.model import AppMode

View File

@ -1,6 +1,6 @@
from sqlalchemy import select
from core.model_runtime.prompt.utils.extract_thread_messages import extract_thread_messages
from core.prompt.utils.extract_thread_messages import extract_thread_messages
from extensions.ext_database import db
from models.model import Message

View File

@ -10,7 +10,7 @@ from core.model_runtime.entities import (
PromptMessageRole,
TextPromptMessageContent,
)
from core.model_runtime.prompt.simple_prompt_transform import ModelMode
from core.prompt.simple_prompt_transform import ModelMode
class PromptMessageUtil:

View File

@ -8,6 +8,7 @@ from typing import Any, cast
logger = logging.getLogger(__name__)
from core.app.llm import deduct_llm_quota
from core.entities.knowledge_entities import PreviewDetail
from core.llm_generator.prompts import DEFAULT_GENERATOR_SUMMARY_PROMPT
from core.model_manager import ModelInstance
@ -35,7 +36,6 @@ from core.rag.models.document import AttachmentDocument, Document, MultimodalGen
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.utils.text_processing_utils import remove_leading_symbols
from core.workflow.file import File, FileTransferMethod, FileType, file_manager
from core.workflow.nodes.llm import llm_utils
from extensions.ext_database import db
from factories.file_factory import build_from_mapping
from libs import helper
@ -474,7 +474,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
# Deduct quota for summary generation (same as workflow nodes)
try:
llm_utils.deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
except Exception as e:
# Log but don't fail summary generation if quota deduction fails
logger.warning("Failed to deduct quota for summary generation: %s", str(e))

View File

@ -29,12 +29,12 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.model_runtime.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.model_runtime.prompt.simple_prompt_transform import ModelMode
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.ops.utils import measure_time
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.prompt.simple_prompt_transform import ModelMode
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
from core.rag.datasource.retrieval_service import RetrievalService

View File

@ -2,14 +2,14 @@ from collections.abc import Generator, Sequence
from typing import Union
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.app.llm import deduct_llm_quota
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
from core.model_runtime.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.model_runtime.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.rag.retrieval.output_parser.react_output import ReactAction
from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser
from core.workflow.nodes.llm import llm_utils
PREFIX = """Respond to the human as helpfully and accurately as possible. You have access to the following tools:"""
@ -162,7 +162,7 @@ class ReactMultiDatasetRouter:
text, usage = self._handle_invoke_result(invoke_result=invoke_result)
# deduct quota
llm_utils.deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
return text, usage

View File

@ -6,9 +6,9 @@ identity:
zh_Hans: 网页抓取
pt_BR: WebScraper
description:
en_US: Web Scrapper tool kit is used to scrape web
en_US: Web Scraper tool kit is used to scrape web
zh_Hans: 一个用于抓取网页的工具。
pt_BR: Web Scrapper tool kit is used to scrape web
pt_BR: Web Scraper tool kit is used to scrape web
icon: icon.svg
tags:
- productivity

View File

@ -1,11 +1,11 @@
from collections.abc import Mapping, Sequence
from typing import Any
from core.app.app_config.entities import VariableEntity
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
from core.tools.errors import WorkflowToolHumanInputNotSupportedError
from core.workflow.enums import NodeType
from core.workflow.nodes.base.entities import OutputVariableEntity
from core.workflow.variables.input_entities import VariableEntity
class WorkflowToolConfigurationUtils:

View File

@ -5,7 +5,6 @@ from collections.abc import Mapping
from pydantic import Field
from sqlalchemy.orm import Session
from core.app.app_config.entities import VariableEntity, VariableEntityType
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.db.session_factory import session_factory
from core.plugin.entities.parameters import PluginParameterOption
@ -23,6 +22,7 @@ from core.tools.entities.tool_entities import (
)
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
from core.tools.workflow_as_tool.tool import WorkflowTool
from core.workflow.variables.input_entities import VariableEntity, VariableEntityType
from extensions.ext_database import db
from models.account import Account
from models.model import App, AppMode

View File

@ -3,7 +3,7 @@ from typing import Any, Literal, Union
from pydantic import BaseModel
from core.model_runtime.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.tools.entities.tool_entities import ToolSelector
from core.workflow.nodes.base.entities import BaseNodeData

View File

@ -588,6 +588,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
def _create_graph_engine(self, index: int, item: object):
# Import dependencies
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
from core.app.workflow.node_factory import DifyNodeFactory
from core.workflow.entities import GraphInitParams
from core.workflow.graph import Graph
@ -642,5 +643,6 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
config=GraphEngineConfig(),
)
graph_engine.layer(LLMQuotaLayer())
return graph_engine

View File

@ -4,11 +4,7 @@ from typing import Any, Literal
from pydantic import BaseModel, Field, field_validator
from core.model_runtime.entities import ImagePromptMessageContent, LLMMode
from core.model_runtime.prompt.entities.advanced_prompt_entities import (
ChatModelMessage,
CompletionModelPromptTemplate,
MemoryConfig,
)
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from core.workflow.nodes.base import BaseNodeData
from core.workflow.nodes.base.entities import VariableSelector

View File

@ -1,26 +1,19 @@
from collections.abc import Sequence
from typing import cast
from sqlalchemy import select, update
from sqlalchemy.orm import Session
from configs import dify_config
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities import PromptMessageRole
from core.model_runtime.entities.message_entities import (
ImagePromptMessageContent,
PromptMessage,
TextPromptMessageContent,
)
from core.model_runtime.entities.model_entities import AIModelEntity
from core.model_runtime.memory import PromptMessageMemory
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.workflow.enums import SystemVariableKey
from core.workflow.file.models import File
from core.workflow.runtime import VariablePool
from core.workflow.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment, StringSegment
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.model import Conversation
from models.provider import Provider, ProviderType
from models.provider_ids import ModelProviderID
from core.workflow.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment
from .exc import InvalidVariableTypeError
@ -48,88 +41,51 @@ def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequenc
raise InvalidVariableTypeError(f"Invalid variable type: {type(variable)}")
def fetch_memory(
variable_pool: VariablePool, app_id: str, node_data_memory: MemoryConfig | None, model_instance: ModelInstance
) -> TokenBufferMemory | None:
if not node_data_memory:
return None
# get conversation id
conversation_id_variable = variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
if not isinstance(conversation_id_variable, StringSegment):
return None
conversation_id = conversation_id_variable.value
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id)
conversation = session.scalar(stmt)
if not conversation:
return None
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
return memory
def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUsage):
provider_model_bundle = model_instance.provider_model_bundle
provider_configuration = provider_model_bundle.configuration
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
return
system_configuration = provider_configuration.system_configuration
quota_unit = None
for quota_configuration in system_configuration.quota_configurations:
if quota_configuration.quota_type == system_configuration.current_quota_type:
quota_unit = quota_configuration.quota_unit
if quota_configuration.quota_limit == -1:
return
break
used_quota = None
if quota_unit:
if quota_unit == QuotaUnit.TOKENS:
used_quota = usage.total_tokens
elif quota_unit == QuotaUnit.CREDITS:
used_quota = dify_config.get_model_credits(model_instance.model_name)
def convert_history_messages_to_text(
*,
history_messages: Sequence[PromptMessage],
human_prefix: str,
ai_prefix: str,
) -> str:
string_messages: list[str] = []
for message in history_messages:
if message.role == PromptMessageRole.USER:
role = human_prefix
elif message.role == PromptMessageRole.ASSISTANT:
role = ai_prefix
else:
used_quota = 1
continue
if used_quota is not None and system_configuration.current_quota_type is not None:
if system_configuration.current_quota_type == ProviderQuotaType.TRIAL:
from services.credit_pool_service import CreditPoolService
if isinstance(message.content, list):
content_parts = []
for content in message.content:
if isinstance(content, TextPromptMessageContent):
content_parts.append(content.data)
elif isinstance(content, ImagePromptMessageContent):
content_parts.append("[image]")
CreditPoolService.check_and_deduct_credits(
tenant_id=tenant_id,
credits_required=used_quota,
)
elif system_configuration.current_quota_type == ProviderQuotaType.PAID:
from services.credit_pool_service import CreditPoolService
CreditPoolService.check_and_deduct_credits(
tenant_id=tenant_id,
credits_required=used_quota,
pool_type="paid",
)
inner_msg = "\n".join(content_parts)
string_messages.append(f"{role}: {inner_msg}")
else:
with Session(db.engine) as session:
stmt = (
update(Provider)
.where(
Provider.tenant_id == tenant_id,
# TODO: Use provider name with prefix after the data migration.
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == system_configuration.current_quota_type.value,
Provider.quota_limit > Provider.quota_used,
)
.values(
quota_used=Provider.quota_used + used_quota,
last_used=naive_utc_now(),
)
)
session.execute(stmt)
session.commit()
string_messages.append(f"{role}: {message.content}")
return "\n".join(string_messages)
def fetch_memory_text(
*,
memory: PromptMessageMemory,
max_token_limit: int,
message_limit: int | None = None,
human_prefix: str = "Human",
ai_prefix: str = "Assistant",
) -> str:
history_messages = memory.get_history_prompt_messages(
max_token_limit=max_token_limit,
message_limit=message_limit,
)
return convert_history_messages_to_text(
history_messages=history_messages,
human_prefix=human_prefix,
ai_prefix=ai_prefix,
)

View File

@ -37,9 +37,10 @@ from core.model_runtime.entities.message_entities import (
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
from core.model_runtime.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
from core.model_runtime.prompt.utils.prompt_message_util import PromptMessageUtil
from core.model_runtime.memory import PromptMessageMemory
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.tools.signature import sign_upload_file
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
@ -62,7 +63,7 @@ from core.workflow.node_events import (
from core.workflow.nodes.base.entities import VariableSelector
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory, PromptMessageMemory
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
from core.workflow.runtime import VariablePool
from core.workflow.variables import (
ArrayFileSegment,
@ -278,8 +279,6 @@ class LLMNode(Node[LLMNodeData]):
else None
)
# deduct quota
llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
break
elif isinstance(event, LLMStructuredOutput):
structured_output = event
@ -1234,6 +1233,10 @@ class LLMNode(Node[LLMNodeData]):
def retry(self) -> bool:
return self.node_data.retry_config.retry_enabled
@property
def model_instance(self) -> ModelInstance:
return self._model_instance
def _combine_message_content_with_role(
*, contents: str | list[PromptMessageContentUnionTypes] | None = None, role: PromptMessageRole
@ -1336,48 +1339,16 @@ def _handle_memory_completion_mode(
)
if not memory_config.role_prefix:
raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.")
memory_messages = memory.get_history_prompt_messages(
memory_text = llm_utils.fetch_memory_text(
memory=memory,
max_token_limit=rest_tokens,
message_limit=memory_config.window.size if memory_config.window.enabled else None,
)
memory_text = _convert_history_messages_to_text(
history_messages=memory_messages,
human_prefix=memory_config.role_prefix.user,
ai_prefix=memory_config.role_prefix.assistant,
)
return memory_text
def _convert_history_messages_to_text(
*,
history_messages: Sequence[PromptMessage],
human_prefix: str,
ai_prefix: str,
) -> str:
string_messages: list[str] = []
for message in history_messages:
if message.role == PromptMessageRole.USER:
role = human_prefix
elif message.role == PromptMessageRole.ASSISTANT:
role = ai_prefix
else:
continue
if isinstance(message.content, list):
content_parts = []
for content in message.content:
if isinstance(content, TextPromptMessageContent):
content_parts.append(content.data)
elif isinstance(content, ImagePromptMessageContent):
content_parts.append("[image]")
inner_msg = "\n".join(content_parts)
string_messages.append(f"{role}: {inner_msg}")
else:
string_messages.append(f"{role}: {message.content}")
return "\n".join(string_messages)
def _handle_completion_template(
*,
template: LLMNodeCompletionModelPromptTemplate,

View File

@ -1,10 +1,8 @@
from __future__ import annotations
from collections.abc import Sequence
from typing import Any, Protocol
from core.model_manager import ModelInstance
from core.model_runtime.entities import PromptMessage
class CredentialsProvider(Protocol):
@ -21,13 +19,3 @@ class ModelFactory(Protocol):
def init_model_instance(self, provider_name: str, model_name: str) -> ModelInstance:
"""Create a model instance that is ready for schema lookup and invocation."""
...
class PromptMessageMemory(Protocol):
"""Port for loading memory as prompt messages for LLM nodes."""
def get_history_prompt_messages(
self, max_token_limit: int = 2000, message_limit: int | None = None
) -> Sequence[PromptMessage]:
"""Return historical prompt messages constrained by token/message limits."""
...

View File

@ -413,6 +413,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
def _create_graph_engine(self, start_at: datetime, root_node_id: str):
# Import dependencies
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
from core.app.workflow.node_factory import DifyNodeFactory
from core.workflow.entities import GraphInitParams
from core.workflow.graph import Graph
@ -454,5 +455,6 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
config=GraphEngineConfig(),
)
graph_engine.layer(LLMQuotaLayer())
return graph_engine

View File

@ -7,7 +7,7 @@ from pydantic import (
field_validator,
)
from core.model_runtime.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.workflow.nodes.base import BaseNodeData
from core.workflow.nodes.llm.entities import ModelConfig, VisionConfig
from core.workflow.variables.types import SegmentType

View File

@ -5,7 +5,6 @@ import uuid
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, cast
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities import ImagePromptMessageContent
from core.model_runtime.entities.llm_entities import LLMUsage
@ -18,13 +17,18 @@ from core.model_runtime.entities.message_entities import (
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
from core.model_runtime.memory import PromptMessageMemory
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.model_runtime.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.model_runtime.prompt.simple_prompt_transform import ModelMode
from core.model_runtime.prompt.utils.prompt_message_util import PromptMessageUtil
from core.model_runtime.utils.encoders import jsonable_encoder
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.enums import (
NodeType,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.file import File
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base import variable_template_parser
@ -97,6 +101,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
_model_instance: ModelInstance
_credentials_provider: "CredentialsProvider"
_model_factory: "ModelFactory"
_memory: PromptMessageMemory | None
def __init__(
self,
@ -108,6 +113,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
credentials_provider: "CredentialsProvider",
model_factory: "ModelFactory",
model_instance: ModelInstance,
memory: PromptMessageMemory | None = None,
) -> None:
super().__init__(
id=id,
@ -118,6 +124,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
self._credentials_provider = credentials_provider
self._model_factory = model_factory
self._model_instance = model_instance
self._memory = memory
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
@ -163,13 +170,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
except ValueError as exc:
raise ModelSchemaNotFoundError("Model schema not found") from exc
# fetch memory
memory = llm_utils.fetch_memory(
variable_pool=variable_pool,
app_id=self.app_id,
node_data_memory=node_data.memory,
model_instance=model_instance,
)
memory = self._memory
if (
set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}
@ -308,9 +309,6 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
usage = invoke_result.usage
tool_call = invoke_result.message.tool_calls[0] if invoke_result.message.tool_calls else None
# deduct quota
llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
return text, usage, tool_call
def _generate_function_call_prompt(
@ -319,7 +317,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
query: str,
variable_pool: VariablePool,
model_instance: ModelInstance,
memory: TokenBufferMemory | None,
memory: PromptMessageMemory | None,
files: Sequence[File],
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
) -> tuple[list[PromptMessage], list[PromptMessageTool]]:
@ -407,7 +405,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
query: str,
variable_pool: VariablePool,
model_instance: ModelInstance,
memory: TokenBufferMemory | None,
memory: PromptMessageMemory | None,
files: Sequence[File],
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
) -> list[PromptMessage]:
@ -445,7 +443,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
query: str,
variable_pool: VariablePool,
model_instance: ModelInstance,
memory: TokenBufferMemory | None,
memory: PromptMessageMemory | None,
files: Sequence[File],
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
) -> list[PromptMessage]:
@ -470,7 +468,8 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
files=files,
context="",
memory_config=node_data.memory,
memory=memory,
# AdvancedPromptTransform is still typed against TokenBufferMemory.
memory=cast(Any, memory),
model_instance=model_instance,
image_detail_config=vision_detail,
)
@ -483,7 +482,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
query: str,
variable_pool: VariablePool,
model_instance: ModelInstance,
memory: TokenBufferMemory | None,
memory: PromptMessageMemory | None,
files: Sequence[File],
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
) -> list[PromptMessage]:
@ -715,7 +714,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
node_data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
memory: TokenBufferMemory | None,
memory: PromptMessageMemory | None,
max_token_limit: int = 2000,
) -> list[ChatModelMessage]:
model_mode = ModelMode(node_data.model.mode)
@ -724,8 +723,8 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
instruction = variable_pool.convert_template(node_data.instruction or "").text
if memory and node_data.memory and node_data.memory.window:
memory_str = memory.get_history_prompt_text(
max_token_limit=max_token_limit, message_limit=node_data.memory.window.size
memory_str = llm_utils.fetch_memory_text(
memory=memory, max_token_limit=max_token_limit, message_limit=node_data.memory.window.size
)
if model_mode == ModelMode.CHAT:
system_prompt_messages = ChatModelMessage(
@ -742,7 +741,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
node_data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
memory: TokenBufferMemory | None,
memory: PromptMessageMemory | None,
max_token_limit: int = 2000,
):
model_mode = ModelMode(node_data.model.mode)
@ -751,8 +750,8 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
instruction = variable_pool.convert_template(node_data.instruction or "").text
if memory and node_data.memory and node_data.memory.window:
memory_str = memory.get_history_prompt_text(
max_token_limit=max_token_limit, message_limit=node_data.memory.window.size
memory_str = llm_utils.fetch_memory_text(
memory=memory, max_token_limit=max_token_limit, message_limit=node_data.memory.window.size
)
if model_mode == ModelMode.CHAT:
system_prompt_messages = ChatModelMessage(
@ -828,6 +827,10 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
return rest_tokens
@property
def model_instance(self) -> ModelInstance:
return self._model_instance
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,

View File

@ -1,6 +1,6 @@
from pydantic import BaseModel, Field
from core.model_runtime.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.workflow.nodes.base import BaseNodeData
from core.workflow.nodes.llm import ModelConfig, VisionConfig

View File

@ -3,12 +3,12 @@ import re
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities import LLMUsage, ModelPropertyKey, PromptMessageRole
from core.model_runtime.prompt.simple_prompt_transform import ModelMode
from core.model_runtime.prompt.utils.prompt_message_util import PromptMessageUtil
from core.model_runtime.memory import PromptMessageMemory
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities import GraphInitParams
from core.workflow.enums import (
NodeExecutionType,
@ -56,6 +56,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
_credentials_provider: "CredentialsProvider"
_model_factory: "ModelFactory"
_model_instance: ModelInstance
_memory: PromptMessageMemory | None
def __init__(
self,
@ -67,6 +68,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
credentials_provider: "CredentialsProvider",
model_factory: "ModelFactory",
model_instance: ModelInstance,
memory: PromptMessageMemory | None = None,
llm_file_saver: LLMFileSaver | None = None,
):
super().__init__(
@ -81,6 +83,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
self._credentials_provider = credentials_provider
self._model_factory = model_factory
self._model_instance = model_instance
self._memory = memory
if llm_file_saver is None:
llm_file_saver = FileSaverImpl(
@ -103,13 +106,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
variables = {"query": query}
# fetch model instance
model_instance = self._model_instance
# fetch memory
memory = llm_utils.fetch_memory(
variable_pool=variable_pool,
app_id=self.app_id,
node_data_memory=node_data.memory,
model_instance=model_instance,
)
memory = self._memory
# fetch instruction
node_data.instruction = node_data.instruction or ""
node_data.instruction = variable_pool.convert_template(node_data.instruction).text
@ -240,6 +237,10 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
llm_usage=usage,
)
@property
def model_instance(self) -> ModelInstance:
return self._model_instance
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
@ -323,7 +324,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
self,
node_data: QuestionClassifierNodeData,
query: str,
memory: TokenBufferMemory | None,
memory: PromptMessageMemory | None,
max_token_limit: int = 2000,
):
model_mode = ModelMode(node_data.model.mode)
@ -336,7 +337,8 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
input_text = query
memory_str = ""
if memory:
memory_str = memory.get_history_prompt_text(
memory_str = llm_utils.fetch_memory_text(
memory=memory,
max_token_limit=max_token_limit,
message_limit=node_data.memory.window.size if node_data.memory and node_data.memory.window else None,
)

View File

@ -2,8 +2,8 @@ from collections.abc import Sequence
from pydantic import Field
from core.app.app_config.entities import VariableEntity
from core.workflow.nodes.base import BaseNodeData
from core.workflow.variables.input_entities import VariableEntity
class StartNodeData(BaseNodeData):

View File

@ -2,12 +2,12 @@ from typing import Any
from jsonschema import Draft7Validator, ValidationError
from core.app.app_config.entities import VariableEntityType
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.start.entities import StartNodeData
from core.workflow.variables.input_entities import VariableEntityType
class StartNode(Node[StartNodeData]):

View File

@ -1,3 +1,4 @@
from .input_entities import VariableEntity, VariableEntityType
from .segment_group import SegmentGroup
from .segments import (
ArrayAnySegment,
@ -64,4 +65,6 @@ __all__ = [
"StringVariable",
"Variable",
"VariableBase",
"VariableEntity",
"VariableEntityType",
]

View File

@ -0,0 +1,62 @@
from collections.abc import Sequence
from enum import StrEnum
from typing import Any
from jsonschema import Draft7Validator, SchemaError
from pydantic import BaseModel, Field, field_validator
from core.workflow.file import FileTransferMethod, FileType
class VariableEntityType(StrEnum):
TEXT_INPUT = "text-input"
SELECT = "select"
PARAGRAPH = "paragraph"
NUMBER = "number"
EXTERNAL_DATA_TOOL = "external_data_tool"
FILE = "file"
FILE_LIST = "file-list"
CHECKBOX = "checkbox"
JSON_OBJECT = "json_object"
class VariableEntity(BaseModel):
"""
Shared variable entity used by workflow runtime and app configuration.
"""
# `variable` records the name of the variable in user inputs.
variable: str
label: str
description: str = ""
type: VariableEntityType
required: bool = False
hide: bool = False
default: Any = None
max_length: int | None = None
options: Sequence[str] = Field(default_factory=list)
allowed_file_types: Sequence[FileType] | None = Field(default_factory=list)
allowed_file_extensions: Sequence[str] | None = Field(default_factory=list)
allowed_file_upload_methods: Sequence[FileTransferMethod] | None = Field(default_factory=list)
json_schema: dict[str, Any] | None = Field(default=None)
@field_validator("description", mode="before")
@classmethod
def convert_none_description(cls, value: Any) -> str:
return value or ""
@field_validator("options", mode="before")
@classmethod
def convert_none_options(cls, value: Any) -> Sequence[str]:
return value or []
@field_validator("json_schema")
@classmethod
def validate_json_schema(cls, schema: dict[str, Any] | None) -> dict[str, Any] | None:
if schema is None:
return None
try:
Draft7Validator.check_schema(schema)
except SchemaError as error:
raise ValueError(f"Invalid JSON schema: {error.message}")
return schema

View File

@ -6,6 +6,7 @@ from typing import Any, cast
from configs import dify_config
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
from core.app.workflow.layers.observability import ObservabilityLayer
from core.app.workflow.node_factory import DifyNodeFactory
from core.workflow.constants import ENVIRONMENT_VARIABLE_NODE_ID
@ -106,6 +107,7 @@ class WorkflowEntry:
max_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, max_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME
)
self.graph_engine.layer(limits_layer)
self.graph_engine.layer(LLMQuotaLayer())
# Add observability layer when OTel is enabled
if dify_config.ENABLE_OTEL or is_instrument_flag_enabled():

View File

@ -0,0 +1,48 @@
"""Helpers for producing concise pyrefly diagnostics for CI diff output."""
from __future__ import annotations
import sys
_DIAGNOSTIC_PREFIXES = ("ERROR ", "WARNING ")
_LOCATION_PREFIX = "-->"
def extract_diagnostics(raw_output: str) -> str:
"""Extract stable diagnostic lines from pyrefly output.
The full pyrefly output includes code excerpts and carets, which create noisy
diffs. This helper keeps only:
- diagnostic headline lines (``ERROR ...`` / ``WARNING ...``)
- the following location line (``--> path:line:column``), when present
"""
lines = raw_output.splitlines()
diagnostics: list[str] = []
for index, line in enumerate(lines):
if line.startswith(_DIAGNOSTIC_PREFIXES):
diagnostics.append(line.rstrip())
next_index = index + 1
if next_index < len(lines):
next_line = lines[next_index]
if next_line.lstrip().startswith(_LOCATION_PREFIX):
diagnostics.append(next_line.rstrip())
if not diagnostics:
return ""
return "\n".join(diagnostics) + "\n"
def main() -> int:
"""Read pyrefly output from stdin and print normalized diagnostics."""
raw_output = sys.stdin.read()
sys.stdout.write(extract_diagnostics(raw_output))
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@ -787,7 +787,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
__tablename__ = "workflow_node_executions"
@declared_attr
@declared_attr.directive
@classmethod
def __table_args__(cls) -> Any:
return (

View File

@ -68,7 +68,7 @@ dependencies = [
"pydantic~=2.12.5",
"pydantic-extra-types~=2.10.3",
"pydantic-settings~=2.12.0",
"pyjwt~=2.10.1",
"pyjwt~=2.11.0",
"pypdfium2==5.2.0",
"python-docx~=1.2.0",
"python-dotenv==1.0.1",
@ -124,7 +124,7 @@ dev = [
"pytest-env~=1.1.3",
"pytest-mock~=3.14.0",
"testcontainers~=4.13.2",
"types-aiofiles~=24.1.0",
"types-aiofiles~=25.1.0",
"types-beautifulsoup4~=4.12.0",
"types-cachetools~=5.5.0",
"types-colorama~=0.4.15",

View File

@ -29,7 +29,7 @@ from typing import Any, cast
import sqlalchemy as sa
from pydantic import ValidationError
from sqlalchemy import and_, delete, func, null, or_, select
from sqlalchemy import and_, delete, func, null, or_, select, tuple_
from sqlalchemy.engine import CursorResult
from sqlalchemy.orm import Session, selectinload, sessionmaker
@ -423,9 +423,10 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
if last_seen:
stmt = stmt.where(
or_(
WorkflowRun.created_at > last_seen[0],
and_(WorkflowRun.created_at == last_seen[0], WorkflowRun.id > last_seen[1]),
tuple_(WorkflowRun.created_at, WorkflowRun.id)
> tuple_(
sa.literal(last_seen[0], type_=sa.DateTime()),
sa.literal(last_seen[1], type_=WorkflowRun.id.type),
)
)

View File

@ -1,6 +1,6 @@
import copy
from core.model_runtime.prompt.prompt_templates.advanced_prompt_templates import (
from core.prompt.prompt_templates.advanced_prompt_templates import (
BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG,
BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG,
BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG,

View File

@ -8,18 +8,18 @@ from core.app.app_config.entities import (
ExternalDataVariableEntity,
ModelConfigEntity,
PromptTemplateEntity,
VariableEntity,
)
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
from core.app.apps.chat.app_config_manager import ChatAppConfigManager
from core.app.apps.completion.app_config_manager import CompletionAppConfigManager
from core.helper import encrypter
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.prompt.simple_prompt_transform import SimplePromptTransform
from core.model_runtime.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.simple_prompt_transform import SimplePromptTransform
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.file.models import FileUploadConfig
from core.workflow.nodes import NodeType
from core.workflow.variables.input_entities import VariableEntity
from events.app_event import app_was_created
from extensions.ext_database import db
from models import Account

View File

@ -9,7 +9,6 @@ from sqlalchemy import exists, select
from sqlalchemy.orm import Session, sessionmaker
from configs import dify_config
from core.app.app_config.entities import VariableEntityType
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.app.entities.app_invoke_entities import InvokeFrom
@ -40,6 +39,7 @@ from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import load_into_variable_pool
from core.workflow.variables import VariableBase
from core.workflow.variables.input_entities import VariableEntityType
from core.workflow.variables.variables import Variable
from core.workflow.workflow_entry import WorkflowEntry
from enums.cloud_plan import CloudPlan

View File

@ -5,7 +5,7 @@ from unittest.mock import MagicMock
from core.app.entities.app_invoke_entities import InvokeFrom
from core.model_manager import ModelInstance
from core.model_runtime.entities import AssistantPromptMessage
from core.model_runtime.entities import AssistantPromptMessage, UserPromptMessage
from core.workflow.entities import GraphInitParams
from core.workflow.enums import WorkflowNodeExecutionStatus
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
@ -22,19 +22,17 @@ from tests.integration_tests.model_runtime.__mock.plugin_daemon import setup_mod
def get_mocked_fetch_memory(memory_text: str):
class MemoryMock:
def get_history_prompt_text(
def get_history_prompt_messages(
self,
human_prefix: str = "Human",
ai_prefix: str = "Assistant",
max_token_limit: int = 2000,
message_limit: int | None = None,
):
return memory_text
return [UserPromptMessage(content=memory_text), AssistantPromptMessage(content="mocked answer")]
return MagicMock(return_value=MemoryMock())
def init_parameter_extractor_node(config: dict):
def init_parameter_extractor_node(config: dict, memory=None):
graph_config = {
"edges": [
{
@ -79,6 +77,7 @@ def init_parameter_extractor_node(config: dict):
credentials_provider=MagicMock(spec=CredentialsProvider),
model_factory=MagicMock(spec=ModelFactory),
model_instance=MagicMock(spec=ModelInstance),
memory=memory,
)
return node
@ -350,7 +349,7 @@ def test_extract_json_from_tool_call():
assert result["location"] == "kawaii"
def test_chat_parameter_extractor_with_memory(setup_model_mock, monkeypatch):
def test_chat_parameter_extractor_with_memory(setup_model_mock):
"""
Test chat parameter extractor with memory.
"""
@ -373,6 +372,7 @@ def test_chat_parameter_extractor_with_memory(setup_model_mock, monkeypatch):
"memory": {"window": {"enabled": True, "size": 50}},
},
},
memory=get_mocked_fetch_memory("customized memory")(),
)
node._model_instance = get_mocked_fetch_model_instance(
@ -381,8 +381,6 @@ def test_chat_parameter_extractor_with_memory(setup_model_mock, monkeypatch):
mode="chat",
credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")},
)()
# Test the mock before running the actual test
monkeypatch.setattr("core.workflow.nodes.llm.llm_utils.fetch_memory", get_mocked_fetch_memory("customized memory"))
db.session.close = MagicMock()
result = node._run()

View File

@ -3,7 +3,7 @@ import copy
import pytest
from faker import Faker
from core.model_runtime.prompt.prompt_templates.advanced_prompt_templates import (
from core.prompt.prompt_templates.advanced_prompt_templates import (
BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG,
BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG,
BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG,

View File

@ -10,11 +10,10 @@ from core.app.app_config.entities import (
ExternalDataVariableEntity,
ModelConfigEntity,
PromptTemplateEntity,
VariableEntity,
VariableEntityType,
)
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.variables.input_entities import VariableEntity, VariableEntityType
from models import Account, Tenant
from models.api_based_extension import APIBasedExtension
from models.model import App, AppMode, AppModelConfig

View File

@ -147,8 +147,7 @@ class TestDisableSegmentsFromIndexTask:
document.cleaning_completed_at = fake.date_time_this_year()
document.splitting_completed_at = fake.date_time_this_year()
document.tokens = fake.random_int(min=50, max=500)
document.indexing_started_at = fake.date_time_this_year()
document.indexing_completed_at = fake.date_time_this_year()
document.completed_at = fake.date_time_this_year()
document.indexing_status = "completed"
document.enabled = True
document.archived = False

View File

@ -0,0 +1,184 @@
import datetime
import re
from unittest.mock import MagicMock, patch
import click
import pytest
from commands import clean_expired_messages
def _mock_service() -> MagicMock:
service = MagicMock()
service.run.return_value = {
"batches": 1,
"total_messages": 10,
"filtered_messages": 5,
"total_deleted": 5,
}
return service
def test_absolute_mode_calls_from_time_range():
policy = object()
service = _mock_service()
start_from = datetime.datetime(2024, 1, 1, 0, 0, 0)
end_before = datetime.datetime(2024, 2, 1, 0, 0, 0)
with (
patch("commands.create_message_clean_policy", return_value=policy),
patch("commands.MessagesCleanService.from_time_range", return_value=service) as mock_from_time_range,
patch("commands.MessagesCleanService.from_days") as mock_from_days,
):
clean_expired_messages.callback(
batch_size=200,
graceful_period=21,
start_from=start_from,
end_before=end_before,
from_days_ago=None,
before_days=None,
dry_run=True,
)
mock_from_time_range.assert_called_once_with(
policy=policy,
start_from=start_from,
end_before=end_before,
batch_size=200,
dry_run=True,
)
mock_from_days.assert_not_called()
def test_relative_mode_before_days_only_calls_from_days():
policy = object()
service = _mock_service()
with (
patch("commands.create_message_clean_policy", return_value=policy),
patch("commands.MessagesCleanService.from_days", return_value=service) as mock_from_days,
patch("commands.MessagesCleanService.from_time_range") as mock_from_time_range,
):
clean_expired_messages.callback(
batch_size=500,
graceful_period=14,
start_from=None,
end_before=None,
from_days_ago=None,
before_days=30,
dry_run=False,
)
mock_from_days.assert_called_once_with(
policy=policy,
days=30,
batch_size=500,
dry_run=False,
)
mock_from_time_range.assert_not_called()
def test_relative_mode_with_from_days_ago_calls_from_time_range():
policy = object()
service = _mock_service()
fixed_now = datetime.datetime(2024, 8, 20, 12, 0, 0)
with (
patch("commands.create_message_clean_policy", return_value=policy),
patch("commands.MessagesCleanService.from_time_range", return_value=service) as mock_from_time_range,
patch("commands.MessagesCleanService.from_days") as mock_from_days,
patch("commands.datetime", autospec=True) as mock_datetime,
):
mock_datetime.datetime.now.return_value = fixed_now
mock_datetime.timedelta = datetime.timedelta
clean_expired_messages.callback(
batch_size=1000,
graceful_period=21,
start_from=None,
end_before=None,
from_days_ago=60,
before_days=30,
dry_run=False,
)
mock_from_time_range.assert_called_once_with(
policy=policy,
start_from=fixed_now - datetime.timedelta(days=60),
end_before=fixed_now - datetime.timedelta(days=30),
batch_size=1000,
dry_run=False,
)
mock_from_days.assert_not_called()
@pytest.mark.parametrize(
("kwargs", "message"),
[
(
{
"start_from": datetime.datetime(2024, 1, 1),
"end_before": datetime.datetime(2024, 2, 1),
"from_days_ago": None,
"before_days": 30,
},
"mutually exclusive",
),
(
{
"start_from": datetime.datetime(2024, 1, 1),
"end_before": None,
"from_days_ago": None,
"before_days": None,
},
"Both --start-from and --end-before are required",
),
(
{
"start_from": None,
"end_before": None,
"from_days_ago": 10,
"before_days": None,
},
"--from-days-ago must be used together with --before-days",
),
(
{
"start_from": None,
"end_before": None,
"from_days_ago": None,
"before_days": -1,
},
"--before-days must be >= 0",
),
(
{
"start_from": None,
"end_before": None,
"from_days_ago": 30,
"before_days": 30,
},
"--from-days-ago must be greater than --before-days",
),
(
{
"start_from": None,
"end_before": None,
"from_days_ago": None,
"before_days": None,
},
"You must provide either (--start-from,--end-before) or (--before-days [--from-days-ago])",
),
],
)
def test_invalid_inputs_raise_usage_error(kwargs: dict, message: str):
with pytest.raises(click.UsageError, match=re.escape(message)):
clean_expired_messages.callback(
batch_size=1000,
graceful_period=21,
start_from=kwargs["start_from"],
end_before=kwargs["end_before"],
from_days_ago=kwargs["from_days_ago"],
before_days=kwargs["before_days"],
dry_run=False,
)

View File

@ -1,7 +1,7 @@
import pytest
from core.app.app_config.entities import VariableEntity, VariableEntityType
from core.app.apps.base_app_generator import BaseAppGenerator
from core.workflow.variables.input_entities import VariableEntity, VariableEntityType
def test_validate_inputs_with_zero():

View File

@ -4,7 +4,6 @@ from unittest.mock import Mock, patch
import jsonschema
import pytest
from core.app.app_config.entities import VariableEntity, VariableEntityType
from core.app.features.rate_limiting.rate_limit import RateLimitGenerator
from core.mcp import types
from core.mcp.server.streamable_http import (
@ -19,6 +18,7 @@ from core.mcp.server.streamable_http import (
prepare_tool_arguments,
process_mapping_response,
)
from core.workflow.variables.input_entities import VariableEntity, VariableEntityType
from models.model import App, AppMCPServer, AppMode, EndUser

View File

@ -11,13 +11,9 @@ from core.model_runtime.entities.message_entities import (
PromptMessageRole,
UserPromptMessage,
)
from core.model_runtime.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.model_runtime.prompt.entities.advanced_prompt_entities import (
ChatModelMessage,
CompletionModelPromptTemplate,
MemoryConfig,
)
from core.model_runtime.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.file import File, FileTransferMethod, FileType
from models.model import Conversation

View File

@ -12,7 +12,7 @@ from core.model_runtime.entities.message_entities import (
UserPromptMessage,
)
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
from models.model import Conversation

View File

@ -1,7 +1,7 @@
from uuid import uuid4
from constants import UUID_NIL
from core.model_runtime.prompt.utils.extract_thread_messages import extract_thread_messages
from core.prompt.utils.extract_thread_messages import extract_thread_messages
class MockMessage:

View File

@ -6,7 +6,7 @@
# from core.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey, ParameterRule
# from core.model_runtime.entities.provider_entities import ProviderEntity
# from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
# from core.model_runtime.prompt.prompt_transform import PromptTransform
# from core.prompt.prompt_transform import PromptTransform
# def test__calculate_rest_token():

View File

@ -3,7 +3,7 @@ from unittest.mock import MagicMock
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage
from core.model_runtime.prompt.simple_prompt_transform import SimplePromptTransform
from core.prompt.simple_prompt_transform import SimplePromptTransform
from models.model import AppMode, Conversation

View File

@ -0,0 +1,174 @@
import threading
from datetime import datetime
from unittest.mock import MagicMock, patch
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
from core.errors.error import QuotaExceededError
from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.graph_engine.entities.commands import CommandType
from core.workflow.graph_events.node import NodeRunSucceededEvent
from core.workflow.node_events import NodeRunResult
def _build_succeeded_event() -> NodeRunSucceededEvent:
return NodeRunSucceededEvent(
id="execution-id",
node_id="llm-node-id",
node_type=NodeType.LLM,
start_at=datetime.now(),
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs={"question": "hello"},
llm_usage=LLMUsage.empty_usage(),
),
)
def test_deduct_quota_called_for_successful_llm_node() -> None:
layer = LLMQuotaLayer()
node = MagicMock()
node.id = "llm-node-id"
node.execution_id = "execution-id"
node.node_type = NodeType.LLM
node.tenant_id = "tenant-id"
node.model_instance = object()
result_event = _build_succeeded_event()
with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct:
layer.on_node_run_end(node=node, error=None, result_event=result_event)
mock_deduct.assert_called_once_with(
tenant_id="tenant-id",
model_instance=node.model_instance,
usage=result_event.node_run_result.llm_usage,
)
def test_deduct_quota_called_for_question_classifier_node() -> None:
layer = LLMQuotaLayer()
node = MagicMock()
node.id = "question-classifier-node-id"
node.execution_id = "execution-id"
node.node_type = NodeType.QUESTION_CLASSIFIER
node.tenant_id = "tenant-id"
node.model_instance = object()
result_event = _build_succeeded_event()
with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct:
layer.on_node_run_end(node=node, error=None, result_event=result_event)
mock_deduct.assert_called_once_with(
tenant_id="tenant-id",
model_instance=node.model_instance,
usage=result_event.node_run_result.llm_usage,
)
def test_non_llm_node_is_ignored() -> None:
layer = LLMQuotaLayer()
node = MagicMock()
node.id = "start-node-id"
node.execution_id = "execution-id"
node.node_type = NodeType.START
node.tenant_id = "tenant-id"
node._model_instance = object()
result_event = _build_succeeded_event()
with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct:
layer.on_node_run_end(node=node, error=None, result_event=result_event)
mock_deduct.assert_not_called()
def test_quota_error_is_handled_in_layer() -> None:
layer = LLMQuotaLayer()
node = MagicMock()
node.id = "llm-node-id"
node.execution_id = "execution-id"
node.node_type = NodeType.LLM
node.tenant_id = "tenant-id"
node.model_instance = object()
result_event = _build_succeeded_event()
with patch(
"core.app.workflow.layers.llm_quota.deduct_llm_quota",
autospec=True,
side_effect=ValueError("quota exceeded"),
):
layer.on_node_run_end(node=node, error=None, result_event=result_event)
def test_quota_deduction_exceeded_aborts_workflow_immediately() -> None:
layer = LLMQuotaLayer()
stop_event = threading.Event()
layer.command_channel = MagicMock()
node = MagicMock()
node.id = "llm-node-id"
node.execution_id = "execution-id"
node.node_type = NodeType.LLM
node.tenant_id = "tenant-id"
node.model_instance = object()
node.graph_runtime_state = MagicMock()
node.graph_runtime_state.stop_event = stop_event
result_event = _build_succeeded_event()
with patch(
"core.app.workflow.layers.llm_quota.deduct_llm_quota",
autospec=True,
side_effect=QuotaExceededError("No credits remaining"),
):
layer.on_node_run_end(node=node, error=None, result_event=result_event)
assert stop_event.is_set()
layer.command_channel.send_command.assert_called_once()
abort_command = layer.command_channel.send_command.call_args.args[0]
assert abort_command.command_type == CommandType.ABORT
assert abort_command.reason == "No credits remaining"
def test_quota_precheck_failure_aborts_workflow_immediately() -> None:
layer = LLMQuotaLayer()
stop_event = threading.Event()
layer.command_channel = MagicMock()
node = MagicMock()
node.id = "llm-node-id"
node.node_type = NodeType.LLM
node.model_instance = object()
node.graph_runtime_state = MagicMock()
node.graph_runtime_state.stop_event = stop_event
with patch(
"core.app.workflow.layers.llm_quota.ensure_llm_quota_available",
autospec=True,
side_effect=QuotaExceededError("Model provider openai quota exceeded."),
):
layer.on_node_run_start(node)
assert stop_event.is_set()
layer.command_channel.send_command.assert_called_once()
abort_command = layer.command_channel.send_command.call_args.args[0]
assert abort_command.command_type == CommandType.ABORT
assert abort_command.reason == "Model provider openai quota exceeded."
def test_quota_precheck_passes_without_abort() -> None:
layer = LLMQuotaLayer()
stop_event = threading.Event()
layer.command_channel = MagicMock()
node = MagicMock()
node.id = "llm-node-id"
node.node_type = NodeType.LLM
node.model_instance = object()
node.graph_runtime_state = MagicMock()
node.graph_runtime_state.stop_event = stop_event
with patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available", autospec=True) as mock_check:
layer.on_node_run_start(node)
assert not stop_event.is_set()
mock_check.assert_called_once_with(model_instance=node.model_instance)
layer.command_channel.send_command.assert_not_called()

View File

@ -21,7 +21,7 @@ from core.model_runtime.entities.message_entities import (
)
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from core.model_runtime.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.workflow.entities import GraphInitParams
from core.workflow.file import File, FileTransferMethod, FileType
from core.workflow.nodes.llm import llm_utils

View File

@ -4,12 +4,12 @@ import time
import pytest
from pydantic import ValidationError as PydanticValidationError
from core.app.app_config.entities import VariableEntity, VariableEntityType
from core.workflow.entities import GraphInitParams
from core.workflow.nodes.start.entities import StartNodeData
from core.workflow.nodes.start.start_node import StartNode
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from core.workflow.variables.input_entities import VariableEntity, VariableEntityType
def make_start_node(user_inputs, variables):

View File

@ -0,0 +1,51 @@
from libs.pyrefly_diagnostics import extract_diagnostics
def test_extract_diagnostics_keeps_only_summary_and_location_lines() -> None:
# Arrange
raw_output = """INFO Checking project configured at `/tmp/project/pyrefly.toml`
ERROR `result` may be uninitialized [unbound-name]
--> controllers/console/app/annotation.py:126:16
|
126 | return result, 200
| ^^^^^^
|
ERROR Object of class `App` has no attribute `access_mode` [missing-attribute]
--> controllers/console/app/app.py:574:13
|
574 | app_model.access_mode = app_setting.access_mode
| ^^^^^^^^^^^^^^^^^^^^^
"""
# Act
diagnostics = extract_diagnostics(raw_output)
# Assert
assert diagnostics == (
"ERROR `result` may be uninitialized [unbound-name]\n"
" --> controllers/console/app/annotation.py:126:16\n"
"ERROR Object of class `App` has no attribute `access_mode` [missing-attribute]\n"
" --> controllers/console/app/app.py:574:13\n"
)
def test_extract_diagnostics_handles_error_without_location_line() -> None:
# Arrange
raw_output = "ERROR unexpected pyrefly output format [bad-format]\n"
# Act
diagnostics = extract_diagnostics(raw_output)
# Assert
assert diagnostics == "ERROR unexpected pyrefly output format [bad-format]\n"
def test_extract_diagnostics_returns_empty_for_non_error_output() -> None:
# Arrange
raw_output = "INFO Checking project configured at `/tmp/project/pyrefly.toml`\n"
# Act
diagnostics = extract_diagnostics(raw_output)
# Assert
assert diagnostics == ""

View File

@ -13,12 +13,11 @@ from core.app.app_config.entities import (
ExternalDataVariableEntity,
ModelConfigEntity,
PromptTemplateEntity,
VariableEntity,
VariableEntityType,
)
from core.helper import encrypter
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.entities.message_entities import PromptMessageRole
from core.workflow.variables.input_entities import VariableEntity, VariableEntityType
from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint
from models.model import AppMode
from services.workflow.workflow_converter import WorkflowConverter

16
api/uv.lock generated
View File

@ -1636,7 +1636,7 @@ requires-dist = [
{ name = "pydantic", specifier = "~=2.12.5" },
{ name = "pydantic-extra-types", specifier = "~=2.10.3" },
{ name = "pydantic-settings", specifier = "~=2.12.0" },
{ name = "pyjwt", specifier = "~=2.10.1" },
{ name = "pyjwt", specifier = "~=2.11.0" },
{ name = "pypdfium2", specifier = "==5.2.0" },
{ name = "python-docx", specifier = "~=1.2.0" },
{ name = "python-dotenv", specifier = "==1.0.1" },
@ -1683,7 +1683,7 @@ dev = [
{ name = "scipy-stubs", specifier = ">=1.15.3.0" },
{ name = "sseclient-py", specifier = ">=1.8.0" },
{ name = "testcontainers", specifier = "~=4.13.2" },
{ name = "types-aiofiles", specifier = "~=24.1.0" },
{ name = "types-aiofiles", specifier = "~=25.1.0" },
{ name = "types-beautifulsoup4", specifier = "~=4.12.0" },
{ name = "types-cachetools", specifier = "~=5.5.0" },
{ name = "types-cffi", specifier = ">=1.17.0" },
@ -4957,11 +4957,11 @@ wheels = [
[[package]]
name = "pyjwt"
version = "2.10.1"
version = "2.11.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/e7/46/bd74733ff231675599650d3e47f361794b22ef3e3770998dda30d3b63726/pyjwt-2.10.1.tar.gz", hash = "sha256:3cc5772eb20009233caf06e9d8a0577824723b44e6648ee0a2aedb6cf9381953", size = 87785, upload-time = "2024-11-28T03:43:29.933Z" }
sdist = { url = "https://files.pythonhosted.org/packages/5c/5a/b46fa56bf322901eee5b0454a34343cdbdae202cd421775a8ee4e42fd519/pyjwt-2.11.0.tar.gz", hash = "sha256:35f95c1f0fbe5d5ba6e43f00271c275f7a1a4db1dab27bf708073b75318ea623", size = 98019, upload-time = "2026-01-30T19:59:55.694Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/61/ad/689f02752eeec26aed679477e80e632ef1b682313be70793d798c1d5fc8f/PyJWT-2.10.1-py3-none-any.whl", hash = "sha256:dcdd193e30abefd5debf142f9adfcdd2b58004e644f25406ffaebd50bd98dacb", size = 22997, upload-time = "2024-11-28T03:43:27.893Z" },
{ url = "https://files.pythonhosted.org/packages/6f/01/c26ce75ba460d5cd503da9e13b21a33804d38c2165dec7b716d06b13010c/pyjwt-2.11.0-py3-none-any.whl", hash = "sha256:94a6bde30eb5c8e04fee991062b534071fd1439ef58d2adc9ccb823e7bcd0469", size = 28224, upload-time = "2026-01-30T19:59:54.539Z" },
]
[package.optional-dependencies]
@ -6293,11 +6293,11 @@ wheels = [
[[package]]
name = "types-aiofiles"
version = "24.1.0.20250822"
version = "25.1.0.20251011"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/19/48/c64471adac9206cc844afb33ed311ac5a65d2f59df3d861e0f2d0cad7414/types_aiofiles-24.1.0.20250822.tar.gz", hash = "sha256:9ab90d8e0c307fe97a7cf09338301e3f01a163e39f3b529ace82466355c84a7b", size = 14484, upload-time = "2025-08-22T03:02:23.039Z" }
sdist = { url = "https://files.pythonhosted.org/packages/84/6c/6d23908a8217e36704aa9c79d99a620f2fdd388b66a4b7f72fbc6b6ff6c6/types_aiofiles-25.1.0.20251011.tar.gz", hash = "sha256:1c2b8ab260cb3cd40c15f9d10efdc05a6e1e6b02899304d80dfa0410e028d3ff", size = 14535, upload-time = "2025-10-11T02:44:51.237Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/bc/8e/5e6d2215e1d8f7c2a94c6e9d0059ae8109ce0f5681956d11bb0a228cef04/types_aiofiles-24.1.0.20250822-py3-none-any.whl", hash = "sha256:0ec8f8909e1a85a5a79aed0573af7901f53120dd2a29771dd0b3ef48e12328b0", size = 14322, upload-time = "2025-08-22T03:02:21.918Z" },
{ url = "https://files.pythonhosted.org/packages/71/0f/76917bab27e270bb6c32addd5968d69e558e5b6f7fb4ac4cbfa282996a96/types_aiofiles-25.1.0.20251011-py3-none-any.whl", hash = "sha256:8ff8de7f9d42739d8f0dadcceeb781ce27cd8d8c4152d4a7c52f6b20edb8149c", size = 14338, upload-time = "2025-10-11T02:44:50.054Z" },
]
[[package]]

View File

@ -1,5 +1,5 @@
import { fireEvent, render, screen } from '@testing-library/react'
import Alert from './alert'
import Alert from '../alert'
describe('Alert', () => {
const defaultProps = {

View File

@ -1,5 +1,5 @@
import { render, screen } from '@testing-library/react'
import AppUnavailable from './app-unavailable'
import AppUnavailable from '../app-unavailable'
describe('AppUnavailable', () => {
beforeEach(() => {

View File

@ -1,5 +1,5 @@
import { render, screen } from '@testing-library/react'
import Badge from './badge'
import Badge from '../badge'
describe('Badge', () => {
describe('Rendering', () => {

View File

@ -1,5 +1,5 @@
import { fireEvent, render, screen } from '@testing-library/react'
import ThemeSelector from './theme-selector'
import ThemeSelector from '../theme-selector'
// Mock next-themes with controllable state
let mockTheme = 'system'

View File

@ -1,5 +1,5 @@
import { fireEvent, render, screen } from '@testing-library/react'
import ThemeSwitcher from './theme-switcher'
import ThemeSwitcher from '../theme-switcher'
let mockTheme = 'system'
const mockSetTheme = vi.fn()

View File

@ -1,5 +1,5 @@
import { render, screen } from '@testing-library/react'
import { ActionButton, ActionButtonState } from './index'
import { ActionButton, ActionButtonState } from '../index'
describe('ActionButton', () => {
it('renders button with default props', () => {

Some files were not shown because too many files have changed in this diff Show More