Merge branch 'fix/app-detail-panel-merge-issue' into dev/plugin-deploy

This commit is contained in:
Yi
2025-02-10 12:01:57 +08:00
199 changed files with 3636 additions and 2136 deletions

View File

@ -48,16 +48,18 @@ ENV TZ=UTC
WORKDIR /app/api
RUN apt-get update \
&& apt-get install -y --no-install-recommends curl nodejs libgmp-dev libmpfr-dev libmpc-dev \
# if you located in China, you can use aliyun mirror to speed up
# && echo "deb http://mirrors.aliyun.com/debian testing main" > /etc/apt/sources.list \
&& echo "deb http://deb.debian.org/debian testing main" > /etc/apt/sources.list \
&& apt-get update \
# For Security
# && apt-get install -y --no-install-recommends expat=2.6.4-1 libldap-2.5-0=2.5.19+dfsg-1 perl=5.40.0-8 libsqlite3-0=3.46.1-1 zlib1g=1:1.3.dfsg+really1.3.1-1+b1 \
# install a chinese font to support the use of tools like matplotlib
&& apt-get install -y fonts-noto-cjk \
RUN \
apt-get update \
# Install dependencies
&& apt-get install -y --no-install-recommends \
# basic environment
curl nodejs libgmp-dev libmpfr-dev libmpc-dev \
# For Security
# expat libldap-2.5-0 perl libsqlite3-0 zlib1g \
# install a chinese font to support the use of tools like matplotlib
fonts-noto-cjk \
# install libmagic to support the use of python-magic guess MIMETYPE
libmagic1 \
&& apt-get autoremove -y \
&& rm -rf /var/lib/apt/lists/*
@ -80,7 +82,6 @@ COPY . /app/api/
COPY docker/entrypoint.sh /entrypoint.sh
RUN chmod +x /entrypoint.sh
ARG COMMIT_SHA
ENV COMMIT_SHA=${COMMIT_SHA}

View File

@ -556,6 +556,11 @@ class AuthConfig(BaseSettings):
default=86400,
)
FORGOT_PASSWORD_LOCKOUT_DURATION: PositiveInt = Field(
description="Time (in seconds) a user must wait before retrying password reset after exceeding the rate limit.",
default=86400,
)
class ModerationConfig(BaseSettings):
"""

View File

@ -1,9 +1,40 @@
from typing import Optional
from pydantic import Field, NonNegativeInt
from pydantic import Field, NonNegativeInt, computed_field
from pydantic_settings import BaseSettings
class HostedCreditConfig(BaseSettings):
HOSTED_MODEL_CREDIT_CONFIG: str = Field(
description="Model credit configuration in format 'model:credits,model:credits', e.g., 'gpt-4:20,gpt-4o:10'",
default="",
)
def get_model_credits(self, model_name: str) -> int:
"""
Get credit value for a specific model name.
Returns 1 if model is not found in configuration (default credit).
:param model_name: The name of the model to search for
:return: The credit value for the model
"""
if not self.HOSTED_MODEL_CREDIT_CONFIG:
return 1
try:
credit_map = dict(
item.strip().split(":", 1) for item in self.HOSTED_MODEL_CREDIT_CONFIG.split(",") if ":" in item
)
# Search for matching model pattern
for pattern, credit in credit_map.items():
if pattern.strip() == model_name:
return int(credit)
return 1 # Default quota if no match found
except (ValueError, AttributeError):
return 1 # Return default quota if parsing fails
class HostedOpenAiConfig(BaseSettings):
"""
Configuration for hosted OpenAI service
@ -202,5 +233,7 @@ class HostedServiceConfig(
HostedZhipuAIConfig,
# moderation
HostedModerationConfig,
# credit config
HostedCreditConfig,
):
pass

View File

@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
CURRENT_VERSION: str = Field(
description="Dify version",
default="1.0.0-beta.1",
default="1.0.0",
)
COMMIT_SHA: str = Field(

View File

@ -1,12 +1,32 @@
import mimetypes
import os
import platform
import re
import urllib.parse
import warnings
from collections.abc import Mapping
from typing import Any
from uuid import uuid4
import httpx
try:
import magic
except ImportError:
if platform.system() == "Windows":
warnings.warn(
"To use python-magic guess MIMETYPE, you need to run `pip install python-magic-bin`", stacklevel=2
)
elif platform.system() == "Darwin":
warnings.warn("To use python-magic guess MIMETYPE, you need to run `brew install libmagic`", stacklevel=2)
elif platform.system() == "Linux":
warnings.warn(
"To use python-magic guess MIMETYPE, you need to run `sudo apt-get install libmagic1`", stacklevel=2
)
else:
warnings.warn("To use python-magic guess MIMETYPE, you need to install `libmagic`", stacklevel=2)
magic = None # type: ignore
from pydantic import BaseModel
from configs import dify_config
@ -47,6 +67,13 @@ def guess_file_info_from_response(response: httpx.Response):
# If guessing fails, use Content-Type from response headers
mimetype = response.headers.get("Content-Type", "application/octet-stream")
# Use python-magic to guess MIME type if still unknown or generic
if mimetype == "application/octet-stream" and magic is not None:
try:
mimetype = magic.from_buffer(response.content[:1024], mime=True)
except magic.MagicException:
pass
extension = os.path.splitext(filename)[1]
# Ensure filename has an extension

View File

@ -59,3 +59,9 @@ class EmailCodeAccountDeletionRateLimitExceededError(BaseHTTPException):
error_code = "email_code_account_deletion_rate_limit_exceeded"
description = "Too many account deletion emails have been sent. Please try again in 5 minutes."
code = 429
class EmailPasswordResetLimitError(BaseHTTPException):
error_code = "email_password_reset_limit"
description = "Too many failed password reset attempts. Please try again in 24 hours."
code = 429

View File

@ -8,7 +8,13 @@ from sqlalchemy.orm import Session
from constants.languages import languages
from controllers.console import api
from controllers.console.auth.error import EmailCodeError, InvalidEmailError, InvalidTokenError, PasswordMismatchError
from controllers.console.auth.error import (
EmailCodeError,
EmailPasswordResetLimitError,
InvalidEmailError,
InvalidTokenError,
PasswordMismatchError,
)
from controllers.console.error import AccountInFreezeError, AccountNotFound, EmailSendIpLimitError
from controllers.console.wraps import setup_required
from events.tenant_event import tenant_was_created
@ -65,6 +71,10 @@ class ForgotPasswordCheckApi(Resource):
user_email = args["email"]
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args["email"])
if is_forgot_password_error_rate_limit:
raise EmailPasswordResetLimitError()
token_data = AccountService.get_reset_password_data(args["token"])
if token_data is None:
raise InvalidTokenError()
@ -73,8 +83,10 @@ class ForgotPasswordCheckApi(Resource):
raise InvalidEmailError()
if args["code"] != token_data.get("code"):
AccountService.add_forgot_password_error_rate_limit(args["email"])
raise EmailCodeError()
AccountService.reset_forgot_password_error_rate_limit(args["email"])
return {"is_valid": True, "email": token_data.get("email")}

View File

@ -50,7 +50,7 @@ class MessageListApi(InstalledAppResource):
try:
return MessageService.pagination_by_first_id(
app_model, current_user, args["conversation_id"], args["first_id"], args["limit"], "desc"
app_model, current_user, args["conversation_id"], args["first_id"], args["limit"]
)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")

View File

@ -1,3 +1,5 @@
import json
from flask_restful import Resource, reqparse # type: ignore
from controllers.console.wraps import setup_required
@ -29,4 +31,34 @@ class EnterpriseWorkspace(Resource):
return {"message": "enterprise workspace created."}
class EnterpriseWorkspaceNoOwnerEmail(Resource):
@setup_required
@enterprise_inner_api_only
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, location="json")
args = parser.parse_args()
tenant = TenantService.create_tenant(args["name"], is_from_dashboard=True)
tenant_was_created.send(tenant)
resp = {
"id": tenant.id,
"name": tenant.name,
"encrypt_public_key": tenant.encrypt_public_key,
"plan": tenant.plan,
"status": tenant.status,
"custom_config": json.loads(tenant.custom_config) if tenant.custom_config else {},
"created_at": tenant.created_at.isoformat() if tenant.created_at else None,
"updated_at": tenant.updated_at.isoformat() if tenant.updated_at else None,
}
return {
"message": "enterprise workspace created.",
"tenant": resp,
}
api.add_resource(EnterpriseWorkspace, "/enterprise/workspace")
api.add_resource(EnterpriseWorkspaceNoOwnerEmail, "/enterprise/workspace/ownerless")

View File

@ -18,6 +18,7 @@ from controllers.service_api.app.error import (
from controllers.service_api.dataset.error import (
ArchivedDocumentImmutableError,
DocumentIndexingError,
InvalidMetadataError,
)
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_resource_check
from core.errors.error import ProviderTokenNotInitError
@ -50,6 +51,9 @@ class DocumentAddByTextApi(DatasetApiResource):
"indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json"
)
parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
parser.add_argument("doc_type", type=str, required=False, nullable=True, location="json")
parser.add_argument("doc_metadata", type=dict, required=False, nullable=True, location="json")
args = parser.parse_args()
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
@ -61,6 +65,28 @@ class DocumentAddByTextApi(DatasetApiResource):
if not dataset.indexing_technique and not args["indexing_technique"]:
raise ValueError("indexing_technique is required.")
# Validate metadata if provided
if args.get("doc_type") or args.get("doc_metadata"):
if not args.get("doc_type") or not args.get("doc_metadata"):
raise InvalidMetadataError("Both doc_type and doc_metadata must be provided when adding metadata")
if args["doc_type"] not in DocumentService.DOCUMENT_METADATA_SCHEMA:
raise InvalidMetadataError(
"Invalid doc_type. Must be one of: " + ", ".join(DocumentService.DOCUMENT_METADATA_SCHEMA.keys())
)
if not isinstance(args["doc_metadata"], dict):
raise InvalidMetadataError("doc_metadata must be a dictionary")
# Validate metadata schema based on doc_type
if args["doc_type"] != "others":
metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[args["doc_type"]]
for key, value in args["doc_metadata"].items():
if key in metadata_schema and not isinstance(value, metadata_schema[key]):
raise InvalidMetadataError(f"Invalid type for metadata field {key}")
# set to MetaDataConfig
args["metadata"] = {"doc_type": args["doc_type"], "doc_metadata": args["doc_metadata"]}
text = args.get("text")
name = args.get("name")
if text is None or name is None:
@ -107,6 +133,8 @@ class DocumentUpdateByTextApi(DatasetApiResource):
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
)
parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
parser.add_argument("doc_type", type=str, required=False, nullable=True, location="json")
parser.add_argument("doc_metadata", type=dict, required=False, nullable=True, location="json")
args = parser.parse_args()
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
@ -115,6 +143,32 @@ class DocumentUpdateByTextApi(DatasetApiResource):
if not dataset:
raise ValueError("Dataset is not exist.")
# indexing_technique is already set in dataset since this is an update
args["indexing_technique"] = dataset.indexing_technique
# Validate metadata if provided
if args.get("doc_type") or args.get("doc_metadata"):
if not args.get("doc_type") or not args.get("doc_metadata"):
raise InvalidMetadataError("Both doc_type and doc_metadata must be provided when adding metadata")
if args["doc_type"] not in DocumentService.DOCUMENT_METADATA_SCHEMA:
raise InvalidMetadataError(
"Invalid doc_type. Must be one of: " + ", ".join(DocumentService.DOCUMENT_METADATA_SCHEMA.keys())
)
if not isinstance(args["doc_metadata"], dict):
raise InvalidMetadataError("doc_metadata must be a dictionary")
# Validate metadata schema based on doc_type
if args["doc_type"] != "others":
metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[args["doc_type"]]
for key, value in args["doc_metadata"].items():
if key in metadata_schema and not isinstance(value, metadata_schema[key]):
raise InvalidMetadataError(f"Invalid type for metadata field {key}")
# set to MetaDataConfig
args["metadata"] = {"doc_type": args["doc_type"], "doc_metadata": args["doc_metadata"]}
if args["text"]:
text = args.get("text")
name = args.get("name")
@ -161,6 +215,30 @@ class DocumentAddByFileApi(DatasetApiResource):
args["doc_form"] = "text_model"
if "doc_language" not in args:
args["doc_language"] = "English"
# Validate metadata if provided
if args.get("doc_type") or args.get("doc_metadata"):
if not args.get("doc_type") or not args.get("doc_metadata"):
raise InvalidMetadataError("Both doc_type and doc_metadata must be provided when adding metadata")
if args["doc_type"] not in DocumentService.DOCUMENT_METADATA_SCHEMA:
raise InvalidMetadataError(
"Invalid doc_type. Must be one of: " + ", ".join(DocumentService.DOCUMENT_METADATA_SCHEMA.keys())
)
if not isinstance(args["doc_metadata"], dict):
raise InvalidMetadataError("doc_metadata must be a dictionary")
# Validate metadata schema based on doc_type
if args["doc_type"] != "others":
metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[args["doc_type"]]
for key, value in args["doc_metadata"].items():
if key in metadata_schema and not isinstance(value, metadata_schema[key]):
raise InvalidMetadataError(f"Invalid type for metadata field {key}")
# set to MetaDataConfig
args["metadata"] = {"doc_type": args["doc_type"], "doc_metadata": args["doc_metadata"]}
# get dataset info
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
@ -228,6 +306,29 @@ class DocumentUpdateByFileApi(DatasetApiResource):
if "doc_language" not in args:
args["doc_language"] = "English"
# Validate metadata if provided
if args.get("doc_type") or args.get("doc_metadata"):
if not args.get("doc_type") or not args.get("doc_metadata"):
raise InvalidMetadataError("Both doc_type and doc_metadata must be provided when adding metadata")
if args["doc_type"] not in DocumentService.DOCUMENT_METADATA_SCHEMA:
raise InvalidMetadataError(
"Invalid doc_type. Must be one of: " + ", ".join(DocumentService.DOCUMENT_METADATA_SCHEMA.keys())
)
if not isinstance(args["doc_metadata"], dict):
raise InvalidMetadataError("doc_metadata must be a dictionary")
# Validate metadata schema based on doc_type
if args["doc_type"] != "others":
metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[args["doc_type"]]
for key, value in args["doc_metadata"].items():
if key in metadata_schema and not isinstance(value, metadata_schema[key]):
raise InvalidMetadataError(f"Invalid type for metadata field {key}")
# set to MetaDataConfig
args["metadata"] = {"doc_type": args["doc_type"], "doc_metadata": args["doc_metadata"]}
# get dataset info
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)

View File

@ -91,7 +91,7 @@ class MessageListApi(WebApiResource):
try:
return MessageService.pagination_by_first_id(
app_model, end_user, args["conversation_id"], args["first_id"], args["limit"], "desc"
app_model, end_user, args["conversation_id"], args["first_id"], args["limit"]
)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")

View File

@ -8,16 +8,16 @@ from core.agent.fc_agent_runner import FunctionCallAgentRunner
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.base_app_runner import AppRunner
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ModelConfigWithCredentialsEntity
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity
from core.app.entities.queue_entities import QueueAnnotationReplyEvent
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMMode, LLMUsage
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.moderation.base import ModerationError
from extensions.ext_database import db
from models.model import App, Conversation, Message, MessageAgentThought
from models.model import App, Conversation, Message
logger = logging.getLogger(__name__)
@ -191,7 +191,8 @@ class AgentChatAppRunner(AppRunner):
# change function call strategy based on LLM model
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
assert model_schema is not None
if not model_schema:
raise ValueError("Model schema not found")
if {ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL}.intersection(model_schema.features or []):
agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING
@ -247,29 +248,3 @@ class AgentChatAppRunner(AppRunner):
stream=application_generate_entity.stream,
agent=True,
)
def _get_usage_of_all_agent_thoughts(
self, model_config: ModelConfigWithCredentialsEntity, message: Message
) -> LLMUsage:
"""
Get usage of all agent thoughts
:param model_config: model config
:param message: message
:return:
"""
agent_thoughts = (
db.session.query(MessageAgentThought).filter(MessageAgentThought.message_id == message.id).all()
)
all_message_tokens = 0
all_answer_tokens = 0
for agent_thought in agent_thoughts:
all_message_tokens += agent_thought.message_tokens
all_answer_tokens += agent_thought.answer_tokens
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
return model_type_instance._calc_response_usage(
model_config.model, model_config.credentials, all_message_tokens, all_answer_tokens
)

View File

@ -11,15 +11,6 @@ from configs import dify_config
SSRF_DEFAULT_MAX_RETRIES = dify_config.SSRF_DEFAULT_MAX_RETRIES
proxy_mounts = (
{
"http://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTP_URL),
"https://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTPS_URL),
}
if dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL
else None
)
BACKOFF_FACTOR = 0.5
STATUS_FORCELIST = [429, 500, 502, 503, 504]
@ -50,7 +41,11 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
if dify_config.SSRF_PROXY_ALL_URL:
with httpx.Client(proxy=dify_config.SSRF_PROXY_ALL_URL) as client:
response = client.request(method=method, url=url, **kwargs)
elif proxy_mounts:
elif dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL:
proxy_mounts = {
"http://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTP_URL),
"https://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTPS_URL),
}
with httpx.Client(mounts=proxy_mounts) as client:
response = client.request(method=method, url=url, **kwargs)
else:

View File

@ -1,4 +1,4 @@
from .llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from .llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from .message_entities import (
AssistantPromptMessage,
AudioPromptMessageContent,
@ -23,6 +23,7 @@ __all__ = [
"AudioPromptMessageContent",
"DocumentPromptMessageContent",
"ImagePromptMessageContent",
"LLMMode",
"LLMResult",
"LLMResultChunk",
"LLMResultChunkDelta",

View File

@ -1,5 +1,5 @@
from decimal import Decimal
from enum import Enum
from enum import StrEnum
from typing import Optional
from pydantic import BaseModel
@ -8,7 +8,7 @@ from core.model_runtime.entities.message_entities import AssistantPromptMessage,
from core.model_runtime.entities.model_entities import ModelUsage, PriceInfo
class LLMMode(Enum):
class LLMMode(StrEnum):
"""
Enum class for large language model mode.
"""

View File

@ -3,8 +3,11 @@ from typing import Optional
from pydantic import BaseModel, ConfigDict, Field
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE
from core.model_runtime.entities.model_entities import (
AIModelEntity,
DefaultParameterName,
ModelType,
PriceConfig,
PriceInfo,
@ -18,6 +21,7 @@ from core.model_runtime.errors.invoke import (
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
from core.plugin.entities.plugin_daemon import PluginDaemonInnerError, PluginModelProviderEntity
from core.plugin.manager.model import PluginModelManager
@ -144,3 +148,102 @@ class AIModel(BaseModel):
model=model,
credentials=credentials or {},
)
def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
"""
Get customizable model schema from credentials
:param model: model name
:param credentials: model credentials
:return: model schema
"""
return self._get_customizable_model_schema(model, credentials)
def _get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
"""
Get customizable model schema and fill in the template
"""
schema = self.get_customizable_model_schema(model, credentials)
if not schema:
return None
# fill in the template
new_parameter_rules = []
for parameter_rule in schema.parameter_rules:
if parameter_rule.use_template:
try:
default_parameter_name = DefaultParameterName.value_of(parameter_rule.use_template)
default_parameter_rule = self._get_default_parameter_rule_variable_map(default_parameter_name)
if not parameter_rule.max and "max" in default_parameter_rule:
parameter_rule.max = default_parameter_rule["max"]
if not parameter_rule.min and "min" in default_parameter_rule:
parameter_rule.min = default_parameter_rule["min"]
if not parameter_rule.default and "default" in default_parameter_rule:
parameter_rule.default = default_parameter_rule["default"]
if not parameter_rule.precision and "precision" in default_parameter_rule:
parameter_rule.precision = default_parameter_rule["precision"]
if not parameter_rule.required and "required" in default_parameter_rule:
parameter_rule.required = default_parameter_rule["required"]
if not parameter_rule.help and "help" in default_parameter_rule:
parameter_rule.help = I18nObject(
en_US=default_parameter_rule["help"]["en_US"],
)
if (
parameter_rule.help
and not parameter_rule.help.en_US
and ("help" in default_parameter_rule and "en_US" in default_parameter_rule["help"])
):
parameter_rule.help.en_US = default_parameter_rule["help"]["en_US"]
if (
parameter_rule.help
and not parameter_rule.help.zh_Hans
and ("help" in default_parameter_rule and "zh_Hans" in default_parameter_rule["help"])
):
parameter_rule.help.zh_Hans = default_parameter_rule["help"].get(
"zh_Hans", default_parameter_rule["help"]["en_US"]
)
except ValueError:
pass
new_parameter_rules.append(parameter_rule)
schema.parameter_rules = new_parameter_rules
return schema
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
"""
Get customizable model schema
:param model: model name
:param credentials: model credentials
:return: model schema
"""
return None
def _get_default_parameter_rule_variable_map(self, name: DefaultParameterName) -> dict:
"""
Get default parameter rule for given name
:param name: parameter name
:return: parameter rule
"""
default_parameter_rule = PARAMETER_RULE_TEMPLATE.get(name)
if not default_parameter_rule:
raise Exception(f"Invalid model parameter rule name {name}")
return default_parameter_rule
def _get_num_tokens_by_gpt2(self, text: str) -> int:
"""
Get number of tokens for given prompt messages by gpt2
Some provider models do not provide an interface for obtaining the number of tokens.
Here, the gpt2 tokenizer is used to calculate the number of tokens.
This method can be executed offline, and the gpt2 tokenizer has been cached in the project.
:param text: plain text of prompt. You need to convert the original message to plain text
:return: number of tokens
"""
return GPT2Tokenizer.get_num_tokens(text)

View File

@ -1,4 +1,5 @@
- openai
- deepseek
- anthropic
- azure_openai
- google
@ -32,7 +33,6 @@
- localai
- volcengine_maas
- openai_api_compatible
- deepseek
- hunyuan
- siliconflow
- perfxcloud

View File

@ -0,0 +1,41 @@
model: gemini-2.0-flash-001
label:
en_US: Gemini 2.0 Flash 001
model_type: llm
features:
- agent-thought
- vision
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 1048576
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_output_tokens
use_template: max_tokens
default: 8192
min: 1
max: 8192
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'
unit: '0.000001'
currency: USD

View File

@ -0,0 +1,41 @@
model: gemini-2.0-flash-lite-preview-02-05
label:
en_US: Gemini 2.0 Flash Lite Preview 0205
model_type: llm
features:
- agent-thought
- vision
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 1048576
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_output_tokens
use_template: max_tokens
default: 8192
min: 1
max: 8192
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'
unit: '0.000001'
currency: USD

View File

@ -0,0 +1,39 @@
model: gemini-2.0-flash-thinking-exp-01-21
label:
en_US: Gemini 2.0 Flash Thinking Exp 0121
model_type: llm
features:
- agent-thought
- vision
- document
- video
- audio
model_properties:
mode: chat
context_size: 32767
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_output_tokens
use_template: max_tokens
default: 8192
min: 1
max: 8192
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'
unit: '0.000001'
currency: USD

View File

@ -0,0 +1,39 @@
model: gemini-2.0-flash-thinking-exp-1219
label:
en_US: Gemini 2.0 Flash Thinking Exp 1219
model_type: llm
features:
- agent-thought
- vision
- document
- video
- audio
model_properties:
mode: chat
context_size: 32767
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_output_tokens
use_template: max_tokens
default: 8192
min: 1
max: 8192
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'
unit: '0.000001'
currency: USD

View File

@ -0,0 +1,37 @@
model: gemini-2.0-pro-exp-02-05
label:
en_US: Gemini 2.0 Pro Exp 0205
model_type: llm
features:
- agent-thought
- document
model_properties:
mode: chat
context_size: 2000000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
en_US: Top k
type: int
help:
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: presence_penalty
use_template: presence_penalty
- name: frequency_penalty
use_template: frequency_penalty
- name: max_output_tokens
use_template: max_tokens
required: true
default: 8192
min: 1
max: 8192
pricing:
input: '0.00'
output: '0.00'
unit: '0.000001'
currency: USD

View File

@ -0,0 +1,41 @@
model: gemini-exp-1114
label:
en_US: Gemini exp 1114
model_type: llm
features:
- agent-thought
- vision
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 32767
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_output_tokens
use_template: max_tokens
default: 8192
min: 1
max: 8192
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'
unit: '0.000001'
currency: USD

View File

@ -0,0 +1,41 @@
model: gemini-exp-1121
label:
en_US: Gemini exp 1121
model_type: llm
features:
- agent-thought
- vision
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 32767
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_output_tokens
use_template: max_tokens
default: 8192
min: 1
max: 8192
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'
unit: '0.000001'
currency: USD

View File

@ -0,0 +1,41 @@
model: gemini-exp-1206
label:
en_US: Gemini exp 1206
model_type: llm
features:
- agent-thought
- vision
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 2097152
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_output_tokens
use_template: max_tokens
default: 8192
min: 1
max: 8192
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'
unit: '0.000001'
currency: USD

View File

@ -0,0 +1,66 @@
model: glm-4-air-0111
label:
en_US: glm-4-air-0111
model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
model_properties:
mode: chat
context_size: 131072
parameter_rules:
- name: temperature
use_template: temperature
default: 0.95
min: 0.0
max: 1.0
help:
zh_Hans: 采样温度,控制输出的随机性,必须为正数取值范围是:(0.0,1.0],不能等于 0,默认值为 0.95 值越大,会使输出更随机,更具创造性;值越小,输出会更加稳定或确定建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数。
en_US: Sampling temperature, controls the randomness of the output, must be a positive number. The value range is (0.0,1.0], which cannot be equal to 0. The default value is 0.95. The larger the value, the more random and creative the output will be; the smaller the value, The output will be more stable or certain. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time.
- name: top_p
use_template: top_p
default: 0.7
help:
zh_Hans: 用温度取样的另一种方法,称为核取样取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1默认值为 0.7 模型考虑具有 top_p 概率质量tokens的结果例如0.1 意味着模型解码器只考虑从前 10% 的概率的候选集中取 tokens 建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数。
en_US: Another method of temperature sampling is called kernel sampling. The value range is (0.0, 1.0) open interval, which cannot be equal to 0 or 1. The default value is 0.7. The model considers the results with top_p probability mass tokens. For example 0.1 means The model decoder only considers tokens from the candidate set with the top 10% probability. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time.
- name: do_sample
label:
zh_Hans: 采样策略
en_US: Sampling strategy
type: boolean
help:
zh_Hans: do_sample 为 true 时启用采样策略do_sample 为 false 时采样策略 temperature、top_p 将不生效。默认值为 true。
en_US: When `do_sample` is set to true, the sampling strategy is enabled. When `do_sample` is set to false, the sampling strategies such as `temperature` and `top_p` will not take effect. The default value is true.
default: true
- name: max_tokens
use_template: max_tokens
default: 1024
min: 1
max: 4095
- name: web_search
type: boolean
label:
zh_Hans: 联网搜索
en_US: Web Search
default: false
help:
zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。
en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic.
- name: response_format
label:
zh_Hans: 回复格式
en_US: Response Format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: '0.0005'
output: '0.0005'
unit: '0.001'
currency: RMB

View File

@ -1,6 +1,6 @@
import json
import time
from typing import cast
from typing import Any, cast
import requests
@ -14,48 +14,47 @@ class FirecrawlApp:
if self.api_key is None and self.base_url == "https://api.firecrawl.dev":
raise ValueError("No API key provided")
def scrape_url(self, url, params=None) -> dict:
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
json_data = {"url": url}
def scrape_url(self, url, params=None) -> dict[str, Any]:
# Documentation: https://docs.firecrawl.dev/api-reference/endpoint/scrape
headers = self._prepare_headers()
json_data = {
"url": url,
"formats": ["markdown"],
"onlyMainContent": True,
"timeout": 30000,
}
if params:
json_data.update(params)
response = requests.post(f"{self.base_url}/v0/scrape", headers=headers, json=json_data)
response = self._post_request(f"{self.base_url}/v1/scrape", json_data, headers)
if response.status_code == 200:
response_data = response.json()
if response_data["success"] == True:
data = response_data["data"]
return {
"title": data.get("metadata").get("title"),
"description": data.get("metadata").get("description"),
"source_url": data.get("metadata").get("sourceURL"),
"markdown": data.get("markdown"),
}
else:
raise Exception(f"Failed to scrape URL. Error: {response_data['error']}")
elif response.status_code in {402, 409, 500}:
error_message = response.json().get("error", "Unknown error occurred")
raise Exception(f"Failed to scrape URL. Status code: {response.status_code}. Error: {error_message}")
data = response_data["data"]
return self._extract_common_fields(data)
elif response.status_code in {402, 409, 500, 429, 408}:
self._handle_error(response, "scrape URL")
return {} # Avoid additional exception after handling error
else:
raise Exception(f"Failed to scrape URL. Status code: {response.status_code}")
def crawl_url(self, url, params=None) -> str:
# Documentation: https://docs.firecrawl.dev/api-reference/endpoint/crawl-post
headers = self._prepare_headers()
json_data = {"url": url}
if params:
json_data.update(params)
response = self._post_request(f"{self.base_url}/v0/crawl", json_data, headers)
response = self._post_request(f"{self.base_url}/v1/crawl", json_data, headers)
if response.status_code == 200:
job_id = response.json().get("jobId")
# There's also another two fields in the response: "success" (bool) and "url" (str)
job_id = response.json().get("id")
return cast(str, job_id)
else:
self._handle_error(response, "start crawl job")
# FIXME: unreachable code for mypy
return "" # unreachable
def check_crawl_status(self, job_id) -> dict:
def check_crawl_status(self, job_id) -> dict[str, Any]:
headers = self._prepare_headers()
response = self._get_request(f"{self.base_url}/v0/crawl/status/{job_id}", headers)
response = self._get_request(f"{self.base_url}/v1/crawl/{job_id}", headers)
if response.status_code == 200:
crawl_status_response = response.json()
if crawl_status_response.get("status") == "completed":
@ -66,42 +65,48 @@ class FirecrawlApp:
url_data_list = []
for item in data:
if isinstance(item, dict) and "metadata" in item and "markdown" in item:
url_data = {
"title": item.get("metadata", {}).get("title"),
"description": item.get("metadata", {}).get("description"),
"source_url": item.get("metadata", {}).get("sourceURL"),
"markdown": item.get("markdown"),
}
url_data = self._extract_common_fields(item)
url_data_list.append(url_data)
if url_data_list:
file_key = "website_files/" + job_id + ".txt"
if storage.exists(file_key):
storage.delete(file_key)
storage.save(file_key, json.dumps(url_data_list).encode("utf-8"))
return {
"status": "completed",
"total": crawl_status_response.get("total"),
"current": crawl_status_response.get("current"),
"data": url_data_list,
}
try:
if storage.exists(file_key):
storage.delete(file_key)
storage.save(file_key, json.dumps(url_data_list).encode("utf-8"))
except Exception as e:
raise Exception(f"Error saving crawl data: {e}")
return self._format_crawl_status_response("completed", crawl_status_response, url_data_list)
else:
return {
"status": crawl_status_response.get("status"),
"total": crawl_status_response.get("total"),
"current": crawl_status_response.get("current"),
"data": [],
}
return self._format_crawl_status_response(
crawl_status_response.get("status"), crawl_status_response, []
)
else:
self._handle_error(response, "check crawl status")
# FIXME: unreachable code for mypy
return {} # unreachable
def _prepare_headers(self):
def _format_crawl_status_response(
self, status: str, crawl_status_response: dict[str, Any], url_data_list: list[dict[str, Any]]
) -> dict[str, Any]:
return {
"status": status,
"total": crawl_status_response.get("total"),
"current": crawl_status_response.get("completed"),
"data": url_data_list,
}
def _extract_common_fields(self, item: dict[str, Any]) -> dict[str, Any]:
return {
"title": item.get("metadata", {}).get("title"),
"description": item.get("metadata", {}).get("description"),
"source_url": item.get("metadata", {}).get("sourceURL"),
"markdown": item.get("markdown"),
}
def _prepare_headers(self) -> dict[str, Any]:
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
def _post_request(self, url, data, headers, retries=3, backoff_factor=0.5):
def _post_request(self, url, data, headers, retries=3, backoff_factor=0.5) -> requests.Response:
for attempt in range(retries):
response = requests.post(url, headers=headers, json=data)
if response.status_code == 502:
@ -110,7 +115,7 @@ class FirecrawlApp:
return response
return response
def _get_request(self, url, headers, retries=3, backoff_factor=0.5):
def _get_request(self, url, headers, retries=3, backoff_factor=0.5) -> requests.Response:
for attempt in range(retries):
response = requests.get(url, headers=headers)
if response.status_code == 502:
@ -119,6 +124,6 @@ class FirecrawlApp:
return response
return response
def _handle_error(self, response, action):
def _handle_error(self, response, action) -> None:
error_message = response.json().get("error", "Unknown error occurred")
raise Exception(f"Failed to {action}. Status code: {response.status_code}. Error: {error_message}")

View File

@ -13,9 +13,10 @@ class FirecrawlWebExtractor(BaseExtractor):
api_key: The API key for Firecrawl.
base_url: The base URL for the Firecrawl API. Defaults to 'https://api.firecrawl.dev'.
mode: The mode of operation. Defaults to 'scrape'. Options are 'crawl', 'scrape' and 'crawl_return_urls'.
only_main_content: Only return the main content of the page excluding headers, navs, footers, etc.
"""
def __init__(self, url: str, job_id: str, tenant_id: str, mode: str = "crawl", only_main_content: bool = False):
def __init__(self, url: str, job_id: str, tenant_id: str, mode: str = "crawl", only_main_content: bool = True):
"""Initialize with url, api_key, base_url and mode."""
self._url = url
self.job_id = job_id

View File

@ -223,14 +223,14 @@ class WorkflowTool(Tool):
if isinstance(value, list):
for item in value:
if isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY:
item["tool_file_id"] = item.get("related_id")
item = self._update_file_mapping(item)
file = build_from_mapping(
mapping=item,
tenant_id=str(cast(ToolRuntime, self.runtime).tenant_id),
)
files.append(file)
elif isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
value["tool_file_id"] = value.get("related_id")
value = self._update_file_mapping(value)
file = build_from_mapping(
mapping=value,
tenant_id=str(cast(ToolRuntime, self.runtime).tenant_id),
@ -240,3 +240,11 @@ class WorkflowTool(Tool):
result[key] = value
return result, files
def _update_file_mapping(self, file_dict: dict) -> dict:
transfer_method = FileTransferMethod.value_of(file_dict.get("transfer_method"))
if transfer_method == FileTransferMethod.TOOL_FILE:
file_dict["tool_file_id"] = file_dict.get("related_id")
elif transfer_method == FileTransferMethod.LOCAL_FILE:
file_dict["upload_file_id"] = file_dict.get("related_id")
return file_dict

View File

@ -590,6 +590,8 @@ class Graph(BaseModel):
start_node_id=node_id,
routes_node_ids=routes_node_ids,
)
# Exclude conditional branch nodes
and all(edge.run_condition is None for edge in reverse_edge_mapping.get(node_id, []))
):
if node_id not in merge_branch_node_ids:
merge_branch_node_ids[node_id] = []

View File

@ -195,7 +195,7 @@ class CodeNode(BaseNode[CodeNodeData]):
if output_config.type == "object":
# check if output is object
if not isinstance(result.get(output_name), dict):
if isinstance(result.get(output_name), type(None)):
if result[output_name] is None:
transformed_result[output_name] = None
else:
raise OutputValidationError(
@ -223,7 +223,7 @@ class CodeNode(BaseNode[CodeNodeData]):
elif output_config.type == "array[number]":
# check if array of number available
if not isinstance(result[output_name], list):
if isinstance(result[output_name], type(None)):
if result[output_name] is None:
transformed_result[output_name] = None
else:
raise OutputValidationError(
@ -244,7 +244,7 @@ class CodeNode(BaseNode[CodeNodeData]):
elif output_config.type == "array[string]":
# check if array of string available
if not isinstance(result[output_name], list):
if isinstance(result[output_name], type(None)):
if result[output_name] is None:
transformed_result[output_name] = None
else:
raise OutputValidationError(
@ -265,7 +265,7 @@ class CodeNode(BaseNode[CodeNodeData]):
elif output_config.type == "array[object]":
# check if array of object available
if not isinstance(result[output_name], list):
if isinstance(result[output_name], type(None)):
if result[output_name] is None:
transformed_result[output_name] = None
else:
raise OutputValidationError(

View File

@ -3,7 +3,7 @@ from typing import Any, Optional
from pydantic import BaseModel, Field, field_validator
from core.model_runtime.entities import ImagePromptMessageContent
from core.model_runtime.entities import ImagePromptMessageContent, LLMMode
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.nodes.base import BaseNodeData
@ -12,7 +12,7 @@ from core.workflow.nodes.base import BaseNodeData
class ModelConfig(BaseModel):
provider: str
name: str
mode: str
mode: LLMMode
completion_params: dict[str, Any] = {}

View File

@ -3,6 +3,7 @@ import logging
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional, cast
from configs import dify_config
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.model_entities import ModelStatus
from core.entities.provider_entities import QuotaUnit
@ -185,6 +186,8 @@ class LLMNode(BaseNode[LLMNodeData]):
result_text = event.text
usage = event.usage
finish_reason = event.finish_reason
# deduct quota
self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
break
except LLMNodeError as e:
yield RunCompletedEvent(
@ -240,17 +243,7 @@ class LLMNode(BaseNode[LLMNodeData]):
user=self.user_id,
)
# handle invoke result
generator = self._handle_invoke_result(invoke_result=invoke_result)
usage = LLMUsage.empty_usage()
for event in generator:
yield event
if isinstance(event, ModelInvokeCompletedEvent):
usage = event.usage
# deduct quota
self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
return self._handle_invoke_result(invoke_result=invoke_result)
def _handle_invoke_result(self, invoke_result: LLMResult | Generator) -> Generator[NodeEvent, None, None]:
if isinstance(invoke_result, LLMResult):
@ -740,10 +733,7 @@ class LLMNode(BaseNode[LLMNodeData]):
if quota_unit == QuotaUnit.TOKENS:
used_quota = usage.total_tokens
elif quota_unit == QuotaUnit.CREDITS:
used_quota = 1
if "gpt-4" in model_instance.model:
used_quota = 20
used_quota = dify_config.get_model_credits(model_instance.model)
else:
used_quota = 1

View File

@ -20,11 +20,11 @@ if [[ "${MODE}" == "worker" ]]; then
CONCURRENCY_OPTION="-c ${CELERY_WORKER_AMOUNT:-1}"
fi
exec celery -A app.celery worker -P ${CELERY_WORKER_CLASS:-gevent} $CONCURRENCY_OPTION --loglevel ${LOG_LEVEL} \
exec celery -A app.celery worker -P ${CELERY_WORKER_CLASS:-gevent} $CONCURRENCY_OPTION --loglevel ${LOG_LEVEL:-INFO} \
-Q ${CELERY_QUEUES:-dataset,mail,ops_trace,app_deletion}
elif [[ "${MODE}" == "beat" ]]; then
exec celery -A app.celery beat --loglevel ${LOG_LEVEL}
exec celery -A app.celery beat --loglevel ${LOG_LEVEL:-INFO}
else
if [[ "${DEBUG}" == "true" ]]; then
exec flask run --host=${DIFY_BIND_ADDRESS:-0.0.0.0} --port=${DIFY_PORT:-5001} --debug

View File

@ -1,3 +1,4 @@
from configs import dify_config
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity
from core.entities.provider_entities import QuotaUnit
from events.message_event import message_was_created
@ -40,10 +41,7 @@ def handle(sender, **kwargs):
if quota_unit == QuotaUnit.TOKENS:
used_quota = message.message_tokens + message.answer_tokens
elif quota_unit == QuotaUnit.CREDITS:
used_quota = 1
if "gpt-4" in model_config.model:
used_quota = 20
used_quota = dify_config.get_model_credits(model_config.model)
else:
used_quota = 1

View File

@ -27,12 +27,11 @@ def init_app(app: DifyApp):
# Always add StreamHandler to log to console
sh = logging.StreamHandler(sys.stdout)
sh.addFilter(RequestIdFilter())
log_formatter = logging.Formatter(fmt=dify_config.LOG_FORMAT)
sh.setFormatter(log_formatter)
log_handlers.append(sh)
logging.basicConfig(
level=dify_config.LOG_LEVEL,
format=dify_config.LOG_FORMAT,
datefmt=dify_config.LOG_DATEFORMAT,
handlers=log_handlers,
force=True,

View File

@ -32,7 +32,11 @@ class AwsS3Storage(BaseStorage):
aws_access_key_id=dify_config.S3_ACCESS_KEY,
endpoint_url=dify_config.S3_ENDPOINT,
region_name=dify_config.S3_REGION,
config=Config(s3={"addressing_style": dify_config.S3_ADDRESS_STYLE}),
config=Config(
s3={"addressing_style": dify_config.S3_ADDRESS_STYLE},
request_checksum_calculation="when_required",
response_checksum_validation="when_required",
),
)
# create bucket
try:

View File

@ -1,6 +1,8 @@
from collections.abc import Generator
from datetime import UTC, datetime, timedelta
from typing import Optional
from azure.identity import ChainedTokenCredential, DefaultAzureCredential
from azure.storage.blob import AccountSasPermissions, BlobServiceClient, ResourceTypes, generate_account_sas
from configs import dify_config
@ -18,6 +20,12 @@ class AzureBlobStorage(BaseStorage):
self.account_name = dify_config.AZURE_BLOB_ACCOUNT_NAME
self.account_key = dify_config.AZURE_BLOB_ACCOUNT_KEY
self.credential: Optional[ChainedTokenCredential] = None
if self.account_key == "managedidentity":
self.credential = DefaultAzureCredential()
else:
self.credential = None
def save(self, filename, data):
client = self._sync_client()
blob_container = client.get_container_client(container=self.bucket_name)
@ -57,6 +65,9 @@ class AzureBlobStorage(BaseStorage):
blob_container.delete_blob(filename)
def _sync_client(self):
if self.account_key == "managedidentity":
return BlobServiceClient(account_url=self.account_url, credential=self.credential) # type: ignore
cache_key = "azure_blob_sas_token_{}_{}".format(self.account_name, self.account_key)
cache_result = redis_client.get(cache_key)
if cache_result is not None:

View File

@ -1147,8 +1147,10 @@ class Message(Base):
"id": self.id,
"app_id": self.app_id,
"conversation_id": self.conversation_id,
"model_id": self.model_id,
"inputs": self.inputs,
"query": self.query,
"total_price": self.total_price,
"message": self.message,
"answer": self.answer,
"status": self.status,
@ -1169,7 +1171,9 @@ class Message(Base):
id=data["id"],
app_id=data["app_id"],
conversation_id=data["conversation_id"],
model_id=data["model_id"],
inputs=data["inputs"],
total_price=data["total_price"],
query=data["query"],
message=data["message"],
answer=data["answer"],

1567
api/poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -18,7 +18,7 @@ package-mode = false
authlib = "1.3.1"
azure-identity = "1.16.1"
beautifulsoup4 = "4.12.2"
boto3 = "1.35.74"
boto3 = "1.36.12"
bs4 = "~0.0.1"
cachetools = "~5.3.0"
celery = "~5.4.0"
@ -47,7 +47,7 @@ mailchimp-transactional = "~1.0.50"
markdown = "~3.5.1"
numpy = "~1.26.4"
oci = "~2.135.1"
openai = "~1.52.0"
openai = "~1.61.0"
openpyxl = "~3.1.5"
opik = "~1.3.4"
pandas = { version = "~2.2.2", extras = ["performance", "excel"] }
@ -73,7 +73,6 @@ starlette = "0.41.0"
tiktoken = "~0.8.0"
tokenizers = "~0.15.0"
transformers = "~4.35.0"
types-pytz = "~2024.2.0.20241003"
unstructured = { version = "~0.16.1", extras = ["docx", "epub", "md", "msg", "ppt", "pptx"] }
validators = "0.21.0"
yarl = "~1.18.3"
@ -150,6 +149,21 @@ pytest = "~8.3.2"
pytest-benchmark = "~4.0.0"
pytest-env = "~1.1.3"
pytest-mock = "~3.14.0"
types-beautifulsoup4 = "~4.12.0.20241020"
types-flask-cors = "~5.0.0.20240902"
types-flask-migrate = "~4.1.0.20250112"
types-html5lib = "~1.1.11.20241018"
types-openpyxl = "~3.1.5.20241225"
types-protobuf = "~5.29.1.20241207"
types-psutil = "~6.1.0.20241221"
types-psycopg2 = "~2.9.21.20250121"
types-python-dateutil = "~2.9.0.20241206"
types-pytz = "~2024.2.0.20241221"
types-pyyaml = "~6.0.12.20241230"
types-regex = "~2024.11.6.20241221"
types-requests = "~2.32.0.20241016"
types-six = "~1.17.0.20241205"
types-tqdm = "~4.67.0.20241221"
############################################################
# [ Lint ] dependency group

View File

@ -79,6 +79,7 @@ class AccountService:
prefix="email_code_account_deletion_rate_limit", max_attempts=1, time_window=60 * 1
)
LOGIN_MAX_ERROR_LIMITS = 5
FORGOT_PASSWORD_MAX_ERROR_LIMITS = 5
@staticmethod
def _get_refresh_token_key(refresh_token: str) -> str:
@ -525,6 +526,32 @@ class AccountService:
key = f"login_error_rate_limit:{email}"
redis_client.delete(key)
@staticmethod
def add_forgot_password_error_rate_limit(email: str) -> None:
key = f"forgot_password_error_rate_limit:{email}"
count = redis_client.get(key)
if count is None:
count = 0
count = int(count) + 1
redis_client.setex(key, dify_config.FORGOT_PASSWORD_LOCKOUT_DURATION, count)
@staticmethod
def is_forgot_password_error_rate_limit(email: str) -> bool:
key = f"forgot_password_error_rate_limit:{email}"
count = redis_client.get(key)
if count is None:
return False
count = int(count)
if count > AccountService.FORGOT_PASSWORD_MAX_ERROR_LIMITS:
return True
return False
@staticmethod
def reset_forgot_password_error_rate_limit(email: str):
key = f"forgot_password_error_rate_limit:{email}"
redis_client.delete(key)
@staticmethod
def is_email_send_ip_limit(ip_address: str):
minute_key = f"email_send_ip_limit_minute:{ip_address}"

View File

@ -21,10 +21,12 @@ class FirecrawlAuth(ApiKeyAuthBase):
headers = self._prepare_headers()
options = {
"url": "https://example.com",
"crawlerOptions": {"excludes": [], "includes": [], "limit": 1},
"pageOptions": {"onlyMainContent": True},
"includePaths": [],
"excludePaths": [],
"limit": 1,
"scrapeOptions": {"onlyMainContent": True},
}
response = self._post_request(f"{self.base_url}/v0/crawl", options, headers)
response = self._post_request(f"{self.base_url}/v1/crawl", options, headers)
if response.status_code == 200:
return True
else:

View File

@ -44,6 +44,7 @@ from models.source import DataSourceOauthBinding
from services.entities.knowledge_entities.knowledge_entities import (
ChildChunkUpdateArgs,
KnowledgeConfig,
MetaDataConfig,
RerankingModel,
RetrievalModel,
SegmentUpdateArgs,
@ -915,6 +916,9 @@ class DocumentService:
document.data_source_info = json.dumps(data_source_info)
document.batch = batch
document.indexing_status = "waiting"
if knowledge_config.metadata:
document.doc_type = knowledge_config.metadata.doc_type
document.metadata = knowledge_config.metadata.doc_metadata
db.session.add(document)
documents.append(document)
duplicate_document_ids.append(document.id)
@ -931,6 +935,7 @@ class DocumentService:
account,
file_name,
batch,
knowledge_config.metadata,
)
db.session.add(document)
db.session.flush()
@ -986,6 +991,7 @@ class DocumentService:
account,
page.page_name,
batch,
knowledge_config.metadata,
)
db.session.add(document)
db.session.flush()
@ -1026,6 +1032,7 @@ class DocumentService:
account,
document_name,
batch,
knowledge_config.metadata,
)
db.session.add(document)
db.session.flush()
@ -1063,6 +1070,7 @@ class DocumentService:
account: Account,
name: str,
batch: str,
metadata: Optional[MetaDataConfig] = None,
):
document = Document(
tenant_id=dataset.tenant_id,
@ -1078,6 +1086,9 @@ class DocumentService:
doc_form=document_form,
doc_language=document_language,
)
if metadata is not None:
document.doc_metadata = metadata.doc_metadata
document.doc_type = metadata.doc_type
return document
@staticmethod
@ -1190,6 +1201,10 @@ class DocumentService:
# update document name
if document_data.name:
document.name = document_data.name
# update doc_type and doc_metadata if provided
if document_data.metadata is not None:
document.doc_metadata = document_data.metadata.doc_type
document.doc_type = document_data.metadata.doc_type
# update document to be waiting
document.indexing_status = "waiting"
document.completed_at = None

View File

@ -93,6 +93,11 @@ class RetrievalModel(BaseModel):
score_threshold: Optional[float] = None
class MetaDataConfig(BaseModel):
doc_type: str
doc_metadata: dict
class KnowledgeConfig(BaseModel):
original_document_id: Optional[str] = None
duplicate: bool = True
@ -105,6 +110,7 @@ class KnowledgeConfig(BaseModel):
embedding_model: Optional[str] = None
embedding_model_provider: Optional[str] = None
name: Optional[str] = None
metadata: Optional[MetaDataConfig] = None
class SegmentUpdateArgs(BaseModel):

View File

@ -38,30 +38,22 @@ class WebsiteService:
only_main_content = options.get("only_main_content", False)
if not crawl_sub_pages:
params = {
"crawlerOptions": {
"includes": [],
"excludes": [],
"generateImgAltText": True,
"limit": 1,
"returnOnlyUrls": False,
"pageOptions": {"onlyMainContent": only_main_content, "includeHtml": False},
}
"includePaths": [],
"excludePaths": [],
"limit": 1,
"scrapeOptions": {"onlyMainContent": only_main_content},
}
else:
includes = options.get("includes").split(",") if options.get("includes") else []
excludes = options.get("excludes").split(",") if options.get("excludes") else []
params = {
"crawlerOptions": {
"includes": includes,
"excludes": excludes,
"generateImgAltText": True,
"limit": options.get("limit", 1),
"returnOnlyUrls": False,
"pageOptions": {"onlyMainContent": only_main_content, "includeHtml": False},
}
"includePaths": includes,
"excludePaths": excludes,
"limit": options.get("limit", 1),
"scrapeOptions": {"onlyMainContent": only_main_content},
}
if options.get("max_depth"):
params["crawlerOptions"]["maxDepth"] = options.get("max_depth")
params["maxDepth"] = options.get("max_depth")
job_id = firecrawl_app.crawl_url(url, params)
website_crawl_time_cache_key = f"website_crawl_{job_id}"
time = str(datetime.datetime.now().timestamp())
@ -228,7 +220,7 @@ class WebsiteService:
# decrypt api_key
api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key"))
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None))
params = {"pageOptions": {"onlyMainContent": only_main_content, "includeHtml": False}}
params = {"onlyMainContent": only_main_content}
result = firecrawl_app.scrape_url(url, params)
return result
else:

View File

@ -10,19 +10,17 @@ def test_firecrawl_web_extractor_crawl_mode(mocker):
base_url = "https://api.firecrawl.dev"
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=base_url)
params = {
"crawlerOptions": {
"includes": [],
"excludes": [],
"generateImgAltText": True,
"maxDepth": 1,
"limit": 1,
"returnOnlyUrls": False,
}
"includePaths": [],
"excludePaths": [],
"maxDepth": 1,
"limit": 1,
}
mocked_firecrawl = {
"jobId": "test",
"id": "test",
}
mocker.patch("requests.post", return_value=_mock_response(mocked_firecrawl))
job_id = firecrawl_app.crawl_url(url, params)
print(job_id)
print(f"job_id: {job_id}")
assert job_id is not None
assert isinstance(job_id, str)