Compare commits

..

22 Commits

Author SHA1 Message Date
662f6e3e68 Merge branch 'main' into feat/new-biliing-quota 2026-04-17 18:28:18 +08:00
ae01a5d137 fix: unit test mock 2026-04-08 14:42:52 +08:00
ad6670ebcc fix: correct quota info response 2026-04-08 14:23:57 +08:00
8ca0917044 Merge branch 'main' into feat/new-biliing-quota 2026-04-08 13:39:24 +08:00
b3870524d4 fix usage get 2026-04-02 09:52:52 +08:00
c543188434 fix linter 2026-03-31 15:22:51 +08:00
f319a9e42f fix test case 2026-03-31 15:22:43 +08:00
58241a89a5 fix linter 2026-03-31 14:59:54 +08:00
422bf3506e rebuild quota service 2026-03-31 14:59:45 +08:00
6e745f9e9b fix linter 2026-03-31 09:49:24 +08:00
4e50d55339 fix comment 2026-03-31 09:49:09 +08:00
b95cdabe26 [autofix.ci] apply automated fixes 2026-03-30 08:45:37 +00:00
daa47c25bb Merge branch 'feat/new-biliing-quota' of github.com:langgenius/dify into feat/new-biliing-quota 2026-03-30 16:43:13 +08:00
f1bcd6d715 add test case for quota and billing service 2026-03-30 16:41:56 +08:00
8643ff43f5 Merge branch 'main' into feat/new-biliing-quota 2026-03-30 15:57:49 +08:00
c5f30a47f0 Merge remote-tracking branch 'origin/main' into feat/new-biliing-quota 2026-03-30 15:26:38 +08:00
37d438fa19 Merge remote-tracking branch 'origin/main' into feat/new-biliing-quota 2026-03-27 16:26:09 +08:00
9503803997 Merge remote-tracking branch 'origin/main' into feat/new-biliing-quota 2026-03-23 09:27:39 +08:00
d6476f5434 Merge remote-tracking branch 'origin/main' into feat/new-biliing-quota 2026-03-20 15:17:27 +08:00
80b4633e8f fix style check and test 2026-03-20 14:58:31 +08:00
3888969af3 [autofix.ci] apply automated fixes 2026-03-20 05:45:30 +00:00
658ac15589 use new quota system 2026-03-20 13:29:22 +08:00
1440 changed files with 13085 additions and 16699 deletions

View File

@ -76,11 +76,13 @@ jobs:
diff += '\\n\\n... (truncated) ...';
}
if (diff.trim()) {
await github.rest.issues.createComment({
issue_number: prNumber,
owner: context.repo.owner,
repo: context.repo.repo,
body: '### Pyrefly Diff\n<details>\n<summary>base → PR</summary>\n\n```diff\n' + diff + '\n```\n</details>',
});
}
const body = diff.trim()
? '### Pyrefly Diff\n<details>\n<summary>base → PR</summary>\n\n```diff\n' + diff + '\n```\n</details>'
: '### Pyrefly Diff\nNo changes detected.';
await github.rest.issues.createComment({
issue_number: prNumber,
owner: context.repo.owner,
repo: context.repo.repo,
body,
});

View File

@ -89,37 +89,3 @@ jobs:
flags: web
env:
CODECOV_TOKEN: ${{ env.CODECOV_TOKEN }}
dify-ui-test:
name: dify-ui Tests
runs-on: ubuntu-latest
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
defaults:
run:
shell: bash
working-directory: ./packages/dify-ui
steps:
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Setup web environment
uses: ./.github/actions/setup-web
- name: Install Chromium for Browser Mode
run: vp exec playwright install --with-deps chromium
- name: Run dify-ui tests
run: vp test run --coverage --silent=passed-only
- name: Report coverage
if: ${{ env.CODECOV_TOKEN != '' }}
uses: codecov/codecov-action@57e3a136b779b570ffcdbf80b3bdc90e7fab3de2 # v6.0.0
with:
directory: packages/dify-ui/coverage
flags: dify-ui
env:
CODECOV_TOKEN: ${{ env.CODECOV_TOKEN }}

View File

@ -2,7 +2,6 @@ import base64
import secrets
import click
from sqlalchemy.orm import Session
from constants.languages import languages
from extensions.ext_database import db
@ -44,11 +43,10 @@ def reset_password(email, new_password, password_confirm):
# encrypt password with salt
password_hashed = hash_password(new_password, salt)
base64_password_hashed = base64.b64encode(password_hashed).decode()
with Session(db.engine) as session:
account = session.merge(account)
account.password = base64_password_hashed
account.password_salt = base64_salt
session.commit()
account = db.session.merge(account)
account.password = base64_password_hashed
account.password_salt = base64_salt
db.session.commit()
AccountService.reset_login_error_rate_limit(normalized_email)
click.echo(click.style("Password reset successfully.", fg="green"))
@ -79,10 +77,9 @@ def reset_email(email, new_email, email_confirm):
click.echo(click.style(f"Invalid email: {new_email}", fg="red"))
return
with Session(db.engine) as session:
account = session.merge(account)
account.email = normalized_new_email
session.commit()
account = db.session.merge(account)
account.email = normalized_new_email
db.session.commit()
click.echo(click.style("Email updated successfully.", fg="green"))

View File

@ -1 +0,0 @@
CURRENT_APP_DSL_VERSION = "0.6.0"

View File

@ -45,7 +45,7 @@ class ConversationVariableResponse(ResponseModel):
def _normalize_value_type(cls, value: Any) -> str:
exposed_type = getattr(value, "exposed_type", None)
if callable(exposed_type):
return str(exposed_type())
return str(exposed_type().value)
if isinstance(value, str):
return value
try:

View File

@ -102,7 +102,7 @@ def _serialize_var_value(variable: WorkflowDraftVariable):
def _serialize_variable_type(workflow_draft_var: WorkflowDraftVariable) -> str:
value_type = workflow_draft_var.value_type
return str(value_type.exposed_type())
return value_type.exposed_type().value
class FullContentDict(TypedDict):
@ -122,7 +122,7 @@ def _serialize_full_content(variable: WorkflowDraftVariable) -> FullContentDict
result: FullContentDict = {
"size_bytes": variable_file.size,
"value_type": str(variable_file.value_type.exposed_type()),
"value_type": variable_file.value_type.exposed_type().value,
"length": variable_file.length,
"download_url": file_helpers.get_signed_file_url(variable_file.upload_file_id, as_attachment=True),
}
@ -598,7 +598,7 @@ class EnvironmentVariableCollectionApi(Resource):
"name": v.name,
"description": v.description,
"selector": v.selector,
"value_type": str(v.value_type.exposed_type()),
"value_type": v.value_type.exposed_type().value,
"value": v.value,
# Do not track edited for env vars.
"edited": False,

View File

@ -20,13 +20,10 @@ class TenantUserPayload(BaseModel):
def get_user(tenant_id: str, user_id: str | None) -> EndUser:
"""
Get current user.
Get current user
NOTE: user_id is not trusted, it could be maliciously set to any value.
As a result, it could only be considered as an end user id. Even when a
concrete end-user ID is supplied, lookups must stay tenant-scoped so one
tenant cannot bind another tenant's user record into the plugin request
context.
As a result, it could only be considered as an end user id.
"""
if not user_id:
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
@ -45,14 +42,7 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
.limit(1)
)
else:
user_model = session.scalar(
select(EndUser)
.where(
EndUser.id == user_id,
EndUser.tenant_id == tenant_id,
)
.limit(1)
)
user_model = session.get(EndUser, user_id)
if not user_model:
user_model = EndUser(

View File

@ -84,10 +84,10 @@ class ConversationVariableResponse(ResponseModel):
def normalize_value_type(cls, value: Any) -> str:
exposed_type = getattr(value, "exposed_type", None)
if callable(exposed_type):
return str(exposed_type())
return str(exposed_type().value)
if isinstance(value, str):
try:
return str(SegmentType(value).exposed_type())
return str(SegmentType(value).exposed_type().value)
except ValueError:
return value
try:

View File

@ -42,7 +42,7 @@ from graphon.model_runtime.entities import (
)
from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
from graphon.model_runtime.entities.model_entities import ModelFeature
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from models.enums import CreatorUserRole
from models.model import Conversation, Message, MessageAgentThought, MessageFile

View File

@ -299,9 +299,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
# update prompt tool
for prompt_tool in prompt_messages_tools:
tool_instance = tool_instances.get(prompt_tool.name)
if tool_instance:
self.update_prompt_message_tool(tool_instance, prompt_tool)
self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
iteration_step += 1

View File

@ -7,7 +7,7 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
from graphon.model_runtime.entities.llm_entities import LLMMode
from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
class ModelConfigConverter:

View File

@ -18,7 +18,7 @@ from core.moderation.base import ModerationError
from extensions.ext_database import db
from graphon.model_runtime.entities.llm_entities import LLMMode
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from models.model import App, Conversation, Message
logger = logging.getLogger(__name__)

View File

@ -59,7 +59,7 @@ from graphon.model_runtime.entities.message_entities import (
AssistantPromptMessage,
TextPromptMessageContent,
)
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from libs.datetime_utils import naive_utc_now
from models.model import AppMode, Conversation, Message, MessageAgentThought, MessageFile, UploadFile

View File

@ -12,14 +12,13 @@ from typing import TYPE_CHECKING, Literal
from configs import dify_config
from core.app.file_access import DatabaseFileAccessController, FileAccessControllerProtocol
from core.db.session_factory import session_factory
from core.helper.ssrf_proxy import graphon_ssrf_proxy
from core.helper.ssrf_proxy import ssrf_proxy
from core.tools.signature import sign_tool_file
from core.workflow.file_reference import parse_file_reference
from extensions.ext_storage import storage
from graphon.file import FileTransferMethod
from graphon.file.protocols import WorkflowFileRuntimeProtocol
from graphon.file.protocols import HttpResponseProtocol, WorkflowFileRuntimeProtocol
from graphon.file.runtime import set_workflow_file_runtime
from graphon.http.protocols import HttpResponseProtocol
if TYPE_CHECKING:
from graphon.file import File
@ -44,7 +43,7 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol):
return dify_config.MULTIMODAL_SEND_FORMAT
def http_get(self, url: str, *, follow_redirects: bool = True) -> HttpResponseProtocol:
return graphon_ssrf_proxy.get(url, follow_redirects=follow_redirects)
return ssrf_proxy.get(url, follow_redirects=follow_redirects)
def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator:
return storage.load(path, stream=stream)

View File

@ -349,7 +349,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
execution.total_tokens = runtime_state.total_tokens
execution.total_steps = runtime_state.node_run_steps
execution.outputs = execution.outputs or runtime_state.outputs
execution.exceptions_count = max(execution.exceptions_count, runtime_state.exceptions_count)
execution.exceptions_count = runtime_state.exceptions_count
def _update_node_execution(
self,

View File

@ -352,11 +352,11 @@ class DatasourceManager:
raise ValueError(f"UploadFile not found for file_id={file_id}, tenant_id={tenant_id}")
file_info = File(
file_id=upload_file.id,
id=upload_file.id,
filename=upload_file.name,
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
file_type=FileType.CUSTOM,
type=FileType.CUSTOM,
transfer_method=FileTransferMethod.LOCAL_FILE,
remote_url=upload_file.source_url,
reference=build_file_reference(record_id=str(upload_file.id)),

View File

@ -31,7 +31,7 @@ from graphon.model_runtime.entities.provider_entities import (
FormType,
ProviderEntity,
)
from graphon.model_runtime.model_providers.base.ai_model import AIModel
from graphon.model_runtime.model_providers.__base.ai_model import AIModel
from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from graphon.model_runtime.runtime import ModelRuntime
from libs.datetime_utils import naive_utc_now
@ -318,28 +318,34 @@ class ProviderConfiguration(BaseModel):
else [],
)
def validate_provider_credentials(self, credentials: dict[str, Any], credential_id: str = ""):
def validate_provider_credentials(
self, credentials: dict[str, Any], credential_id: str = "", session: Session | None = None
):
"""
Validate custom credentials.
:param credentials: provider credentials
:param credential_id: (Optional)If provided, can use existing credential's hidden api key to validate
:param session: optional database session
:return:
"""
provider_credential_secret_variables = self.extract_secret_variables(
self.provider.provider_credential_schema.credential_form_schemas
if self.provider.provider_credential_schema
else []
)
if credential_id:
with Session(db.engine) as session:
def _validate(s: Session):
# Get provider credential secret variables
provider_credential_secret_variables = self.extract_secret_variables(
self.provider.provider_credential_schema.credential_form_schemas
if self.provider.provider_credential_schema
else []
)
if credential_id:
try:
stmt = select(ProviderCredential).where(
ProviderCredential.tenant_id == self.tenant_id,
ProviderCredential.provider_name.in_(self._get_provider_names()),
ProviderCredential.id == credential_id,
)
credential_record = session.execute(stmt).scalar_one_or_none()
credential_record = s.execute(stmt).scalar_one_or_none()
# fix origin data
if credential_record and credential_record.encrypted_config:
if not credential_record.encrypted_config.startswith("{"):
original_credentials = {"openai_api_key": credential_record.encrypted_config}
@ -350,23 +356,31 @@ class ProviderConfiguration(BaseModel):
except JSONDecodeError:
original_credentials = {}
for key, value in credentials.items():
# encrypt credentials
for key, value in credentials.items():
if key in provider_credential_secret_variables:
# if send [__HIDDEN__] in secret input, it will be same as original value
if value == HIDDEN_VALUE and key in original_credentials:
credentials[key] = encrypter.decrypt_token(
tenant_id=self.tenant_id, token=original_credentials[key]
)
model_provider_factory = self.get_model_provider_factory()
validated_credentials = model_provider_factory.provider_credentials_validate(
provider=self.provider.provider, credentials=credentials
)
for key, value in validated_credentials.items():
if key in provider_credential_secret_variables:
if value == HIDDEN_VALUE and key in original_credentials:
credentials[key] = encrypter.decrypt_token(
tenant_id=self.tenant_id, token=original_credentials[key]
)
validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
model_provider_factory = self.get_model_provider_factory()
validated_credentials = model_provider_factory.provider_credentials_validate(
provider=self.provider.provider, credentials=credentials
)
return validated_credentials
for key, value in validated_credentials.items():
if key in provider_credential_secret_variables and isinstance(value, str):
validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
return validated_credentials
if session:
return _validate(session)
else:
with Session(db.engine) as new_session:
return _validate(new_session)
def _generate_provider_credential_name(self, session) -> str:
"""
@ -443,16 +457,14 @@ class ProviderConfiguration(BaseModel):
:param credential_name: credential name
:return:
"""
with Session(db.engine) as pre_session:
with Session(db.engine) as session:
if credential_name:
if self._check_provider_credential_name_exists(credential_name=credential_name, session=pre_session):
if self._check_provider_credential_name_exists(credential_name=credential_name, session=session):
raise ValueError(f"Credential with name '{credential_name}' already exists.")
else:
credential_name = self._generate_provider_credential_name(pre_session)
credential_name = self._generate_provider_credential_name(session)
credentials = self.validate_provider_credentials(credentials=credentials)
with Session(db.engine) as session:
credentials = self.validate_provider_credentials(credentials=credentials, session=session)
provider_record = self._get_provider_record(session)
try:
new_record = ProviderCredential(
@ -465,6 +477,7 @@ class ProviderConfiguration(BaseModel):
session.flush()
if not provider_record:
# If provider record does not exist, create it
provider_record = Provider(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
@ -517,15 +530,15 @@ class ProviderConfiguration(BaseModel):
:param credential_name: credential name
:return:
"""
with Session(db.engine) as pre_session:
with Session(db.engine) as session:
if credential_name and self._check_provider_credential_name_exists(
credential_name=credential_name, session=pre_session, exclude_id=credential_id
credential_name=credential_name, session=session, exclude_id=credential_id
):
raise ValueError(f"Credential with name '{credential_name}' already exists.")
credentials = self.validate_provider_credentials(credentials=credentials, credential_id=credential_id)
with Session(db.engine) as session:
credentials = self.validate_provider_credentials(
credentials=credentials, credential_id=credential_id, session=session
)
provider_record = self._get_provider_record(session)
stmt = select(ProviderCredential).where(
ProviderCredential.id == credential_id,
@ -533,10 +546,12 @@ class ProviderConfiguration(BaseModel):
ProviderCredential.provider_name.in_(self._get_provider_names()),
)
# Get the credential record to update
credential_record = session.execute(stmt).scalar_one_or_none()
if not credential_record:
raise ValueError("Credential record not found.")
try:
# Update credential
credential_record.encrypted_config = json.dumps(credentials)
credential_record.updated_at = naive_utc_now()
if credential_name:
@ -864,6 +879,7 @@ class ProviderConfiguration(BaseModel):
model: str,
credentials: dict[str, Any],
credential_id: str = "",
session: Session | None = None,
):
"""
Validate custom model credentials.
@ -874,14 +890,16 @@ class ProviderConfiguration(BaseModel):
:param credential_id: (Optional)If provided, can use existing credential's hidden api key to validate
:return:
"""
provider_credential_secret_variables = self.extract_secret_variables(
self.provider.model_credential_schema.credential_form_schemas
if self.provider.model_credential_schema
else []
)
if credential_id:
with Session(db.engine) as session:
def _validate(s: Session):
# Get provider credential secret variables
provider_credential_secret_variables = self.extract_secret_variables(
self.provider.model_credential_schema.credential_form_schemas
if self.provider.model_credential_schema
else []
)
if credential_id:
try:
stmt = select(ProviderModelCredential).where(
ProviderModelCredential.id == credential_id,
@ -890,7 +908,7 @@ class ProviderConfiguration(BaseModel):
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type,
)
credential_record = session.execute(stmt).scalar_one_or_none()
credential_record = s.execute(stmt).scalar_one_or_none()
original_credentials = (
json.loads(credential_record.encrypted_config)
if credential_record and credential_record.encrypted_config
@ -899,23 +917,31 @@ class ProviderConfiguration(BaseModel):
except JSONDecodeError:
original_credentials = {}
for key, value in credentials.items():
# decrypt credentials
for key, value in credentials.items():
if key in provider_credential_secret_variables:
# if send [__HIDDEN__] in secret input, it will be same as original value
if value == HIDDEN_VALUE and key in original_credentials:
credentials[key] = encrypter.decrypt_token(
tenant_id=self.tenant_id, token=original_credentials[key]
)
model_provider_factory = self.get_model_provider_factory()
validated_credentials = model_provider_factory.model_credentials_validate(
provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
)
for key, value in validated_credentials.items():
if key in provider_credential_secret_variables:
if value == HIDDEN_VALUE and key in original_credentials:
credentials[key] = encrypter.decrypt_token(
tenant_id=self.tenant_id, token=original_credentials[key]
)
validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
model_provider_factory = self.get_model_provider_factory()
validated_credentials = model_provider_factory.model_credentials_validate(
provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
)
return validated_credentials
for key, value in validated_credentials.items():
if key in provider_credential_secret_variables and isinstance(value, str):
validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
return validated_credentials
if session:
return _validate(session)
else:
with Session(db.engine) as new_session:
return _validate(new_session)
def create_custom_model_credential(
self, model_type: ModelType, model: str, credentials: dict[str, Any], credential_name: str | None
@ -928,22 +954,20 @@ class ProviderConfiguration(BaseModel):
:param credentials: model credentials dict
:return:
"""
with Session(db.engine) as pre_session:
with Session(db.engine) as session:
if credential_name:
if self._check_custom_model_credential_name_exists(
model=model, model_type=model_type, credential_name=credential_name, session=pre_session
model=model, model_type=model_type, credential_name=credential_name, session=session
):
raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.")
else:
credential_name = self._generate_custom_model_credential_name(
model=model, model_type=model_type, session=pre_session
model=model, model_type=model_type, session=session
)
credentials = self.validate_custom_model_credentials(
model_type=model_type, model=model, credentials=credentials
)
with Session(db.engine) as session:
# validate custom model config
credentials = self.validate_custom_model_credentials(
model_type=model_type, model=model, credentials=credentials, session=session
)
provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session)
try:
@ -958,6 +982,7 @@ class ProviderConfiguration(BaseModel):
session.add(credential)
session.flush()
# save provider model
if not provider_model_record:
provider_model_record = ProviderModel(
tenant_id=self.tenant_id,
@ -999,24 +1024,23 @@ class ProviderConfiguration(BaseModel):
:param credential_id: credential id
:return:
"""
with Session(db.engine) as pre_session:
with Session(db.engine) as session:
if credential_name and self._check_custom_model_credential_name_exists(
model=model,
model_type=model_type,
credential_name=credential_name,
session=pre_session,
session=session,
exclude_id=credential_id,
):
raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.")
credentials = self.validate_custom_model_credentials(
model_type=model_type,
model=model,
credentials=credentials,
credential_id=credential_id,
)
with Session(db.engine) as session:
# validate custom model config
credentials = self.validate_custom_model_credentials(
model_type=model_type,
model=model,
credentials=credentials,
credential_id=credential_id,
session=session,
)
provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session)
stmt = select(ProviderModelCredential).where(
@ -1031,6 +1055,7 @@ class ProviderConfiguration(BaseModel):
raise ValueError("Credential record not found.")
try:
# Update credential
credential_record.encrypted_config = json.dumps(credentials)
credential_record.updated_at = naive_utc_now()
if credential_name:

View File

@ -102,7 +102,7 @@ class TemplateTransformer(ABC):
@classmethod
def serialize_inputs(cls, inputs: Mapping[str, Any]) -> str:
inputs_json_str = dumps_with_segments(inputs).encode()
inputs_json_str = dumps_with_segments(inputs, ensure_ascii=False).encode()
input_base64_encoded = b64encode(inputs_json_str).decode("utf-8")
return input_base64_encoded

View File

@ -8,7 +8,7 @@ from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_
from extensions.ext_hosting_provider import hosting_configuration
from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.errors.invoke import InvokeBadRequestError
from graphon.model_runtime.model_providers.base.moderation_model import ModerationModel
from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel
from models.provider import ProviderType
logger = logging.getLogger(__name__)

View File

@ -12,7 +12,6 @@ from pydantic import TypeAdapter, ValidationError
from configs import dify_config
from core.helper.http_client_pooling import get_pooled_http_client
from core.tools.errors import ToolSSRFError
from graphon.http.response import HttpResponse
logger = logging.getLogger(__name__)
@ -268,47 +267,4 @@ class SSRFProxy:
return patch(url=url, max_retries=max_retries, **kwargs)
def _to_graphon_http_response(response: httpx.Response) -> HttpResponse:
"""Convert an ``httpx`` response into Graphon's transport-agnostic wrapper."""
return HttpResponse(
status_code=response.status_code,
headers=dict(response.headers),
content=response.content,
url=str(response.url) if response.url else None,
reason_phrase=response.reason_phrase,
fallback_text=response.text,
)
class GraphonSSRFProxy:
"""Adapter exposing SSRF helpers behind Graphon's ``HttpClientProtocol``."""
@property
def max_retries_exceeded_error(self) -> type[Exception]:
return max_retries_exceeded_error
@property
def request_error(self) -> type[Exception]:
return request_error
def get(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
return _to_graphon_http_response(get(url=url, max_retries=max_retries, **kwargs))
def head(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
return _to_graphon_http_response(head(url=url, max_retries=max_retries, **kwargs))
def post(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
return _to_graphon_http_response(post(url=url, max_retries=max_retries, **kwargs))
def put(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
return _to_graphon_http_response(put(url=url, max_retries=max_retries, **kwargs))
def delete(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
return _to_graphon_http_response(delete(url=url, max_retries=max_retries, **kwargs))
def patch(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
return _to_graphon_http_response(patch(url=url, max_retries=max_retries, **kwargs))
ssrf_proxy = SSRFProxy()
graphon_ssrf_proxy = GraphonSSRFProxy()

View File

@ -1,6 +1,6 @@
import logging
from collections.abc import Callable, Generator, Iterable, Mapping, Sequence
from typing import IO, Any, Literal, Optional, ParamSpec, TypeVar, Union, cast, overload
from typing import IO, Any, Literal, Optional, Union, cast, overload
from configs import dify_config
from core.entities import PluginCredentialType
@ -18,17 +18,15 @@ from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelFe
from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from graphon.model_runtime.model_providers.base.moderation_model import ModerationModel
from graphon.model_runtime.model_providers.base.rerank_model import RerankModel
from graphon.model_runtime.model_providers.base.speech2text_model import Speech2TextModel
from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel
from graphon.model_runtime.model_providers.base.tts_model import TTSModel
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel
from graphon.model_runtime.model_providers.__base.rerank_model import RerankModel
from graphon.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from graphon.model_runtime.model_providers.__base.tts_model import TTSModel
from models.provider import ProviderType
logger = logging.getLogger(__name__)
P = ParamSpec("P")
R = TypeVar("R")
class ModelInstance:
@ -170,7 +168,7 @@ class ModelInstance:
return cast(
Union[LLMResult, Generator],
self._round_robin_invoke(
self.model_type_instance.invoke,
function=self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
prompt_messages=list(prompt_messages),
@ -195,7 +193,7 @@ class ModelInstance:
if not isinstance(self.model_type_instance, LargeLanguageModel):
raise Exception("Model type instance is not LargeLanguageModel")
return self._round_robin_invoke(
self.model_type_instance.get_num_tokens,
function=self.model_type_instance.get_num_tokens,
model=self.model_name,
credentials=self.credentials,
prompt_messages=list(prompt_messages),
@ -215,7 +213,7 @@ class ModelInstance:
if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel")
return self._round_robin_invoke(
self.model_type_instance.invoke,
function=self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
texts=texts,
@ -237,7 +235,7 @@ class ModelInstance:
if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel")
return self._round_robin_invoke(
self.model_type_instance.invoke,
function=self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
multimodel_documents=multimodel_documents,
@ -254,7 +252,7 @@ class ModelInstance:
if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel")
return self._round_robin_invoke(
self.model_type_instance.get_num_tokens,
function=self.model_type_instance.get_num_tokens,
model=self.model_name,
credentials=self.credentials,
texts=texts,
@ -279,7 +277,7 @@ class ModelInstance:
if not isinstance(self.model_type_instance, RerankModel):
raise Exception("Model type instance is not RerankModel")
return self._round_robin_invoke(
self.model_type_instance.invoke,
function=self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
query=query,
@ -307,7 +305,7 @@ class ModelInstance:
if not isinstance(self.model_type_instance, RerankModel):
raise Exception("Model type instance is not RerankModel")
return self._round_robin_invoke(
self.model_type_instance.invoke_multimodal_rerank,
function=self.model_type_instance.invoke_multimodal_rerank,
model=self.model_name,
credentials=self.credentials,
query=query,
@ -326,7 +324,7 @@ class ModelInstance:
if not isinstance(self.model_type_instance, ModerationModel):
raise Exception("Model type instance is not ModerationModel")
return self._round_robin_invoke(
self.model_type_instance.invoke,
function=self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
text=text,
@ -342,7 +340,7 @@ class ModelInstance:
if not isinstance(self.model_type_instance, Speech2TextModel):
raise Exception("Model type instance is not Speech2TextModel")
return self._round_robin_invoke(
self.model_type_instance.invoke,
function=self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
file=file,
@ -359,14 +357,14 @@ class ModelInstance:
if not isinstance(self.model_type_instance, TTSModel):
raise Exception("Model type instance is not TTSModel")
return self._round_robin_invoke(
self.model_type_instance.invoke,
function=self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
content_text=content_text,
voice=voice,
)
def _round_robin_invoke(self, function: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
def _round_robin_invoke[**P, R](self, function: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
"""
Round-robin invoke
:param function: function to invoke

View File

@ -4,20 +4,7 @@ from collections.abc import Sequence
from opentelemetry.trace import SpanKind
from sqlalchemy.orm import sessionmaker
from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.trace_entity import (
BaseTraceInfo,
DatasetRetrievalTraceInfo,
GenerateNameTraceInfo,
MessageTraceInfo,
ModerationTraceInfo,
SuggestedQuestionTraceInfo,
ToolTraceInfo,
WorkflowTraceInfo,
)
from core.repositories import DifyCoreRepositoryFactory
from dify_trace_aliyun.config import AliyunConfig
from dify_trace_aliyun.data_exporter.traceclient import (
from core.ops.aliyun_trace.data_exporter.traceclient import (
TraceClient,
build_endpoint,
convert_datetime_to_nanoseconds,
@ -25,8 +12,8 @@ from dify_trace_aliyun.data_exporter.traceclient import (
convert_to_trace_id,
generate_span_id,
)
from dify_trace_aliyun.entities.aliyun_trace_entity import SpanData, TraceMetadata
from dify_trace_aliyun.entities.semconv import (
from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData, TraceMetadata
from core.ops.aliyun_trace.entities.semconv import (
DIFY_APP_ID,
GEN_AI_COMPLETION,
GEN_AI_INPUT_MESSAGE,
@ -45,7 +32,7 @@ from dify_trace_aliyun.entities.semconv import (
TOOL_PARAMETERS,
GenAISpanKind,
)
from dify_trace_aliyun.utils import (
from core.ops.aliyun_trace.utils import (
create_common_span_attributes,
create_links_from_trace_id,
create_status_from_error,
@ -57,6 +44,19 @@ from dify_trace_aliyun.utils import (
get_workflow_node_status,
serialize_json_data,
)
from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import AliyunConfig
from core.ops.entities.trace_entity import (
BaseTraceInfo,
DatasetRetrievalTraceInfo,
GenerateNameTraceInfo,
MessageTraceInfo,
ModerationTraceInfo,
SuggestedQuestionTraceInfo,
ToolTraceInfo,
WorkflowTraceInfo,
)
from core.repositories import DifyCoreRepositoryFactory
from extensions.ext_database import db
from graphon.entities import WorkflowNodeExecution
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey

View File

@ -26,8 +26,8 @@ from opentelemetry.semconv.attributes import service_attributes
from opentelemetry.trace import Link, SpanContext, TraceFlags
from configs import dify_config
from dify_trace_aliyun.entities.aliyun_trace_entity import SpanData
from dify_trace_aliyun.entities.semconv import ACS_ARMS_SERVICE_FEATURE
from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData
from core.ops.aliyun_trace.entities.semconv import ACS_ARMS_SERVICE_FEATURE
INVALID_SPAN_ID: Final[int] = 0x0000000000000000
INVALID_TRACE_ID: Final[int] = 0x00000000000000000000000000000000

View File

@ -4,8 +4,7 @@ from typing import Any, TypedDict
from opentelemetry.trace import Link, Status, StatusCode
from core.rag.models.document import Document
from dify_trace_aliyun.entities.semconv import (
from core.ops.aliyun_trace.entities.semconv import (
GEN_AI_FRAMEWORK,
GEN_AI_SESSION_ID,
GEN_AI_SPAN_KIND,
@ -14,6 +13,7 @@ from dify_trace_aliyun.entities.semconv import (
OUTPUT_VALUE,
GenAISpanKind,
)
from core.rag.models.document import Document
from extensions.ext_database import db
from graphon.entities import WorkflowNodeExecution
from graphon.enums import WorkflowNodeExecutionStatus
@ -48,7 +48,7 @@ def get_workflow_node_status(node_execution: WorkflowNodeExecution) -> Status:
def create_links_from_trace_id(trace_id: str | None) -> list[Link]:
from dify_trace_aliyun.data_exporter.traceclient import create_link
from core.ops.aliyun_trace.data_exporter.traceclient import create_link
links = []
if trace_id:

View File

@ -25,6 +25,7 @@ from opentelemetry.util.types import AttributeValue
from sqlalchemy.orm import sessionmaker
from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import ArizeConfig, PhoenixConfig
from core.ops.entities.trace_entity import (
BaseTraceInfo,
DatasetRetrievalTraceInfo,
@ -38,7 +39,6 @@ from core.ops.entities.trace_entity import (
)
from core.ops.utils import JSON_DICT_ADAPTER
from core.repositories import DifyCoreRepositoryFactory
from dify_trace_arize_phoenix.config import ArizeConfig, PhoenixConfig
from extensions.ext_database import db
from graphon.enums import WorkflowNodeExecutionStatus
from models.model import EndUser, MessageFile

View File

@ -1,8 +1,8 @@
from enum import StrEnum
from pydantic import BaseModel
from pydantic import BaseModel, ValidationInfo, field_validator
from core.ops.utils import validate_project_name, validate_url
from core.ops.utils import validate_integer_id, validate_project_name, validate_url, validate_url_with_path
class TracingProviderEnum(StrEnum):
@ -52,5 +52,220 @@ class BaseTracingConfig(BaseModel):
return validate_project_name(v, default_name)
class ArizeConfig(BaseTracingConfig):
"""
Model class for Arize tracing config.
"""
api_key: str | None = None
space_id: str | None = None
project: str | None = None
endpoint: str = "https://otlp.arize.com"
@field_validator("project")
@classmethod
def project_validator(cls, v, info: ValidationInfo):
return cls.validate_project_field(v, "default")
@field_validator("endpoint")
@classmethod
def endpoint_validator(cls, v, info: ValidationInfo):
return cls.validate_endpoint_url(v, "https://otlp.arize.com")
class PhoenixConfig(BaseTracingConfig):
"""
Model class for Phoenix tracing config.
"""
api_key: str | None = None
project: str | None = None
endpoint: str = "https://app.phoenix.arize.com"
@field_validator("project")
@classmethod
def project_validator(cls, v, info: ValidationInfo):
return cls.validate_project_field(v, "default")
@field_validator("endpoint")
@classmethod
def endpoint_validator(cls, v, info: ValidationInfo):
return validate_url_with_path(v, "https://app.phoenix.arize.com")
class LangfuseConfig(BaseTracingConfig):
"""
Model class for Langfuse tracing config.
"""
public_key: str
secret_key: str
host: str = "https://api.langfuse.com"
@field_validator("host")
@classmethod
def host_validator(cls, v, info: ValidationInfo):
return validate_url_with_path(v, "https://api.langfuse.com")
class LangSmithConfig(BaseTracingConfig):
"""
Model class for Langsmith tracing config.
"""
api_key: str
project: str
endpoint: str = "https://api.smith.langchain.com"
@field_validator("endpoint")
@classmethod
def endpoint_validator(cls, v, info: ValidationInfo):
# LangSmith only allows HTTPS
return validate_url(v, "https://api.smith.langchain.com", allowed_schemes=("https",))
class OpikConfig(BaseTracingConfig):
"""
Model class for Opik tracing config.
"""
api_key: str | None = None
project: str | None = None
workspace: str | None = None
url: str = "https://www.comet.com/opik/api/"
@field_validator("project")
@classmethod
def project_validator(cls, v, info: ValidationInfo):
return cls.validate_project_field(v, "Default Project")
@field_validator("url")
@classmethod
def url_validator(cls, v, info: ValidationInfo):
return validate_url_with_path(v, "https://www.comet.com/opik/api/", required_suffix="/api/")
class WeaveConfig(BaseTracingConfig):
"""
Model class for Weave tracing config.
"""
api_key: str
entity: str | None = None
project: str
endpoint: str = "https://trace.wandb.ai"
host: str | None = None
@field_validator("endpoint")
@classmethod
def endpoint_validator(cls, v, info: ValidationInfo):
# Weave only allows HTTPS for endpoint
return validate_url(v, "https://trace.wandb.ai", allowed_schemes=("https",))
@field_validator("host")
@classmethod
def host_validator(cls, v, info: ValidationInfo):
if v is not None and v.strip() != "":
return validate_url(v, v, allowed_schemes=("https", "http"))
return v
class AliyunConfig(BaseTracingConfig):
"""
Model class for Aliyun tracing config.
"""
app_name: str = "dify_app"
license_key: str
endpoint: str
@field_validator("app_name")
@classmethod
def app_name_validator(cls, v, info: ValidationInfo):
return cls.validate_project_field(v, "dify_app")
@field_validator("license_key")
@classmethod
def license_key_validator(cls, v, info: ValidationInfo):
if not v or v.strip() == "":
raise ValueError("License key cannot be empty")
return v
@field_validator("endpoint")
@classmethod
def endpoint_validator(cls, v, info: ValidationInfo):
# aliyun uses two URL formats, which may include a URL path
return validate_url_with_path(v, "https://tracing-analysis-dc-hz.aliyuncs.com")
class TencentConfig(BaseTracingConfig):
"""
Tencent APM tracing config
"""
token: str
endpoint: str
service_name: str
@field_validator("token")
@classmethod
def token_validator(cls, v, info: ValidationInfo):
if not v or v.strip() == "":
raise ValueError("Token cannot be empty")
return v
@field_validator("endpoint")
@classmethod
def endpoint_validator(cls, v, info: ValidationInfo):
return cls.validate_endpoint_url(v, "https://apm.tencentcloudapi.com")
@field_validator("service_name")
@classmethod
def service_name_validator(cls, v, info: ValidationInfo):
return cls.validate_project_field(v, "dify_app")
class MLflowConfig(BaseTracingConfig):
"""
Model class for MLflow tracing config.
"""
tracking_uri: str = "http://localhost:5000"
experiment_id: str = "0" # Default experiment id in MLflow is 0
username: str | None = None
password: str | None = None
@field_validator("tracking_uri")
@classmethod
def tracking_uri_validator(cls, v, info: ValidationInfo):
if isinstance(v, str) and v.startswith("databricks"):
raise ValueError(
"Please use Databricks tracing config below to record traces to Databricks-managed MLflow instances."
)
return validate_url_with_path(v, "http://localhost:5000")
@field_validator("experiment_id")
@classmethod
def experiment_id_validator(cls, v, info: ValidationInfo):
return validate_integer_id(v)
class DatabricksConfig(BaseTracingConfig):
"""
Model class for Databricks (Databricks-managed MLflow) tracing config.
"""
experiment_id: str
host: str
client_id: str | None = None
client_secret: str | None = None
personal_access_token: str | None = None
@field_validator("experiment_id")
@classmethod
def experiment_id_validator(cls, v, info: ValidationInfo):
return validate_integer_id(v)
OPS_FILE_PATH = "ops_trace/"
OPS_TRACE_FAILED_KEY = "FAILED_OPS_TRACE"

View File

@ -16,6 +16,7 @@ from langfuse.api.commons.types.usage import Usage
from sqlalchemy.orm import sessionmaker
from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import LangfuseConfig
from core.ops.entities.trace_entity import (
BaseTraceInfo,
DatasetRetrievalTraceInfo,
@ -27,10 +28,7 @@ from core.ops.entities.trace_entity import (
TraceTaskName,
WorkflowTraceInfo,
)
from core.ops.utils import filter_none_values
from core.repositories import DifyCoreRepositoryFactory
from dify_trace_langfuse.config import LangfuseConfig
from dify_trace_langfuse.entities.langfuse_trace_entity import (
from core.ops.langfuse_trace.entities.langfuse_trace_entity import (
GenerationUsage,
LangfuseGeneration,
LangfuseSpan,
@ -38,6 +36,8 @@ from dify_trace_langfuse.entities.langfuse_trace_entity import (
LevelEnum,
UnitEnum,
)
from core.ops.utils import filter_none_values
from core.repositories import DifyCoreRepositoryFactory
from extensions.ext_database import db
from graphon.enums import BuiltinNodeTypes
from models import EndUser, WorkflowNodeExecutionTriggeredFrom

View File

@ -9,6 +9,7 @@ from langsmith.schemas import RunBase
from sqlalchemy.orm import sessionmaker
from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import LangSmithConfig
from core.ops.entities.trace_entity import (
BaseTraceInfo,
DatasetRetrievalTraceInfo,
@ -20,14 +21,13 @@ from core.ops.entities.trace_entity import (
TraceTaskName,
WorkflowTraceInfo,
)
from core.ops.utils import filter_none_values, generate_dotted_order
from core.repositories import DifyCoreRepositoryFactory
from dify_trace_langsmith.config import LangSmithConfig
from dify_trace_langsmith.entities.langsmith_trace_entity import (
from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
LangSmithRunModel,
LangSmithRunType,
LangSmithRunUpdateModel,
)
from core.ops.utils import filter_none_values, generate_dotted_order
from core.repositories import DifyCoreRepositoryFactory
from extensions.ext_database import db
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom

View File

@ -11,6 +11,7 @@ from mlflow.tracing.provider import detach_span_from_context, set_span_in_contex
from sqlalchemy import select
from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import DatabricksConfig, MLflowConfig
from core.ops.entities.trace_entity import (
BaseTraceInfo,
DatasetRetrievalTraceInfo,
@ -23,7 +24,6 @@ from core.ops.entities.trace_entity import (
WorkflowTraceInfo,
)
from core.ops.utils import JSON_DICT_ADAPTER
from dify_trace_mlflow.config import DatabricksConfig, MLflowConfig
from extensions.ext_database import db
from graphon.enums import BuiltinNodeTypes
from models import EndUser

View File

@ -10,6 +10,7 @@ from opik.id_helpers import uuid4_to_uuid7
from sqlalchemy.orm import sessionmaker
from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import OpikConfig
from core.ops.entities.trace_entity import (
BaseTraceInfo,
DatasetRetrievalTraceInfo,
@ -22,7 +23,6 @@ from core.ops.entities.trace_entity import (
WorkflowTraceInfo,
)
from core.repositories import DifyCoreRepositoryFactory
from dify_trace_opik.config import OpikConfig
from extensions.ext_database import db
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom

View File

@ -204,117 +204,114 @@ class TracingProviderConfigEntry(TypedDict):
class OpsTraceProviderConfigMap(collections.UserDict[str, TracingProviderConfigEntry]):
def __getitem__(self, provider: str) -> TracingProviderConfigEntry:
try:
match provider:
case TracingProviderEnum.LANGFUSE:
from dify_trace_langfuse.config import LangfuseConfig
from dify_trace_langfuse.langfuse_trace import LangFuseDataTrace
match provider:
case TracingProviderEnum.LANGFUSE:
from core.ops.entities.config_entity import LangfuseConfig
from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace
return {
"config_class": LangfuseConfig,
"secret_keys": ["public_key", "secret_key"],
"other_keys": ["host", "project_key"],
"trace_instance": LangFuseDataTrace,
}
return {
"config_class": LangfuseConfig,
"secret_keys": ["public_key", "secret_key"],
"other_keys": ["host", "project_key"],
"trace_instance": LangFuseDataTrace,
}
case TracingProviderEnum.LANGSMITH:
from dify_trace_langsmith.config import LangSmithConfig
from dify_trace_langsmith.langsmith_trace import LangSmithDataTrace
case TracingProviderEnum.LANGSMITH:
from core.ops.entities.config_entity import LangSmithConfig
from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace
return {
"config_class": LangSmithConfig,
"secret_keys": ["api_key"],
"other_keys": ["project", "endpoint"],
"trace_instance": LangSmithDataTrace,
}
return {
"config_class": LangSmithConfig,
"secret_keys": ["api_key"],
"other_keys": ["project", "endpoint"],
"trace_instance": LangSmithDataTrace,
}
case TracingProviderEnum.OPIK:
from dify_trace_opik.config import OpikConfig
from dify_trace_opik.opik_trace import OpikDataTrace
case TracingProviderEnum.OPIK:
from core.ops.entities.config_entity import OpikConfig
from core.ops.opik_trace.opik_trace import OpikDataTrace
return {
"config_class": OpikConfig,
"secret_keys": ["api_key"],
"other_keys": ["project", "url", "workspace"],
"trace_instance": OpikDataTrace,
}
return {
"config_class": OpikConfig,
"secret_keys": ["api_key"],
"other_keys": ["project", "url", "workspace"],
"trace_instance": OpikDataTrace,
}
case TracingProviderEnum.WEAVE:
from dify_trace_weave.config import WeaveConfig
from dify_trace_weave.weave_trace import WeaveDataTrace
case TracingProviderEnum.WEAVE:
from core.ops.entities.config_entity import WeaveConfig
from core.ops.weave_trace.weave_trace import WeaveDataTrace
return {
"config_class": WeaveConfig,
"secret_keys": ["api_key"],
"other_keys": ["project", "entity", "endpoint", "host"],
"trace_instance": WeaveDataTrace,
}
case TracingProviderEnum.ARIZE:
from dify_trace_arize_phoenix.arize_phoenix_trace import ArizePhoenixDataTrace
from dify_trace_arize_phoenix.config import ArizeConfig
return {
"config_class": WeaveConfig,
"secret_keys": ["api_key"],
"other_keys": ["project", "entity", "endpoint", "host"],
"trace_instance": WeaveDataTrace,
}
case TracingProviderEnum.ARIZE:
from core.ops.arize_phoenix_trace.arize_phoenix_trace import ArizePhoenixDataTrace
from core.ops.entities.config_entity import ArizeConfig
return {
"config_class": ArizeConfig,
"secret_keys": ["api_key", "space_id"],
"other_keys": ["project", "endpoint"],
"trace_instance": ArizePhoenixDataTrace,
}
case TracingProviderEnum.PHOENIX:
from dify_trace_arize_phoenix.arize_phoenix_trace import ArizePhoenixDataTrace
from dify_trace_arize_phoenix.config import PhoenixConfig
return {
"config_class": ArizeConfig,
"secret_keys": ["api_key", "space_id"],
"other_keys": ["project", "endpoint"],
"trace_instance": ArizePhoenixDataTrace,
}
case TracingProviderEnum.PHOENIX:
from core.ops.arize_phoenix_trace.arize_phoenix_trace import ArizePhoenixDataTrace
from core.ops.entities.config_entity import PhoenixConfig
return {
"config_class": PhoenixConfig,
"secret_keys": ["api_key"],
"other_keys": ["project", "endpoint"],
"trace_instance": ArizePhoenixDataTrace,
}
case TracingProviderEnum.ALIYUN:
from dify_trace_aliyun.aliyun_trace import AliyunDataTrace
from dify_trace_aliyun.config import AliyunConfig
return {
"config_class": PhoenixConfig,
"secret_keys": ["api_key"],
"other_keys": ["project", "endpoint"],
"trace_instance": ArizePhoenixDataTrace,
}
case TracingProviderEnum.ALIYUN:
from core.ops.aliyun_trace.aliyun_trace import AliyunDataTrace
from core.ops.entities.config_entity import AliyunConfig
return {
"config_class": AliyunConfig,
"secret_keys": ["license_key"],
"other_keys": ["endpoint", "app_name"],
"trace_instance": AliyunDataTrace,
}
case TracingProviderEnum.MLFLOW:
from dify_trace_mlflow.config import MLflowConfig
from dify_trace_mlflow.mlflow_trace import MLflowDataTrace
return {
"config_class": AliyunConfig,
"secret_keys": ["license_key"],
"other_keys": ["endpoint", "app_name"],
"trace_instance": AliyunDataTrace,
}
case TracingProviderEnum.MLFLOW:
from core.ops.entities.config_entity import MLflowConfig
from core.ops.mlflow_trace.mlflow_trace import MLflowDataTrace
return {
"config_class": MLflowConfig,
"secret_keys": ["password"],
"other_keys": ["tracking_uri", "experiment_id", "username"],
"trace_instance": MLflowDataTrace,
}
case TracingProviderEnum.DATABRICKS:
from dify_trace_mlflow.config import DatabricksConfig
from dify_trace_mlflow.mlflow_trace import MLflowDataTrace
return {
"config_class": MLflowConfig,
"secret_keys": ["password"],
"other_keys": ["tracking_uri", "experiment_id", "username"],
"trace_instance": MLflowDataTrace,
}
case TracingProviderEnum.DATABRICKS:
from core.ops.entities.config_entity import DatabricksConfig
from core.ops.mlflow_trace.mlflow_trace import MLflowDataTrace
return {
"config_class": DatabricksConfig,
"secret_keys": ["personal_access_token", "client_secret"],
"other_keys": ["host", "client_id", "experiment_id"],
"trace_instance": MLflowDataTrace,
}
return {
"config_class": DatabricksConfig,
"secret_keys": ["personal_access_token", "client_secret"],
"other_keys": ["host", "client_id", "experiment_id"],
"trace_instance": MLflowDataTrace,
}
case TracingProviderEnum.TENCENT:
from dify_trace_tencent.config import TencentConfig
from dify_trace_tencent.tencent_trace import TencentDataTrace
case TracingProviderEnum.TENCENT:
from core.ops.entities.config_entity import TencentConfig
from core.ops.tencent_trace.tencent_trace import TencentDataTrace
return {
"config_class": TencentConfig,
"secret_keys": ["token"],
"other_keys": ["endpoint", "service_name"],
"trace_instance": TencentDataTrace,
}
return {
"config_class": TencentConfig,
"secret_keys": ["token"],
"other_keys": ["endpoint", "service_name"],
"trace_instance": TencentDataTrace,
}
case _:
raise KeyError(f"Unsupported tracing provider: {provider}")
except ImportError:
raise ImportError(f"Provider {provider} is not installed.")
case _:
raise KeyError(f"Unsupported tracing provider: {provider}")
provider_config_map = OpsTraceProviderConfigMap()

View File

@ -14,8 +14,7 @@ from core.ops.entities.trace_entity import (
ToolTraceInfo,
WorkflowTraceInfo,
)
from core.rag.models.document import Document
from dify_trace_tencent.entities.semconv import (
from core.ops.tencent_trace.entities.semconv import (
GEN_AI_COMPLETION,
GEN_AI_FRAMEWORK,
GEN_AI_IS_ENTRY,
@ -39,8 +38,9 @@ from dify_trace_tencent.entities.semconv import (
TOOL_PARAMETERS,
GenAISpanKind,
)
from dify_trace_tencent.entities.tencent_trace_entity import SpanData
from dify_trace_tencent.utils import TencentTraceUtils
from core.ops.tencent_trace.entities.tencent_trace_entity import SpanData
from core.ops.tencent_trace.utils import TencentTraceUtils
from core.rag.models.document import Document
from graphon.entities import WorkflowNodeExecution
from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus

View File

@ -1,12 +1,14 @@
"""Tencent APM tracing with idempotent client cleanup."""
"""
Tencent APM tracing implementation with separated concerns
"""
import inspect
import logging
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import TencentConfig
from core.ops.entities.trace_entity import (
BaseTraceInfo,
DatasetRetrievalTraceInfo,
@ -17,12 +19,11 @@ from core.ops.entities.trace_entity import (
ToolTraceInfo,
WorkflowTraceInfo,
)
from core.ops.tencent_trace.client import TencentTraceClient
from core.ops.tencent_trace.entities.tencent_trace_entity import SpanData
from core.ops.tencent_trace.span_builder import TencentSpanBuilder
from core.ops.tencent_trace.utils import TencentTraceUtils
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from dify_trace_tencent.client import TencentTraceClient
from dify_trace_tencent.config import TencentConfig
from dify_trace_tencent.entities.tencent_trace_entity import SpanData
from dify_trace_tencent.span_builder import TencentSpanBuilder
from dify_trace_tencent.utils import TencentTraceUtils
from extensions.ext_database import db
from graphon.entities.workflow_node_execution import (
WorkflowNodeExecution,
@ -37,18 +38,10 @@ class TencentDataTrace(BaseTraceInstance):
"""
Tencent APM trace implementation with single responsibility principle.
Acts as a coordinator that delegates specific tasks to specialized classes.
The instance owns a long-lived ``TencentTraceClient``. Cleanup may happen
explicitly in tests or implicitly during garbage collection, so shutdown
must be safe to call multiple times.
"""
trace_client: TencentTraceClient
_closed: bool
def __init__(self, tencent_config: TencentConfig):
super().__init__(tencent_config)
self._closed = False
self.trace_client = TencentTraceClient(
service_name=tencent_config.service_name,
endpoint=tencent_config.endpoint,
@ -520,25 +513,10 @@ class TencentDataTrace(BaseTraceInstance):
except Exception:
logger.debug("[Tencent APM] Failed to record message trace duration")
def close(self) -> None:
"""Synchronously and idempotently shutdown the underlying trace client."""
if getattr(self, "_closed", False):
return
self._closed = True
trace_client = getattr(self, "trace_client", None)
if trace_client is None:
return
def __del__(self):
"""Ensure proper cleanup on garbage collection."""
try:
shutdown_result = trace_client.shutdown()
if inspect.isawaitable(shutdown_result):
close_awaitable = getattr(shutdown_result, "close", None)
if callable(close_awaitable):
close_awaitable()
if hasattr(self, "trace_client"):
self.trace_client.shutdown()
except Exception:
logger.exception("[Tencent APM] Failed to shutdown trace client during cleanup")
def __del__(self):
"""Ensure best-effort cleanup on garbage collection without retrying shutdown."""
self.close()

View File

@ -17,6 +17,7 @@ from weave.trace_server.trace_server_interface import (
)
from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import WeaveConfig
from core.ops.entities.trace_entity import (
BaseTraceInfo,
DatasetRetrievalTraceInfo,
@ -28,9 +29,8 @@ from core.ops.entities.trace_entity import (
TraceTaskName,
WorkflowTraceInfo,
)
from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel
from core.repositories import DifyCoreRepositoryFactory
from dify_trace_weave.config import WeaveConfig
from dify_trace_weave.entities.weave_trace_entity import WeaveTraceModel
from extensions.ext_database import db
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom

View File

@ -66,15 +66,15 @@ class PluginModelRuntime(ModelRuntime):
if not provider_schema.icon_small:
raise ValueError(f"Provider {provider} does not have small icon.")
file_name = (
provider_schema.icon_small.zh_hans if lang.lower() == "zh_hans" else provider_schema.icon_small.en_us
provider_schema.icon_small.zh_Hans if lang.lower() == "zh_hans" else provider_schema.icon_small.en_US
)
elif icon_type.lower() == "icon_small_dark":
if not provider_schema.icon_small_dark:
raise ValueError(f"Provider {provider} does not have small dark icon.")
file_name = (
provider_schema.icon_small_dark.zh_hans
provider_schema.icon_small_dark.zh_Hans
if lang.lower() == "zh_hans"
else provider_schema.icon_small_dark.en_us
else provider_schema.icon_small_dark.en_US
)
else:
raise ValueError(f"Unsupported icon type: {icon_type}.")

View File

@ -10,7 +10,7 @@ from graphon.model_runtime.entities.message_entities import (
SystemPromptMessage,
UserPromptMessage,
)
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
class AgentHistoryPromptTransform(PromptTransform):

View File

@ -14,7 +14,7 @@ from core.rag.embedding.embedding_base import Embeddings
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from graphon.model_runtime.entities.model_entities import ModelPropertyKey
from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel
from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from libs import helper
from models.dataset import Embedding

View File

@ -3,7 +3,6 @@
Supports local file paths and remote URLs (downloaded via `core.helper.ssrf_proxy`).
"""
import inspect
import logging
import mimetypes
import os
@ -37,11 +36,8 @@ class WordExtractor(BaseExtractor):
file_path: Path to the file to load.
"""
_closed: bool
def __init__(self, file_path: str, tenant_id: str, user_id: str):
"""Initialize with file path."""
self._closed = False
self.file_path = file_path
self.tenant_id = tenant_id
self.user_id = user_id
@ -69,27 +65,9 @@ class WordExtractor(BaseExtractor):
elif not os.path.isfile(self.file_path):
raise ValueError(f"File path {self.file_path} is not a valid file or url")
def close(self) -> None:
"""Best-effort cleanup for downloaded temporary files."""
if getattr(self, "_closed", False):
return
self._closed = True
temp_file = getattr(self, "temp_file", None)
if temp_file is None:
return
try:
close_result = temp_file.close()
if inspect.isawaitable(close_result):
close_awaitable = getattr(close_result, "close", None)
if callable(close_awaitable):
close_awaitable()
except Exception:
logger.debug("Failed to cleanup downloaded word temp file", exc_info=True)
def __del__(self):
self.close()
if hasattr(self, "temp_file"):
self.temp_file.close()
def extract(self) -> list[Document]:
"""Load given path as single page."""

View File

@ -609,11 +609,11 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
try:
# Create File object directly (similar to DatasetRetrieval)
file_obj = File(
file_id=upload_file.id,
id=upload_file.id,
filename=upload_file.name,
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
file_type=FileType.IMAGE,
type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
remote_url=upload_file.source_url,
reference=build_file_reference(

View File

@ -68,7 +68,7 @@ from graphon.file import File, FileTransferMethod, FileType
from graphon.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMUsage
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from libs.helper import parse_uuid_str_or_none
from libs.json_in_md_parser import parse_and_check_json_markdown
from models import UploadFile
@ -517,11 +517,11 @@ class DatasetRetrieval:
if attachments_with_bindings:
for _, upload_file in attachments_with_bindings:
attachment_info = File(
file_id=upload_file.id,
id=upload_file.id,
filename=upload_file.name,
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
file_type=FileType.IMAGE,
type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
remote_url=upload_file.source_url,
reference=build_file_reference(

View File

@ -9,7 +9,7 @@ from typing import Any, Literal
from core.model_manager import ModelInstance
from core.rag.splitter.text_splitter import RecursiveCharacterTextSplitter
from graphon.model_runtime.model_providers.base.tokenizers.gpt2_tokenizer import GPT2Tokenizer
from graphon.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer import GPT2Tokenizer
class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):

View File

@ -8,7 +8,7 @@ from sqlalchemy import select
from sqlalchemy.orm import Session, selectinload
from core.db.session_factory import session_factory
from core.workflow.human_input_adapter import (
from core.workflow.human_input_compat import (
BoundRecipient,
DeliveryChannelConfig,
EmailDeliveryMethod,

View File

@ -28,7 +28,7 @@ class ToolFileManager:
def _build_graph_file_reference(tool_file: ToolFile) -> File:
extension = guess_extension(tool_file.mimetype) or ".bin"
return File(
file_type=get_file_type_by_mime_type(tool_file.mimetype),
type=get_file_type_by_mime_type(tool_file.mimetype),
transfer_method=FileTransferMethod.TOOL_FILE,
remote_url=tool_file.original_url,
reference=build_file_reference(record_id=str(tool_file.id)),

View File

@ -1082,12 +1082,7 @@ class ToolManager:
continue
tool_input = ToolNodeData.ToolInput.model_validate(tool_configurations.get(parameter.name, {}))
if tool_input.type == "variable":
variable_selector = tool_input.value
if not isinstance(variable_selector, list) or not all(
isinstance(selector_part, str) for selector_part in variable_selector
):
raise ToolParameterError("Variable tool input must be a variable selector")
variable = variable_pool.get(variable_selector)
variable = variable_pool.get(tool_input.value)
if variable is None:
raise ToolParameterError(f"Variable {tool_input.value} does not exist")
parameter_value = variable.value

View File

@ -21,7 +21,7 @@ from graphon.model_runtime.errors.invoke import (
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from graphon.model_runtime.utils.encoders import jsonable_encoder
from models.tools import ToolModelInvoke

View File

@ -357,10 +357,7 @@ class WorkflowTool(Tool):
def _update_file_mapping(self, file_dict: dict[str, Any]) -> dict[str, Any]:
file_id = resolve_file_record_id(file_dict.get("reference") or file_dict.get("related_id"))
transfer_method_value = file_dict.get("transfer_method")
if not isinstance(transfer_method_value, str):
raise ValueError("Workflow file mapping is missing a valid transfer_method")
transfer_method = FileTransferMethod.value_of(transfer_method_value)
transfer_method = FileTransferMethod.value_of(file_dict.get("transfer_method"))
match transfer_method:
case FileTransferMethod.TOOL_FILE:
file_dict["tool_file_id"] = file_id

View File

@ -1,8 +1,8 @@
"""Workflow-to-Graphon adapters for persisted node payloads.
"""Workflow-layer adapters for legacy human-input payload keys.
Stored workflow graphs and editor payloads still contain a small set of
Dify-owned field spellings and value shapes. Adapt them here before handing the
payload to Graphon so Graphon-owned models only see current contracts.
Stored workflow graphs and editor payloads may still use Dify-specific human
input recipient keys. Normalize them here before handing configs to
`graphon` so graph-owned models only see graph-neutral field names.
"""
from __future__ import annotations
@ -185,7 +185,7 @@ def _copy_mapping(value: object) -> dict[str, Any] | None:
return None
def adapt_human_input_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]:
def normalize_human_input_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]:
normalized = _copy_mapping(node_data)
if normalized is None:
raise TypeError(f"human-input node data must be a mapping, got {type(node_data).__name__}")
@ -215,7 +215,7 @@ def adapt_human_input_node_data_for_graph(node_data: Mapping[str, Any] | BaseMod
def parse_human_input_delivery_methods(node_data: Mapping[str, Any] | BaseModel) -> list[DeliveryChannelConfig]:
normalized = adapt_human_input_node_data_for_graph(node_data)
normalized = normalize_human_input_node_data_for_graph(node_data)
raw_delivery_methods = normalized.get("delivery_methods")
if not isinstance(raw_delivery_methods, list):
return []
@ -229,20 +229,17 @@ def is_human_input_webapp_enabled(node_data: Mapping[str, Any] | BaseModel) -> b
return False
def adapt_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]:
def normalize_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]:
normalized = _copy_mapping(node_data)
if normalized is None:
raise TypeError(f"node data must be a mapping, got {type(node_data).__name__}")
node_type = normalized.get("type")
if node_type == BuiltinNodeTypes.HUMAN_INPUT:
return adapt_human_input_node_data_for_graph(normalized)
if node_type == BuiltinNodeTypes.TOOL:
return _adapt_tool_node_data_for_graph(normalized)
return normalized
if normalized.get("type") != BuiltinNodeTypes.HUMAN_INPUT:
return normalized
return normalize_human_input_node_data_for_graph(normalized)
def adapt_node_config_for_graph(node_config: Mapping[str, Any] | BaseModel) -> dict[str, Any]:
def normalize_node_config_for_graph(node_config: Mapping[str, Any] | BaseModel) -> dict[str, Any]:
normalized = _copy_mapping(node_config)
if normalized is None:
raise TypeError(f"node config must be a mapping, got {type(node_config).__name__}")
@ -251,65 +248,10 @@ def adapt_node_config_for_graph(node_config: Mapping[str, Any] | BaseModel) -> d
if data_mapping is None:
return normalized
normalized["data"] = adapt_node_data_for_graph(data_mapping)
normalized["data"] = normalize_node_data_for_graph(data_mapping)
return normalized
def _adapt_tool_node_data_for_graph(node_data: Mapping[str, Any]) -> dict[str, Any]:
normalized = dict(node_data)
raw_tool_configurations = normalized.get("tool_configurations")
if not isinstance(raw_tool_configurations, Mapping):
return normalized
existing_tool_parameters = normalized.get("tool_parameters")
normalized_tool_parameters = dict(existing_tool_parameters) if isinstance(existing_tool_parameters, Mapping) else {}
normalized_tool_configurations: dict[str, Any] = {}
found_legacy_tool_inputs = False
for name, value in raw_tool_configurations.items():
if not isinstance(value, Mapping):
normalized_tool_configurations[name] = value
continue
input_type = value.get("type")
input_value = value.get("value")
if input_type not in {"mixed", "variable", "constant"}:
normalized_tool_configurations[name] = value
continue
found_legacy_tool_inputs = True
normalized_tool_parameters.setdefault(name, dict(value))
flattened_value = _flatten_legacy_tool_configuration_value(
input_type=input_type,
input_value=input_value,
)
if flattened_value is not None:
normalized_tool_configurations[name] = flattened_value
if not found_legacy_tool_inputs:
return normalized
normalized["tool_parameters"] = normalized_tool_parameters
normalized["tool_configurations"] = normalized_tool_configurations
return normalized
def _flatten_legacy_tool_configuration_value(*, input_type: Any, input_value: Any) -> str | int | float | bool | None:
if input_type in {"mixed", "constant"} and isinstance(input_value, str | int | float | bool):
return input_value
if (
input_type == "variable"
and isinstance(input_value, list)
and all(isinstance(item, str) for item in input_value)
):
return "{{#" + ".".join(input_value) + "#}}"
return None
def _normalize_email_recipients(recipients: Mapping[str, Any]) -> dict[str, Any]:
normalized = dict(recipients)
@ -349,9 +291,9 @@ __all__ = [
"MemberRecipient",
"WebAppDeliveryMethod",
"_WebAppDeliveryConfig",
"adapt_human_input_node_data_for_graph",
"adapt_node_config_for_graph",
"adapt_node_data_for_graph",
"is_human_input_webapp_enabled",
"normalize_human_input_node_data_for_graph",
"normalize_node_config_for_graph",
"normalize_node_data_for_graph",
"parse_human_input_delivery_methods",
]

View File

@ -15,12 +15,12 @@ from core.helper.code_executor.code_executor import (
CodeExecutionError,
CodeExecutor,
)
from core.helper.ssrf_proxy import graphon_ssrf_proxy
from core.helper.ssrf_proxy import ssrf_proxy
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.trigger.constants import TRIGGER_NODE_TYPES
from core.workflow.human_input_adapter import adapt_node_config_for_graph
from core.workflow.human_input_compat import normalize_node_config_for_graph
from core.workflow.node_runtime import (
DifyFileReferenceFactory,
DifyHumanInputNodeRuntime,
@ -46,7 +46,7 @@ from graphon.enums import BuiltinNodeTypes, NodeType
from graphon.file.file_manager import file_manager
from graphon.graph.graph import NodeFactory
from graphon.model_runtime.memory import PromptMessageMemory
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from graphon.nodes.base.node import Node
from graphon.nodes.code.code_node import WorkflowCodeExecutor
from graphon.nodes.code.entities import CodeLanguage
@ -121,7 +121,6 @@ def get_node_type_classes_mapping() -> Mapping[NodeType, Mapping[str, type[Node]
def resolve_workflow_node_class(*, node_type: NodeType, node_version: str) -> type[Node]:
"""Resolve the production node class for the requested type/version."""
node_mapping = get_node_type_classes_mapping().get(node_type)
if not node_mapping:
raise ValueError(f"No class mapping found for node type: {node_type}")
@ -298,7 +297,7 @@ class DifyNodeFactory(NodeFactory):
)
self._jinja2_template_renderer = CodeExecutorJinja2TemplateRenderer()
self._template_transform_max_output_length = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
self._http_request_http_client = graphon_ssrf_proxy
self._http_request_http_client = ssrf_proxy
self._bound_tool_file_manager_factory = lambda: DifyToolFileManager(
self._dify_context,
conversation_id_getter=self._conversation_id,
@ -365,14 +364,10 @@ class DifyNodeFactory(NodeFactory):
(including pydantic ValidationError, which subclasses ValueError),
if node type is unknown, or if no implementation exists for the resolved version
"""
typed_node_config = NodeConfigDictAdapter.validate_python(adapt_node_config_for_graph(node_config))
typed_node_config = NodeConfigDictAdapter.validate_python(normalize_node_config_for_graph(node_config))
node_id = typed_node_config["id"]
node_data = typed_node_config["data"]
node_class = self._resolve_node_class(node_type=node_data.type, node_version=str(node_data.version))
# Graph configs are initially validated against permissive shared node data.
# Re-validate using the resolved node class so workflow-local node schemas
# stay explicit and constructors receive the concrete typed payload.
resolved_node_data = self._validate_resolved_node_data(node_class, node_data)
node_type = node_data.type
node_init_kwargs_factories: Mapping[NodeType, Callable[[], dict[str, object]]] = {
BuiltinNodeTypes.CODE: lambda: {
@ -396,7 +391,7 @@ class DifyNodeFactory(NodeFactory):
},
BuiltinNodeTypes.LLM: lambda: self._build_llm_compatible_node_init_kwargs(
node_class=node_class,
node_data=resolved_node_data,
node_data=node_data,
wrap_model_instance=True,
include_http_client=True,
include_llm_file_saver=True,
@ -410,7 +405,7 @@ class DifyNodeFactory(NodeFactory):
},
BuiltinNodeTypes.QUESTION_CLASSIFIER: lambda: self._build_llm_compatible_node_init_kwargs(
node_class=node_class,
node_data=resolved_node_data,
node_data=node_data,
wrap_model_instance=True,
include_http_client=True,
include_llm_file_saver=True,
@ -420,7 +415,7 @@ class DifyNodeFactory(NodeFactory):
),
BuiltinNodeTypes.PARAMETER_EXTRACTOR: lambda: self._build_llm_compatible_node_init_kwargs(
node_class=node_class,
node_data=resolved_node_data,
node_data=node_data,
wrap_model_instance=True,
include_http_client=False,
include_llm_file_saver=False,
@ -441,8 +436,8 @@ class DifyNodeFactory(NodeFactory):
}
node_init_kwargs = node_init_kwargs_factories.get(node_type, lambda: {})()
return node_class(
node_id=node_id,
config=resolved_node_data,
id=node_id,
config=typed_node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
**node_init_kwargs,
@ -453,10 +448,7 @@ class DifyNodeFactory(NodeFactory):
"""
Re-validate the permissive graph payload with the concrete NodeData model declared by the resolved node class.
"""
validate_node_data = getattr(node_class, "validate_node_data", None)
if callable(validate_node_data):
return cast("BaseNodeData", validate_node_data(node_data))
return node_data
return node_class.validate_node_data(node_data)
@staticmethod
def _resolve_node_class(*, node_type: NodeType, node_version: str) -> type[Node]:

View File

@ -2,7 +2,7 @@ from __future__ import annotations
from collections.abc import Callable, Generator, Mapping, Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Literal, cast, overload
from typing import TYPE_CHECKING, Any, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -41,7 +41,7 @@ from graphon.model_runtime.entities.llm_entities import (
)
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from graphon.model_runtime.entities.model_entities import AIModelEntity
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from graphon.nodes.human_input.entities import HumanInputNodeData
from graphon.nodes.llm.runtime_protocols import (
PreparedLLMProtocol,
@ -64,7 +64,7 @@ from models.dataset import SegmentAttachmentBinding
from models.model import UploadFile
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
from .human_input_adapter import (
from .human_input_compat import (
BoundRecipient,
DeliveryChannelConfig,
DeliveryMethodType,
@ -173,28 +173,6 @@ class DifyPreparedLLM(PreparedLLMProtocol):
def get_llm_num_tokens(self, prompt_messages: Sequence[PromptMessage]) -> int:
return self._model_instance.get_llm_num_tokens(prompt_messages)
@overload
def invoke_llm(
self,
*,
prompt_messages: Sequence[PromptMessage],
model_parameters: Mapping[str, Any],
tools: Sequence[PromptMessageTool] | None,
stop: Sequence[str] | None,
stream: Literal[False],
) -> LLMResult: ...
@overload
def invoke_llm(
self,
*,
prompt_messages: Sequence[PromptMessage],
model_parameters: Mapping[str, Any],
tools: Sequence[PromptMessageTool] | None,
stop: Sequence[str] | None,
stream: Literal[True],
) -> Generator[LLMResultChunk, None, None]: ...
def invoke_llm(
self,
*,
@ -212,28 +190,6 @@ class DifyPreparedLLM(PreparedLLMProtocol):
stream=stream,
)
@overload
def invoke_llm_with_structured_output(
self,
*,
prompt_messages: Sequence[PromptMessage],
json_schema: Mapping[str, Any],
model_parameters: Mapping[str, Any],
stop: Sequence[str] | None,
stream: Literal[False],
) -> LLMResultWithStructuredOutput: ...
@overload
def invoke_llm_with_structured_output(
self,
*,
prompt_messages: Sequence[PromptMessage],
json_schema: Mapping[str, Any],
model_parameters: Mapping[str, Any],
stop: Sequence[str] | None,
stream: Literal[True],
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
def invoke_llm_with_structured_output(
self,
*,

View File

@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
from core.workflow.system_variables import SystemVariableKey, get_system_text
from graphon.entities.graph_config import NodeConfigDict
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
from graphon.node_events import NodeEventBase, NodeRunResult, StreamCompletedEvent
from graphon.nodes.base.node import Node
@ -34,18 +35,18 @@ class AgentNode(Node[AgentNodeData]):
def __init__(
self,
node_id: str,
config: AgentNodeData,
*,
id: str,
config: NodeConfigDict,
graph_init_params: GraphInitParams,
graph_runtime_state: GraphRuntimeState,
*,
strategy_resolver: AgentStrategyResolver,
presentation_provider: AgentStrategyPresentationProvider,
runtime_support: AgentRuntimeSupport,
message_transformer: AgentMessageTransformer,
) -> None:
super().__init__(
node_id=node_id,
id=id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,

View File

@ -7,6 +7,7 @@ from core.datasource.entities.datasource_entities import DatasourceProviderType
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.workflow.file_reference import resolve_file_record_id
from core.workflow.system_variables import SystemVariableKey, get_system_segment
from graphon.entities.graph_config import NodeConfigDict
from graphon.enums import (
BuiltinNodeTypes,
NodeExecutionType,
@ -35,14 +36,13 @@ class DatasourceNode(Node[DatasourceNodeData]):
def __init__(
self,
node_id: str,
config: DatasourceNodeData,
*,
id: str,
config: NodeConfigDict,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
) -> None:
):
super().__init__(
node_id=node_id,
id=id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,

View File

@ -7,6 +7,7 @@ from core.rag.index_processor.index_processor_base import SummaryIndexSettingDic
from core.rag.summary_index.summary_index import SummaryIndex
from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE
from core.workflow.system_variables import SystemVariableKey, get_system_segment, get_system_text
from graphon.entities.graph_config import NodeConfigDict
from graphon.enums import NodeExecutionType, WorkflowNodeExecutionStatus
from graphon.node_events import NodeRunResult
from graphon.nodes.base.node import Node
@ -31,18 +32,12 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
def __init__(
self,
node_id: str,
config: KnowledgeIndexNodeData,
*,
id: str,
config: NodeConfigDict,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
) -> None:
super().__init__(
node_id=node_id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
super().__init__(id, config, graph_init_params, graph_runtime_state)
self.index_processor = IndexProcessor()
self.summary_index_service = SummaryIndex()

View File

@ -14,6 +14,7 @@ from core.rag.data_post_processor.data_post_processor import RerankingModelDict,
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.workflow.file_reference import parse_file_reference
from graphon.entities import GraphInitParams
from graphon.entities.graph_config import NodeConfigDict
from graphon.enums import (
BuiltinNodeTypes,
WorkflowNodeExecutionMetadataKey,
@ -49,18 +50,6 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
def _normalize_metadata_filter_scalar(value: object) -> str | int | float | None:
if value is None or isinstance(value, (str, float)):
return value
if isinstance(value, int) and not isinstance(value, bool):
return value
return str(value)
def _normalize_metadata_filter_sequence_item(value: object) -> str:
return value if isinstance(value, str) else str(value)
class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeData]):
node_type = BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL
@ -70,14 +59,13 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
def __init__(
self,
node_id: str,
config: KnowledgeRetrievalNodeData,
*,
id: str,
config: NodeConfigDict,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
) -> None:
):
super().__init__(
node_id=node_id,
id=id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
@ -294,21 +282,18 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
resolved_conditions: list[Condition] = []
for cond in conditions.conditions or []:
value = cond.value
resolved_value: str | Sequence[str] | int | float | None
if isinstance(value, str):
segment_group = variable_pool.convert_template(value)
if len(segment_group.value) == 1:
resolved_value = _normalize_metadata_filter_scalar(segment_group.value[0].to_object())
resolved_value = segment_group.value[0].to_object()
else:
resolved_value = segment_group.text
elif isinstance(value, Sequence) and all(isinstance(v, str) for v in value):
resolved_values: list[str] = []
for v in value:
resolved_values = []
for v in value: # type: ignore
segment_group = variable_pool.convert_template(v)
if len(segment_group.value) == 1:
resolved_values.append(
_normalize_metadata_filter_sequence_item(segment_group.value[0].to_object())
)
resolved_values.append(segment_group.value[0].to_object())
else:
resolved_values.append(segment_group.text)
resolved_value = resolved_values

View File

@ -1,56 +1,17 @@
import logging
from dataclasses import dataclass
from enum import StrEnum, auto
logger = logging.getLogger(__name__)
@dataclass
class QuotaCharge:
"""
Result of a quota consumption operation.
Attributes:
success: Whether the quota charge succeeded
charge_id: UUID for refund, or None if failed/disabled
"""
success: bool
charge_id: str | None
_quota_type: "QuotaType"
def refund(self) -> None:
"""
Refund this quota charge.
Safe to call even if charge failed or was disabled.
This method guarantees no exceptions will be raised.
"""
if self.charge_id:
self._quota_type.refund(self.charge_id)
logger.info("Refunded quota for %s with charge_id: %s", self._quota_type.value, self.charge_id)
class QuotaType(StrEnum):
"""
Supported quota types for tenant feature usage.
Add additional types here whenever new billable features become available.
"""
# Trigger execution quota
TRIGGER = auto()
# Workflow execution quota
WORKFLOW = auto()
UNLIMITED = auto()
@property
def billing_key(self) -> str:
"""
Get the billing key for the feature.
"""
match self:
case QuotaType.TRIGGER:
return "trigger_event"
@ -58,152 +19,3 @@ class QuotaType(StrEnum):
return "api_rate_limit"
case _:
raise ValueError(f"Invalid quota type: {self}")
def consume(self, tenant_id: str, amount: int = 1) -> QuotaCharge:
"""
Consume quota for the feature.
Args:
tenant_id: The tenant identifier
amount: Amount to consume (default: 1)
Returns:
QuotaCharge with success status and charge_id for refund
Raises:
QuotaExceededError: When quota is insufficient
"""
from configs import dify_config
from services.billing_service import BillingService
from services.errors.app import QuotaExceededError
if not dify_config.BILLING_ENABLED:
logger.debug("Billing disabled, allowing request for %s", tenant_id)
return QuotaCharge(success=True, charge_id=None, _quota_type=self)
logger.info("Consuming %d %s quota for tenant %s", amount, self.value, tenant_id)
if amount <= 0:
raise ValueError("Amount to consume must be greater than 0")
try:
response = BillingService.update_tenant_feature_plan_usage(tenant_id, self.billing_key, delta=amount)
if response.get("result") != "success":
logger.warning(
"Failed to consume quota for %s, feature %s details: %s",
tenant_id,
self.value,
response.get("detail"),
)
raise QuotaExceededError(feature=self.value, tenant_id=tenant_id, required=amount)
charge_id = response.get("history_id")
logger.debug(
"Successfully consumed %d %s quota for tenant %s, charge_id: %s",
amount,
self.value,
tenant_id,
charge_id,
)
return QuotaCharge(success=True, charge_id=charge_id, _quota_type=self)
except QuotaExceededError:
raise
except Exception:
# fail-safe: allow request on billing errors
logger.exception("Failed to consume quota for %s, feature %s", tenant_id, self.value)
return unlimited()
def check(self, tenant_id: str, amount: int = 1) -> bool:
"""
Check if tenant has sufficient quota without consuming.
Args:
tenant_id: The tenant identifier
amount: Amount to check (default: 1)
Returns:
True if quota is sufficient, False otherwise
"""
from configs import dify_config
if not dify_config.BILLING_ENABLED:
return True
if amount <= 0:
raise ValueError("Amount to check must be greater than 0")
try:
remaining = self.get_remaining(tenant_id)
return remaining >= amount if remaining != -1 else True
except Exception:
logger.exception("Failed to check quota for %s, feature %s", tenant_id, self.value)
# fail-safe: allow request on billing errors
return True
def refund(self, charge_id: str) -> None:
"""
Refund quota using charge_id from consume().
This method guarantees no exceptions will be raised.
All errors are logged but silently handled.
Args:
charge_id: The UUID returned from consume()
"""
try:
from configs import dify_config
from services.billing_service import BillingService
if not dify_config.BILLING_ENABLED:
return
if not charge_id:
logger.warning("Cannot refund: charge_id is empty")
return
logger.info("Refunding %s quota with charge_id: %s", self.value, charge_id)
response = BillingService.refund_tenant_feature_plan_usage(charge_id)
if response.get("result") == "success":
logger.debug("Successfully refunded %s quota, charge_id: %s", self.value, charge_id)
else:
logger.warning("Refund failed for charge_id: %s", charge_id)
except Exception:
# Catch ALL exceptions - refund must never fail
logger.exception("Failed to refund quota for charge_id: %s", charge_id)
# Don't raise - refund is best-effort and must be silent
def get_remaining(self, tenant_id: str) -> int:
"""
Get remaining quota for the tenant.
Args:
tenant_id: The tenant identifier
Returns:
Remaining quota amount
"""
from services.billing_service import BillingService
try:
usage_info = BillingService.get_tenant_feature_plan_usage(tenant_id, self.billing_key)
# Assuming the API returns a dict with 'remaining' or 'limit' and 'used'
if isinstance(usage_info, dict):
return usage_info.get("remaining", 0)
# If it returns a simple number, treat it as remaining
return int(usage_info) if usage_info else 0
except Exception:
logger.exception("Failed to get remaining quota for %s, feature %s", tenant_id, self.value)
return -1
def unlimited() -> QuotaCharge:
"""
Return a quota charge for unlimited quota.
This is useful for features that are not subject to quota limits, such as the UNLIMITED quota type.
"""
return QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.UNLIMITED)

View File

@ -10,8 +10,8 @@ from typing import Any
from sqlalchemy import select
from core.app.file_access import FileAccessControllerProtocol
from core.db.session_factory import session_factory
from core.workflow.file_reference import build_file_reference
from extensions.ext_database import db
from graphon.file import File, FileTransferMethod, FileType, FileUploadConfig, helpers, standardize_file_type
from models import ToolFile, UploadFile
@ -135,30 +135,29 @@ def _build_from_local_file(
UploadFile.id == upload_file_id,
UploadFile.tenant_id == tenant_id,
)
with session_factory.create_session() as session:
row = session.scalar(access_controller.apply_upload_file_filters(stmt))
if row is None:
raise ValueError("Invalid upload file")
row = db.session.scalar(access_controller.apply_upload_file_filters(stmt))
if row is None:
raise ValueError("Invalid upload file")
detected_file_type = standardize_file_type(extension="." + row.extension, mime_type=row.mime_type)
file_type = _resolve_file_type(
detected_file_type=detected_file_type,
specified_type=mapping.get("type", "custom"),
strict_type_validation=strict_type_validation,
)
detected_file_type = standardize_file_type(extension="." + row.extension, mime_type=row.mime_type)
file_type = _resolve_file_type(
detected_file_type=detected_file_type,
specified_type=mapping.get("type", "custom"),
strict_type_validation=strict_type_validation,
)
return File(
file_id=mapping.get("id"),
filename=row.name,
extension="." + row.extension,
mime_type=row.mime_type,
file_type=file_type,
transfer_method=transfer_method,
remote_url=row.source_url,
reference=build_file_reference(record_id=str(row.id)),
size=row.size,
storage_key=row.key,
)
return File(
id=mapping.get("id"),
filename=row.name,
extension="." + row.extension,
mime_type=row.mime_type,
type=file_type,
transfer_method=transfer_method,
remote_url=row.source_url,
reference=build_file_reference(record_id=str(row.id)),
size=row.size,
storage_key=row.key,
)
def _build_from_remote_url(
@ -180,33 +179,32 @@ def _build_from_remote_url(
UploadFile.id == upload_file_id,
UploadFile.tenant_id == tenant_id,
)
with session_factory.create_session() as session:
upload_file = session.scalar(access_controller.apply_upload_file_filters(stmt))
if upload_file is None:
raise ValueError("Invalid upload file")
upload_file = db.session.scalar(access_controller.apply_upload_file_filters(stmt))
if upload_file is None:
raise ValueError("Invalid upload file")
detected_file_type = standardize_file_type(
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
)
file_type = _resolve_file_type(
detected_file_type=detected_file_type,
specified_type=mapping.get("type"),
strict_type_validation=strict_type_validation,
)
detected_file_type = standardize_file_type(
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
)
file_type = _resolve_file_type(
detected_file_type=detected_file_type,
specified_type=mapping.get("type"),
strict_type_validation=strict_type_validation,
)
return File(
file_id=mapping.get("id"),
filename=upload_file.name,
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
file_type=file_type,
transfer_method=transfer_method,
remote_url=helpers.get_signed_file_url(upload_file_id=str(upload_file_id)),
reference=build_file_reference(record_id=str(upload_file.id)),
size=upload_file.size,
storage_key=upload_file.key,
)
return File(
id=mapping.get("id"),
filename=upload_file.name,
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
type=file_type,
transfer_method=transfer_method,
remote_url=helpers.get_signed_file_url(upload_file_id=str(upload_file_id)),
reference=build_file_reference(record_id=str(upload_file.id)),
size=upload_file.size,
storage_key=upload_file.key,
)
url = mapping.get("url") or mapping.get("remote_url")
if not url:
@ -222,9 +220,9 @@ def _build_from_remote_url(
)
return File(
file_id=mapping.get("id"),
id=mapping.get("id"),
filename=filename,
file_type=file_type,
type=file_type,
transfer_method=transfer_method,
remote_url=url,
mime_type=mime_type,
@ -249,31 +247,30 @@ def _build_from_tool_file(
ToolFile.id == tool_file_id,
ToolFile.tenant_id == tenant_id,
)
with session_factory.create_session() as session:
tool_file = session.scalar(access_controller.apply_tool_file_filters(stmt))
if tool_file is None:
raise ValueError(f"ToolFile {tool_file_id} not found")
tool_file = db.session.scalar(access_controller.apply_tool_file_filters(stmt))
if tool_file is None:
raise ValueError(f"ToolFile {tool_file_id} not found")
extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin"
detected_file_type = standardize_file_type(extension=extension, mime_type=tool_file.mimetype)
file_type = _resolve_file_type(
detected_file_type=detected_file_type,
specified_type=mapping.get("type"),
strict_type_validation=strict_type_validation,
)
extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin"
detected_file_type = standardize_file_type(extension=extension, mime_type=tool_file.mimetype)
file_type = _resolve_file_type(
detected_file_type=detected_file_type,
specified_type=mapping.get("type"),
strict_type_validation=strict_type_validation,
)
return File(
file_id=mapping.get("id"),
filename=tool_file.name,
file_type=file_type,
transfer_method=transfer_method,
remote_url=tool_file.original_url,
reference=build_file_reference(record_id=str(tool_file.id)),
extension=extension,
mime_type=tool_file.mimetype,
size=tool_file.size,
storage_key=tool_file.file_key,
)
return File(
id=mapping.get("id"),
filename=tool_file.name,
type=file_type,
transfer_method=transfer_method,
remote_url=tool_file.original_url,
reference=build_file_reference(record_id=str(tool_file.id)),
extension=extension,
mime_type=tool_file.mimetype,
size=tool_file.size,
storage_key=tool_file.file_key,
)
def _build_from_datasource_file(
@ -292,32 +289,31 @@ def _build_from_datasource_file(
UploadFile.id == datasource_file_id,
UploadFile.tenant_id == tenant_id,
)
with session_factory.create_session() as session:
datasource_file = session.scalar(access_controller.apply_upload_file_filters(stmt))
if datasource_file is None:
raise ValueError(f"DatasourceFile {mapping.get('datasource_file_id')} not found")
datasource_file = db.session.scalar(access_controller.apply_upload_file_filters(stmt))
if datasource_file is None:
raise ValueError(f"DatasourceFile {mapping.get('datasource_file_id')} not found")
extension = "." + datasource_file.key.split(".")[-1] if "." in datasource_file.key else ".bin"
detected_file_type = standardize_file_type(extension="." + extension, mime_type=datasource_file.mime_type)
file_type = _resolve_file_type(
detected_file_type=detected_file_type,
specified_type=mapping.get("type"),
strict_type_validation=strict_type_validation,
)
extension = "." + datasource_file.key.split(".")[-1] if "." in datasource_file.key else ".bin"
detected_file_type = standardize_file_type(extension="." + extension, mime_type=datasource_file.mime_type)
file_type = _resolve_file_type(
detected_file_type=detected_file_type,
specified_type=mapping.get("type"),
strict_type_validation=strict_type_validation,
)
return File(
file_id=mapping.get("datasource_file_id"),
filename=datasource_file.name,
file_type=file_type,
transfer_method=FileTransferMethod.TOOL_FILE,
remote_url=datasource_file.source_url,
reference=build_file_reference(record_id=str(datasource_file.id)),
extension=extension,
mime_type=datasource_file.mime_type,
size=datasource_file.size,
storage_key=datasource_file.key,
url=datasource_file.source_url,
)
return File(
id=mapping.get("datasource_file_id"),
filename=datasource_file.name,
type=file_type,
transfer_method=FileTransferMethod.TOOL_FILE,
remote_url=datasource_file.source_url,
reference=build_file_reference(record_id=str(datasource_file.id)),
extension=extension,
mime_type=datasource_file.mime_type,
size=datasource_file.size,
storage_key=datasource_file.key,
url=datasource_file.source_url,
)
def _is_valid_mapping(mapping: Mapping[str, Any]) -> bool:

View File

@ -10,9 +10,9 @@ class _VarTypedDict(TypedDict, total=False):
def serialize_value_type(v: _VarTypedDict | Segment) -> str:
if isinstance(v, Segment):
return str(v.value_type.exposed_type())
return v.value_type.exposed_type().value
else:
value_type = v.get("value_type")
if value_type is None:
raise ValueError("value_type is required but not provided")
return str(value_type.exposed_type())
return value_type.exposed_type().value

View File

@ -57,10 +57,10 @@ class ConversationVariableResponse(ResponseModel):
def _normalize_value_type(cls, value: Any) -> str:
exposed_type = getattr(value, "exposed_type", None)
if callable(exposed_type):
return str(exposed_type())
return str(exposed_type().value)
if isinstance(value, str):
try:
return str(SegmentType(value).exposed_type())
return str(SegmentType(value).exposed_type().value)
except ValueError:
return value
try:

View File

@ -26,7 +26,7 @@ class EnvironmentVariableField(fields.Raw):
"id": value.id,
"name": value.name,
"value": value.value,
"value_type": str(value.value_type.exposed_type()),
"value_type": value.value_type.exposed_type().value,
"description": value.description,
}
if isinstance(value, dict):

View File

@ -6,8 +6,8 @@ from flask_login import current_user
from pydantic import TypeAdapter
from sqlalchemy import select
from core.db.session_factory import session_factory
from core.helper.http_client_pooling import get_pooled_http_client
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.source import DataSourceOauthBinding
@ -95,28 +95,27 @@ class NotionOAuth(OAuthDataSource):
pages=pages,
)
# save data source binding
with session_factory.create_session() as session:
data_source_binding = session.scalar(
select(DataSourceOauthBinding).where(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.access_token == access_token,
)
data_source_binding = db.session.scalar(
select(DataSourceOauthBinding).where(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.access_token == access_token,
)
if data_source_binding:
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info)
data_source_binding.disabled = False
data_source_binding.updated_at = naive_utc_now()
session.commit()
else:
new_data_source_binding = DataSourceOauthBinding(
tenant_id=current_user.current_tenant_id,
access_token=access_token,
source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info),
provider="notion",
)
session.add(new_data_source_binding)
session.commit()
)
if data_source_binding:
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info)
data_source_binding.disabled = False
data_source_binding.updated_at = naive_utc_now()
db.session.commit()
else:
new_data_source_binding = DataSourceOauthBinding(
tenant_id=current_user.current_tenant_id,
access_token=access_token,
source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info),
provider="notion",
)
db.session.add(new_data_source_binding)
db.session.commit()
def save_internal_access_token(self, access_token: str) -> None:
workspace_name = self.notion_workspace_name(access_token)
@ -131,57 +130,55 @@ class NotionOAuth(OAuthDataSource):
pages=pages,
)
# save data source binding
with session_factory.create_session() as session:
data_source_binding = session.scalar(
select(DataSourceOauthBinding).where(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.access_token == access_token,
)
data_source_binding = db.session.scalar(
select(DataSourceOauthBinding).where(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.access_token == access_token,
)
if data_source_binding:
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info)
data_source_binding.disabled = False
data_source_binding.updated_at = naive_utc_now()
session.commit()
else:
new_data_source_binding = DataSourceOauthBinding(
tenant_id=current_user.current_tenant_id,
access_token=access_token,
source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info),
provider="notion",
)
session.add(new_data_source_binding)
session.commit()
)
if data_source_binding:
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info)
data_source_binding.disabled = False
data_source_binding.updated_at = naive_utc_now()
db.session.commit()
else:
new_data_source_binding = DataSourceOauthBinding(
tenant_id=current_user.current_tenant_id,
access_token=access_token,
source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info),
provider="notion",
)
db.session.add(new_data_source_binding)
db.session.commit()
def sync_data_source(self, binding_id: str) -> None:
# save data source binding
with session_factory.create_session() as session:
data_source_binding = session.scalar(
select(DataSourceOauthBinding).where(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.id == binding_id,
DataSourceOauthBinding.disabled == False,
)
data_source_binding = db.session.scalar(
select(DataSourceOauthBinding).where(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.id == binding_id,
DataSourceOauthBinding.disabled == False,
)
)
if data_source_binding:
# get all authorized pages
pages = self.get_authorized_pages(data_source_binding.access_token)
source_info = NOTION_SOURCE_INFO_ADAPTER.validate_python(data_source_binding.source_info)
new_source_info = self._build_source_info(
workspace_name=source_info["workspace_name"],
workspace_icon=source_info["workspace_icon"],
workspace_id=source_info["workspace_id"],
pages=pages,
)
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(new_source_info)
data_source_binding.disabled = False
data_source_binding.updated_at = naive_utc_now()
session.commit()
else:
raise ValueError("Data source binding not found")
if data_source_binding:
# get all authorized pages
pages = self.get_authorized_pages(data_source_binding.access_token)
source_info = NOTION_SOURCE_INFO_ADAPTER.validate_python(data_source_binding.source_info)
new_source_info = self._build_source_info(
workspace_name=source_info["workspace_name"],
workspace_icon=source_info["workspace_icon"],
workspace_id=source_info["workspace_id"],
pages=pages,
)
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(new_source_info)
data_source_binding.disabled = False
data_source_binding.updated_at = naive_utc_now()
db.session.commit()
else:
raise ValueError("Data source binding not found")
def get_authorized_pages(self, access_token: str) -> list[NotionPageSummary]:
pages: list[NotionPageSummary] = []

View File

@ -3,7 +3,6 @@
from datetime import datetime
from typing import Optional
import sqlalchemy as sa
from sqlalchemy import Index, func
from sqlalchemy.orm import Mapped, mapped_column, relationship
@ -37,24 +36,24 @@ class WorkflowComment(Base):
__tablename__ = "workflow_comments"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="workflow_comments_pkey"),
db.PrimaryKeyConstraint("id", name="workflow_comments_pkey"),
Index("workflow_comments_app_idx", "tenant_id", "app_id"),
Index("workflow_comments_created_at_idx", "created_at"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuidv7()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuidv7()"))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
position_x: Mapped[float] = mapped_column(sa.Float)
position_y: Mapped[float] = mapped_column(sa.Float)
content: Mapped[str] = mapped_column(sa.Text, nullable=False)
position_x: Mapped[float] = mapped_column(db.Float)
position_y: Mapped[float] = mapped_column(db.Float)
content: Mapped[str] = mapped_column(db.Text, nullable=False)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
resolved: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
resolved_at: Mapped[datetime | None] = mapped_column(sa.DateTime)
resolved: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
resolved_at: Mapped[datetime | None] = mapped_column(db.DateTime)
resolved_by: Mapped[str | None] = mapped_column(StringUUID)
# Relationships
@ -144,20 +143,20 @@ class WorkflowCommentReply(Base):
__tablename__ = "workflow_comment_replies"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="workflow_comment_replies_pkey"),
db.PrimaryKeyConstraint("id", name="workflow_comment_replies_pkey"),
Index("comment_replies_comment_idx", "comment_id"),
Index("comment_replies_created_at_idx", "created_at"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuidv7()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuidv7()"))
comment_id: Mapped[str] = mapped_column(
StringUUID, sa.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False
StringUUID, db.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False
)
content: Mapped[str] = mapped_column(sa.Text, nullable=False)
content: Mapped[str] = mapped_column(db.Text, nullable=False)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
# Relationships
comment: Mapped["WorkflowComment"] = relationship("WorkflowComment", back_populates="replies")
@ -188,18 +187,18 @@ class WorkflowCommentMention(Base):
__tablename__ = "workflow_comment_mentions"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="workflow_comment_mentions_pkey"),
db.PrimaryKeyConstraint("id", name="workflow_comment_mentions_pkey"),
Index("comment_mentions_comment_idx", "comment_id"),
Index("comment_mentions_reply_idx", "reply_id"),
Index("comment_mentions_user_idx", "mentioned_user_id"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuidv7()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuidv7()"))
comment_id: Mapped[str] = mapped_column(
StringUUID, sa.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False
StringUUID, db.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False
)
reply_id: Mapped[str | None] = mapped_column(
StringUUID, sa.ForeignKey("workflow_comment_replies.id", ondelete="CASCADE"), nullable=True
StringUUID, db.ForeignKey("workflow_comment_replies.id", ondelete="CASCADE"), nullable=True
)
mentioned_user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)

View File

@ -1715,7 +1715,7 @@ class SegmentAttachmentBinding(TypeBase):
)
class DocumentSegmentSummary(TypeBase):
class DocumentSegmentSummary(Base):
__tablename__ = "document_segment_summaries"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="document_segment_summaries_pkey"),
@ -1725,40 +1725,25 @@ class DocumentSegmentSummary(TypeBase):
sa.Index("document_segment_summaries_status_idx", "status"),
)
id: Mapped[str] = mapped_column(
StringUUID,
nullable=False,
insert_default=lambda: str(uuid4()),
default_factory=lambda: str(uuid4()),
init=False,
)
id: Mapped[str] = mapped_column(StringUUID, nullable=False, default=lambda: str(uuid4()))
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
document_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
# corresponds to DocumentSegment.id or parent chunk id
chunk_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
summary_content: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
summary_index_node_id: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
summary_index_node_hash: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True, default=None)
status: Mapped[SummaryStatus] = mapped_column(
EnumText(SummaryStatus, length=32),
nullable=False,
server_default=sa.text("'generating'"),
default=SummaryStatus.GENERATING,
)
error: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"), default=True)
disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None)
disabled_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
created_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
summary_content: Mapped[str] = mapped_column(LongText, nullable=True)
summary_index_node_id: Mapped[str] = mapped_column(String(255), nullable=True)
summary_index_node_hash: Mapped[str] = mapped_column(String(255), nullable=True)
tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
status: Mapped[str] = mapped_column(
EnumText(SummaryStatus, length=32), nullable=False, server_default=sa.text("'generating'")
)
error: Mapped[str] = mapped_column(LongText, nullable=True)
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
disabled_by = mapped_column(StringUUID, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
DateTime,
nullable=False,
server_default=func.current_timestamp(),
onupdate=func.current_timestamp(),
init=False,
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
def __repr__(self):

View File

@ -6,7 +6,7 @@ import sqlalchemy as sa
from pydantic import BaseModel, Field
from sqlalchemy.orm import Mapped, mapped_column, relationship
from core.workflow.human_input_adapter import DeliveryMethodType
from core.workflow.human_input_compat import DeliveryMethodType
from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus
from libs.helper import generate_string

View File

@ -5,8 +5,7 @@ from functools import lru_cache
from typing import Any
from core.workflow.file_reference import parse_file_reference
from graphon.file import File, FileTransferMethod, FileType
from graphon.file.constants import FILE_MODEL_IDENTITY, maybe_file_object
from graphon.file import File, FileTransferMethod
@lru_cache(maxsize=1)
@ -44,124 +43,6 @@ def resolve_file_mapping_tenant_id(
return tenant_resolver()
def build_file_from_mapping_without_lookup(*, file_mapping: Mapping[str, Any]) -> File:
"""Build a graph `File` directly from serialized metadata."""
def _coerce_file_type(value: Any) -> FileType:
if isinstance(value, FileType):
return value
if isinstance(value, str):
return FileType.value_of(value)
raise ValueError("file type is required in file mapping")
mapping = dict(file_mapping)
transfer_method_value = mapping.get("transfer_method")
if isinstance(transfer_method_value, FileTransferMethod):
transfer_method = transfer_method_value
elif isinstance(transfer_method_value, str):
transfer_method = FileTransferMethod.value_of(transfer_method_value)
else:
raise ValueError("transfer_method is required in file mapping")
file_id = mapping.get("file_id")
if not isinstance(file_id, str) or not file_id:
legacy_id = mapping.get("id")
file_id = legacy_id if isinstance(legacy_id, str) and legacy_id else None
related_id = resolve_file_record_id(mapping)
if related_id is None:
raw_related_id = mapping.get("related_id")
related_id = raw_related_id if isinstance(raw_related_id, str) and raw_related_id else None
remote_url = mapping.get("remote_url")
if not isinstance(remote_url, str) or not remote_url:
url = mapping.get("url")
remote_url = url if isinstance(url, str) and url else None
reference = mapping.get("reference")
if not isinstance(reference, str) or not reference:
reference = None
filename = mapping.get("filename")
if not isinstance(filename, str):
filename = None
extension = mapping.get("extension")
if not isinstance(extension, str):
extension = None
mime_type = mapping.get("mime_type")
if not isinstance(mime_type, str):
mime_type = None
size = mapping.get("size", -1)
if not isinstance(size, int):
size = -1
storage_key = mapping.get("storage_key")
if not isinstance(storage_key, str):
storage_key = None
tenant_id = mapping.get("tenant_id")
if not isinstance(tenant_id, str):
tenant_id = None
dify_model_identity = mapping.get("dify_model_identity")
if not isinstance(dify_model_identity, str):
dify_model_identity = FILE_MODEL_IDENTITY
tool_file_id = mapping.get("tool_file_id")
if not isinstance(tool_file_id, str):
tool_file_id = None
upload_file_id = mapping.get("upload_file_id")
if not isinstance(upload_file_id, str):
upload_file_id = None
datasource_file_id = mapping.get("datasource_file_id")
if not isinstance(datasource_file_id, str):
datasource_file_id = None
return File(
file_id=file_id,
tenant_id=tenant_id,
file_type=_coerce_file_type(mapping.get("file_type", mapping.get("type"))),
transfer_method=transfer_method,
remote_url=remote_url,
reference=reference,
related_id=related_id,
filename=filename,
extension=extension,
mime_type=mime_type,
size=size,
storage_key=storage_key,
dify_model_identity=dify_model_identity,
url=remote_url,
tool_file_id=tool_file_id,
upload_file_id=upload_file_id,
datasource_file_id=datasource_file_id,
)
def rebuild_serialized_graph_files_without_lookup(value: Any) -> Any:
"""Recursively rebuild serialized graph file payloads into `File` objects.
`graphon` 0.2.2 no longer accepts legacy serialized file mappings via
`model_validate_json()`. Dify keeps this recovery path at the model boundary
so historical JSON blobs remain readable without reintroducing global graph
patches or test-local coercion.
"""
if isinstance(value, list):
return [rebuild_serialized_graph_files_without_lookup(item) for item in value]
if isinstance(value, dict):
if maybe_file_object(value):
return build_file_from_mapping_without_lookup(file_mapping=value)
return {key: rebuild_serialized_graph_files_without_lookup(item) for key, item in value.items()}
return value
def build_file_from_stored_mapping(
*,
file_mapping: Mapping[str, Any],
@ -195,7 +76,12 @@ def build_file_from_stored_mapping(
pass
if transfer_method == FileTransferMethod.REMOTE_URL and record_id is None:
return build_file_from_mapping_without_lookup(file_mapping=mapping)
remote_url = mapping.get("remote_url")
if not isinstance(remote_url, str) or not remote_url:
url = mapping.get("url")
if isinstance(url, str) and url:
mapping["remote_url"] = url
return File.model_validate(mapping)
return file_factory.build_from_mapping(
mapping=mapping,

View File

@ -24,7 +24,7 @@ from sqlalchemy.orm import Mapped, mapped_column
from typing_extensions import deprecated
from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
from core.workflow.human_input_adapter import adapt_node_config_for_graph
from core.workflow.human_input_compat import normalize_node_config_for_graph
from core.workflow.variable_prefixes import (
CONVERSATION_VARIABLE_NODE_ID,
SYSTEM_VARIABLE_NODE_ID,
@ -64,10 +64,7 @@ from .base import Base, DefaultFieldsDCMixin, TypeBase
from .engine import db
from .enums import CreatorUserRole, DraftVariableType, ExecutionOffLoadType, WorkflowRunTriggeredFrom
from .types import EnumText, LongText, StringUUID
from .utils.file_input_compat import (
build_file_from_mapping_without_lookup,
build_file_from_stored_mapping,
)
from .utils.file_input_compat import build_file_from_stored_mapping
logger = logging.getLogger(__name__)
@ -293,7 +290,7 @@ class Workflow(Base): # bug
node_config: dict[str, Any] = next(filter(lambda node: node["id"] == node_id, nodes))
except StopIteration:
raise NodeNotFoundError(node_id)
return NodeConfigDictAdapter.validate_python(adapt_node_config_for_graph(node_config))
return NodeConfigDictAdapter.validate_python(normalize_node_config_for_graph(node_config))
@staticmethod
def get_node_type_from_node_config(node_config: NodeConfigDict) -> NodeType:
@ -1691,7 +1688,7 @@ class WorkflowDraftVariable(Base):
return cast(Any, value)
normalized_file = dict(value)
normalized_file.pop("tenant_id", None)
return build_file_from_mapping_without_lookup(file_mapping=normalized_file)
return File.model_validate(normalized_file)
elif isinstance(value, list) and value:
value_list = cast(list[Any], value)
first: Any = value_list[0]
@ -1701,7 +1698,7 @@ class WorkflowDraftVariable(Base):
for item in value_list:
normalized_file = dict(cast(dict[str, Any], item))
normalized_file.pop("tenant_id", None)
file_list.append(build_file_from_mapping_without_lookup(file_mapping=normalized_file))
file_list.append(File.model_validate(normalized_file))
return cast(Any, file_list)
else:
return cast(Any, value)

View File

@ -10,6 +10,3 @@ This directory holds **optional workspace packages** that plug into Difys API
Provider tests often live next to the package, e.g. `providers/<type>/<backend>/tests/unit_tests/`. Shared fixtures may live under `providers/` (e.g. `conftest.py`).
## Excluding Providers
In order to build with selected providers, use `--no-group vdb-all` and `--no-group trace-all` to disable default ones, then use `--group vdb-<provider>` and `--group trace-<provider>` to enable specific providers.

View File

@ -1,78 +0,0 @@
# Trace providers
This directory holds **optional workspace packages** that send Dify **ops tracing** data (workflows, messages, tools, moderation, etc.) to an external observability backend (Langfuse, LangSmith, OpenTelemetry-style exporters, and others).
Unlike VDB providers, trace plugins are **not** discovered via entry points. The API core imports your package **explicitly** from `core/ops/ops_trace_manager.py` after you register the provider id and mapping.
## Architecture
| Layer | Location | Role |
|--------|----------|------|
| Contracts | `api/core/ops/base_trace_instance.py`, `api/core/ops/entities/trace_entity.py`, `api/core/ops/entities/config_entity.py` | `BaseTraceInstance`, `BaseTracingConfig`, and typed `*TraceInfo` payloads |
| Registry | `api/core/ops/ops_trace_manager.py` | `TracingProviderEnum`, `OpsTraceProviderConfigMap` — maps provider **string** → config class, encrypted keys, and trace class |
| Your package | `api/providers/trace/trace-<name>/` | Pydantic config + subclass of `BaseTraceInstance` |
At runtime, `OpsTraceManager` decrypts stored credentials, builds your config model, caches a trace instance, and calls `trace(trace_info)` with a concrete `BaseTraceInfo` subtype.
## What you implement
### 1. Config model (`BaseTracingConfig`)
Subclass `BaseTracingConfig` from `core.ops.entities.config_entity`. Use Pydantic validators; reuse helpers from `core.ops.utils` (for example `validate_url`, `validate_url_with_path`, `validate_project_name`) where appropriate.
Fields fall into two groups used by the manager:
- **`secret_keys`** — names of fields that are **encrypted at rest** (API keys, tokens, passwords).
- **`other_keys`** — non-secret connection settings (hosts, project names, endpoints).
List these key names in your `OpsTraceProviderConfigMap` entry so encrypt/decrypt and merge logic stay correct.
### 2. Trace instance (`BaseTraceInstance`)
Subclass `BaseTraceInstance` and implement:
```python
def trace(self, trace_info: BaseTraceInfo) -> None:
...
```
Dispatch on the concrete type with `isinstance` (see `trace_langfuse` or `trace_langsmith` for full patterns). Payload types are defined in `core/ops/entities/trace_entity.py`, including:
- `WorkflowTraceInfo`, `WorkflowNodeTraceInfo`, `DraftNodeExecutionTrace`
- `MessageTraceInfo`, `ToolTraceInfo`, `ModerationTraceInfo`, `SuggestedQuestionTraceInfo`
- `DatasetRetrievalTraceInfo`, `GenerateNameTraceInfo`, `PromptGenerationTraceInfo`
You may ignore categories your backend does not support; existing providers often no-op unhandled types.
Optional: use `get_service_account_with_tenant(app_id)` from the base class when you need tenant-scoped account context.
### 3. Register in the API core
Upstream changes are required so Dify knows your provider exists:
1. **`TracingProviderEnum`** (`api/core/ops/entities/config_entity.py`) — add a new member whose **value** is the stable string stored in app tracing config (e.g. `"mybackend"`).
2. **`OpsTraceProviderConfigMap.__getitem__`** (`api/core/ops/ops_trace_manager.py`) — add a `match` case for that enum member returning:
- `config_class`: your Pydantic config type
- `secret_keys` / `other_keys`: lists of field names as above
- `trace_instance`: your `BaseTraceInstance` subclass
Lazy-import your package inside the case so missing optional installs raise a clear `ImportError`.
If the `match` case is missing, the provider string will not resolve and tracing will be disabled for that app.
## Package layout
Each provider is a normal uv workspace member, for example:
- `api/providers/trace/trace-<name>/pyproject.toml` — project name `dify-trace-<name>`, dependencies on vendor SDKs
- `api/providers/trace/trace-<name>/src/dify_trace_<name>/``config.py`, `<name>_trace.py`, optional `entities/`, and an empty **`py.typed`** file (PEP 561) so the API type checker treats the package as typed; list `py.typed` under `[tool.setuptools.package-data]` for that import name in `pyproject.toml`.
Reference implementations: `trace-langfuse/`, `trace-langsmith/`, `trace-opik/`.
## Wiring into the `api` workspace
In `api/pyproject.toml`:
1. **`[tool.uv.sources]`** — `dify-trace-<name> = { workspace = true }`
2. **`[dependency-groups]`** — add `trace-<name> = ["dify-trace-<name>"]` and include `dify-trace-<name>` in `trace-all` if it should ship with the default bundle
After changing metadata, run **`uv sync`** from `api/`.

View File

@ -1,14 +0,0 @@
[project]
name = "dify-trace-aliyun"
version = "0.0.1"
dependencies = [
# versions inherited from parent
"opentelemetry-api",
"opentelemetry-exporter-otlp-proto-grpc",
"opentelemetry-sdk",
"opentelemetry-semantic-conventions",
]
description = "Dify ops tracing provider (Aliyun)."
[tool.setuptools.packages.find]
where = ["src"]

View File

@ -1,32 +0,0 @@
from pydantic import ValidationInfo, field_validator
from core.ops.entities.config_entity import BaseTracingConfig
from core.ops.utils import validate_url_with_path
class AliyunConfig(BaseTracingConfig):
"""
Model class for Aliyun tracing config.
"""
app_name: str = "dify_app"
license_key: str
endpoint: str
@field_validator("app_name")
@classmethod
def app_name_validator(cls, v, info: ValidationInfo):
return cls.validate_project_field(v, "dify_app")
@field_validator("license_key")
@classmethod
def license_key_validator(cls, v, info: ValidationInfo):
if not v or v.strip() == "":
raise ValueError("License key cannot be empty")
return v
@field_validator("endpoint")
@classmethod
def endpoint_validator(cls, v, info: ValidationInfo):
# aliyun uses two URL formats, which may include a URL path
return validate_url_with_path(v, "https://tracing-analysis-dc-hz.aliyuncs.com")

View File

@ -1,85 +0,0 @@
import pytest
from dify_trace_aliyun.config import AliyunConfig
from pydantic import ValidationError
class TestAliyunConfig:
"""Test cases for AliyunConfig"""
def test_valid_config(self):
"""Test valid Aliyun configuration"""
config = AliyunConfig(
app_name="test_app",
license_key="test_license_key",
endpoint="https://custom.tracing-analysis-dc-hz.aliyuncs.com",
)
assert config.app_name == "test_app"
assert config.license_key == "test_license_key"
assert config.endpoint == "https://custom.tracing-analysis-dc-hz.aliyuncs.com"
def test_default_values(self):
"""Test default values are set correctly"""
config = AliyunConfig(license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com")
assert config.app_name == "dify_app"
def test_missing_required_fields(self):
"""Test that required fields are enforced"""
with pytest.raises(ValidationError):
AliyunConfig()
with pytest.raises(ValidationError):
AliyunConfig(license_key="test_license")
with pytest.raises(ValidationError):
AliyunConfig(endpoint="https://tracing-analysis-dc-hz.aliyuncs.com")
def test_app_name_validation_empty(self):
"""Test app_name validation with empty value"""
config = AliyunConfig(
license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com", app_name=""
)
assert config.app_name == "dify_app"
def test_endpoint_validation_empty(self):
"""Test endpoint validation with empty value"""
config = AliyunConfig(license_key="test_license", endpoint="")
assert config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com"
def test_endpoint_validation_with_path(self):
"""Test endpoint validation preserves path for Aliyun endpoints"""
config = AliyunConfig(
license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com/api/v1/traces"
)
assert config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com/api/v1/traces"
def test_endpoint_validation_invalid_scheme(self):
"""Test endpoint validation rejects invalid schemes"""
with pytest.raises(ValidationError, match="URL must start with https:// or http://"):
AliyunConfig(license_key="test_license", endpoint="ftp://invalid.tracing-analysis-dc-hz.aliyuncs.com")
def test_endpoint_validation_no_scheme(self):
"""Test endpoint validation rejects URLs without scheme"""
with pytest.raises(ValidationError, match="URL must start with https:// or http://"):
AliyunConfig(license_key="test_license", endpoint="invalid.tracing-analysis-dc-hz.aliyuncs.com")
def test_license_key_required(self):
"""Test that license_key is required and cannot be empty"""
with pytest.raises(ValidationError):
AliyunConfig(license_key="", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com")
def test_valid_endpoint_format_examples(self):
"""Test valid endpoint format examples from comments"""
valid_endpoints = [
# cms2.0 public endpoint
"https://proj-xtrace-123456-cn-heyuan.cn-heyuan.log.aliyuncs.com/apm/trace/opentelemetry",
# cms2.0 intranet endpoint
"https://proj-xtrace-123456-cn-heyuan.cn-heyuan-intranet.log.aliyuncs.com/apm/trace/opentelemetry",
# xtrace public endpoint
"http://tracing-cn-heyuan.arms.aliyuncs.com",
# xtrace intranet endpoint
"http://tracing-cn-heyuan-internal.arms.aliyuncs.com",
]
for endpoint in valid_endpoints:
config = AliyunConfig(license_key="test_license", endpoint=endpoint)
assert config.endpoint == endpoint

View File

@ -1,10 +0,0 @@
[project]
name = "dify-trace-arize-phoenix"
version = "0.0.1"
dependencies = [
"arize-phoenix-otel~=0.15.0",
]
description = "Dify ops tracing provider (Arize / Phoenix)."
[tool.setuptools.packages.find]
where = ["src"]

View File

@ -1,45 +0,0 @@
from pydantic import ValidationInfo, field_validator
from core.ops.entities.config_entity import BaseTracingConfig
from core.ops.utils import validate_url_with_path
class ArizeConfig(BaseTracingConfig):
"""
Model class for Arize tracing config.
"""
api_key: str | None = None
space_id: str | None = None
project: str | None = None
endpoint: str = "https://otlp.arize.com"
@field_validator("project")
@classmethod
def project_validator(cls, v, info: ValidationInfo):
return cls.validate_project_field(v, "default")
@field_validator("endpoint")
@classmethod
def endpoint_validator(cls, v, info: ValidationInfo):
return cls.validate_endpoint_url(v, "https://otlp.arize.com")
class PhoenixConfig(BaseTracingConfig):
"""
Model class for Phoenix tracing config.
"""
api_key: str | None = None
project: str | None = None
endpoint: str = "https://app.phoenix.arize.com"
@field_validator("project")
@classmethod
def project_validator(cls, v, info: ValidationInfo):
return cls.validate_project_field(v, "default")
@field_validator("endpoint")
@classmethod
def endpoint_validator(cls, v, info: ValidationInfo):
return validate_url_with_path(v, "https://app.phoenix.arize.com")

View File

@ -1,88 +0,0 @@
import pytest
from dify_trace_arize_phoenix.config import ArizeConfig, PhoenixConfig
from pydantic import ValidationError
class TestArizeConfig:
"""Test cases for ArizeConfig"""
def test_valid_config(self):
"""Test valid Arize configuration"""
config = ArizeConfig(
api_key="test_key", space_id="test_space", project="test_project", endpoint="https://custom.arize.com"
)
assert config.api_key == "test_key"
assert config.space_id == "test_space"
assert config.project == "test_project"
assert config.endpoint == "https://custom.arize.com"
def test_default_values(self):
"""Test default values are set correctly"""
config = ArizeConfig()
assert config.api_key is None
assert config.space_id is None
assert config.project is None
assert config.endpoint == "https://otlp.arize.com"
def test_project_validation_empty(self):
"""Test project validation with empty value"""
config = ArizeConfig(project="")
assert config.project == "default"
def test_project_validation_none(self):
"""Test project validation with None value"""
config = ArizeConfig(project=None)
assert config.project == "default"
def test_endpoint_validation_empty(self):
"""Test endpoint validation with empty value"""
config = ArizeConfig(endpoint="")
assert config.endpoint == "https://otlp.arize.com"
def test_endpoint_validation_with_path(self):
"""Test endpoint validation normalizes URL by removing path"""
config = ArizeConfig(endpoint="https://custom.arize.com/api/v1")
assert config.endpoint == "https://custom.arize.com"
def test_endpoint_validation_invalid_scheme(self):
"""Test endpoint validation rejects invalid schemes"""
with pytest.raises(ValidationError, match="URL scheme must be one of"):
ArizeConfig(endpoint="ftp://invalid.com")
def test_endpoint_validation_no_scheme(self):
"""Test endpoint validation rejects URLs without scheme"""
with pytest.raises(ValidationError, match="URL scheme must be one of"):
ArizeConfig(endpoint="invalid.com")
class TestPhoenixConfig:
"""Test cases for PhoenixConfig"""
def test_valid_config(self):
"""Test valid Phoenix configuration"""
config = PhoenixConfig(api_key="test_key", project="test_project", endpoint="https://custom.phoenix.com")
assert config.api_key == "test_key"
assert config.project == "test_project"
assert config.endpoint == "https://custom.phoenix.com"
def test_default_values(self):
"""Test default values are set correctly"""
config = PhoenixConfig()
assert config.api_key is None
assert config.project is None
assert config.endpoint == "https://app.phoenix.arize.com"
def test_project_validation_empty(self):
"""Test project validation with empty value"""
config = PhoenixConfig(project="")
assert config.project == "default"
def test_endpoint_validation_with_path(self):
"""Test endpoint validation with path"""
config = PhoenixConfig(endpoint="https://app.phoenix.arize.com/s/dify-integration")
assert config.endpoint == "https://app.phoenix.arize.com/s/dify-integration"
def test_endpoint_validation_without_path(self):
"""Test endpoint validation without path"""
config = PhoenixConfig(endpoint="https://app.phoenix.arize.com")
assert config.endpoint == "https://app.phoenix.arize.com"

View File

@ -1,10 +0,0 @@
[project]
name = "dify-trace-langfuse"
version = "0.0.1"
dependencies = [
"langfuse>=4.2.0,<5.0.0",
]
description = "Dify ops tracing provider (Langfuse)."
[tool.setuptools.packages.find]
where = ["src"]

View File

@ -1,19 +0,0 @@
from pydantic import ValidationInfo, field_validator
from core.ops.entities.config_entity import BaseTracingConfig
from core.ops.utils import validate_url_with_path
class LangfuseConfig(BaseTracingConfig):
"""
Model class for Langfuse tracing config.
"""
public_key: str
secret_key: str
host: str = "https://api.langfuse.com"
@field_validator("host")
@classmethod
def host_validator(cls, v, info: ValidationInfo):
return validate_url_with_path(v, "https://api.langfuse.com")

View File

@ -1,42 +0,0 @@
import pytest
from dify_trace_langfuse.config import LangfuseConfig
from pydantic import ValidationError
class TestLangfuseConfig:
"""Test cases for LangfuseConfig"""
def test_valid_config(self):
"""Test valid Langfuse configuration"""
config = LangfuseConfig(public_key="public_key", secret_key="secret_key", host="https://custom.langfuse.com")
assert config.public_key == "public_key"
assert config.secret_key == "secret_key"
assert config.host == "https://custom.langfuse.com"
def test_valid_config_with_path(self):
host = "https://custom.langfuse.com/api/v1"
config = LangfuseConfig(public_key="public_key", secret_key="secret_key", host=host)
assert config.public_key == "public_key"
assert config.secret_key == "secret_key"
assert config.host == host
def test_default_values(self):
"""Test default values are set correctly"""
config = LangfuseConfig(public_key="public", secret_key="secret")
assert config.host == "https://api.langfuse.com"
def test_missing_required_fields(self):
"""Test that required fields are enforced"""
with pytest.raises(ValidationError):
LangfuseConfig()
with pytest.raises(ValidationError):
LangfuseConfig(public_key="public")
with pytest.raises(ValidationError):
LangfuseConfig(secret_key="secret")
def test_host_validation_empty(self):
"""Test host validation with empty value"""
config = LangfuseConfig(public_key="public", secret_key="secret", host="")
assert config.host == "https://api.langfuse.com"

Some files were not shown because too many files have changed in this diff Show More