mirror of
https://github.com/langgenius/dify.git
synced 2026-05-01 07:58:02 +08:00
Merge branch 'feat/model-plugins-implementing' into deploy/dev
This commit is contained in:
@ -4,6 +4,7 @@ import logging
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from enum import StrEnum
|
||||
from typing import cast
|
||||
from urllib.parse import urlparse
|
||||
from uuid import uuid4
|
||||
|
||||
@ -32,7 +33,7 @@ from extensions.ext_redis import redis_client
|
||||
from factories import variable_factory
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account, App, AppMode
|
||||
from models.model import AppModelConfig, IconType
|
||||
from models.model import AppModelConfig, AppModelConfigDict, IconType
|
||||
from models.workflow import Workflow
|
||||
from services.plugin.dependencies_analysis import DependenciesAnalysisService
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableService
|
||||
@ -523,7 +524,7 @@ class AppDslService:
|
||||
if not app.app_model_config:
|
||||
app_model_config = AppModelConfig(
|
||||
app_id=app.id, created_by=account.id, updated_by=account.id
|
||||
).from_model_config_dict(model_config)
|
||||
).from_model_config_dict(cast(AppModelConfigDict, model_config))
|
||||
app_model_config.id = str(uuid4())
|
||||
app.app_model_config_id = app_model_config.id
|
||||
|
||||
|
||||
@ -1,12 +1,12 @@
|
||||
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
|
||||
from core.app.apps.chat.app_config_manager import ChatAppConfigManager
|
||||
from core.app.apps.completion.app_config_manager import CompletionAppConfigManager
|
||||
from models.model import AppMode
|
||||
from models.model import AppMode, AppModelConfigDict
|
||||
|
||||
|
||||
class AppModelConfigService:
|
||||
@classmethod
|
||||
def validate_configuration(cls, tenant_id: str, config: dict, app_mode: AppMode):
|
||||
def validate_configuration(cls, tenant_id: str, config: dict, app_mode: AppMode) -> AppModelConfigDict:
|
||||
if app_mode == AppMode.CHAT:
|
||||
return ChatAppConfigManager.config_validate(tenant_id, config)
|
||||
elif app_mode == AppMode.AGENT_CHAT:
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import TypedDict, cast
|
||||
from typing import Any, TypedDict, cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask_sqlalchemy.pagination import Pagination
|
||||
@ -187,7 +187,7 @@ class AppService:
|
||||
for tool in agent_mode.get("tools") or []:
|
||||
if not isinstance(tool, dict) or len(tool.keys()) <= 3:
|
||||
continue
|
||||
agent_tool_entity = AgentToolEntity(**tool)
|
||||
agent_tool_entity = AgentToolEntity(**cast(dict[str, Any], tool))
|
||||
# get tool
|
||||
try:
|
||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||
@ -388,7 +388,7 @@ class AppService:
|
||||
agent_config = app_model_config.agent_mode_dict
|
||||
|
||||
# get all tools
|
||||
tools = agent_config.get("tools", [])
|
||||
tools = cast(list[dict[str, Any]], agent_config.get("tools", []))
|
||||
|
||||
url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/builtin/"
|
||||
|
||||
|
||||
@ -2,6 +2,7 @@ import io
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from typing import cast
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
from werkzeug.datastructures import FileStorage
|
||||
@ -106,7 +107,7 @@ class AudioService:
|
||||
if not text_to_speech_dict.get("enabled"):
|
||||
raise ValueError("TTS is not enabled")
|
||||
|
||||
voice = text_to_speech_dict.get("voice")
|
||||
voice = cast(str | None, text_to_speech_dict.get("voice"))
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
|
||||
@ -63,7 +63,12 @@ class RagPipelineTransformService:
|
||||
):
|
||||
node = self._deal_file_extensions(node)
|
||||
if node.get("data", {}).get("type") == "knowledge-index":
|
||||
node = self._deal_knowledge_index(dataset, doc_form, indexing_technique, retrieval_model, node)
|
||||
knowledge_configuration = KnowledgeConfiguration.model_validate(node.get("data", {}))
|
||||
if dataset.tenant_id != current_user.current_tenant_id:
|
||||
raise ValueError("Unauthorized")
|
||||
node = self._deal_knowledge_index(
|
||||
knowledge_configuration, dataset, indexing_technique, retrieval_model, node
|
||||
)
|
||||
new_nodes.append(node)
|
||||
if new_nodes:
|
||||
graph["nodes"] = new_nodes
|
||||
@ -155,14 +160,13 @@ class RagPipelineTransformService:
|
||||
|
||||
def _deal_knowledge_index(
|
||||
self,
|
||||
knowledge_configuration: KnowledgeConfiguration,
|
||||
dataset: Dataset,
|
||||
doc_form: str,
|
||||
indexing_technique: str | None,
|
||||
retrieval_model: RetrievalSetting | None,
|
||||
node: dict,
|
||||
):
|
||||
knowledge_configuration_dict = node.get("data", {})
|
||||
knowledge_configuration = KnowledgeConfiguration.model_validate(knowledge_configuration_dict)
|
||||
|
||||
if indexing_technique == "high_quality":
|
||||
knowledge_configuration.embedding_model = dataset.embedding_model
|
||||
|
||||
@ -12,6 +12,7 @@ from sqlalchemy.engine import CursorResult
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.model import (
|
||||
App,
|
||||
AppAnnotationHitHistory,
|
||||
@ -142,7 +143,7 @@ class MessagesCleanService:
|
||||
if batch_size <= 0:
|
||||
raise ValueError(f"batch_size ({batch_size}) must be greater than 0")
|
||||
|
||||
end_before = datetime.datetime.now() - datetime.timedelta(days=days)
|
||||
end_before = naive_utc_now() - datetime.timedelta(days=days)
|
||||
|
||||
logger.info(
|
||||
"clean_messages: days=%s, end_before=%s, batch_size=%s, policy=%s",
|
||||
|
||||
Reference in New Issue
Block a user