mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 00:48:04 +08:00
Merge branch 'refs/heads/main' into feat/workflow-parallel-support
# Conflicts: # api/core/app/apps/advanced_chat/app_generator.py # api/core/app/apps/advanced_chat/generate_task_pipeline.py # api/core/app/apps/workflow/app_runner.py # api/core/app/apps/workflow/generate_task_pipeline.py # api/core/app/task_pipeline/workflow_cycle_state_manager.py # api/core/workflow/entities/variable_pool.py # api/core/workflow/nodes/code/code_node.py # api/core/workflow/nodes/llm/llm_node.py # api/core/workflow/nodes/start/start_node.py # api/core/workflow/nodes/variable_assigner/__init__.py # api/tests/integration_tests/workflow/nodes/test_llm.py # api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py # api/tests/unit_tests/core/workflow/nodes/test_answer.py # api/tests/unit_tests/core/workflow/nodes/test_if_else.py # api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py
This commit is contained in:
@ -82,6 +82,7 @@ class AppDslService:
|
||||
# get app basic info
|
||||
name = args.get("name") if args.get("name") else app_data.get('name')
|
||||
description = args.get("description") if args.get("description") else app_data.get('description', '')
|
||||
icon_type = args.get("icon_type") if args.get("icon_type") else app_data.get('icon_type')
|
||||
icon = args.get("icon") if args.get("icon") else app_data.get('icon')
|
||||
icon_background = args.get("icon_background") if args.get("icon_background") \
|
||||
else app_data.get('icon_background')
|
||||
@ -96,6 +97,7 @@ class AppDslService:
|
||||
account=account,
|
||||
name=name,
|
||||
description=description,
|
||||
icon_type=icon_type,
|
||||
icon=icon,
|
||||
icon_background=icon_background
|
||||
)
|
||||
@ -107,6 +109,7 @@ class AppDslService:
|
||||
account=account,
|
||||
name=name,
|
||||
description=description,
|
||||
icon_type=icon_type,
|
||||
icon=icon,
|
||||
icon_background=icon_background
|
||||
)
|
||||
@ -165,8 +168,8 @@ class AppDslService:
|
||||
"app": {
|
||||
"name": app_model.name,
|
||||
"mode": app_model.mode,
|
||||
"icon": app_model.icon,
|
||||
"icon_background": app_model.icon_background,
|
||||
"icon": '🤖' if app_model.icon_type == 'image' else app_model.icon,
|
||||
"icon_background": '#FFEAD5' if app_model.icon_type == 'image' else app_model.icon_background,
|
||||
"description": app_model.description
|
||||
}
|
||||
}
|
||||
@ -207,6 +210,7 @@ class AppDslService:
|
||||
account: Account,
|
||||
name: str,
|
||||
description: str,
|
||||
icon_type: str,
|
||||
icon: str,
|
||||
icon_background: str) -> App:
|
||||
"""
|
||||
@ -218,6 +222,7 @@ class AppDslService:
|
||||
:param account: Account instance
|
||||
:param name: app name
|
||||
:param description: app description
|
||||
:param icon_type: app icon type, "emoji" or "image"
|
||||
:param icon: app icon
|
||||
:param icon_background: app icon background
|
||||
"""
|
||||
@ -231,6 +236,7 @@ class AppDslService:
|
||||
account=account,
|
||||
name=name,
|
||||
description=description,
|
||||
icon_type=icon_type,
|
||||
icon=icon,
|
||||
icon_background=icon_background
|
||||
)
|
||||
@ -307,6 +313,7 @@ class AppDslService:
|
||||
account: Account,
|
||||
name: str,
|
||||
description: str,
|
||||
icon_type: str,
|
||||
icon: str,
|
||||
icon_background: str) -> App:
|
||||
"""
|
||||
@ -331,6 +338,7 @@ class AppDslService:
|
||||
account=account,
|
||||
name=name,
|
||||
description=description,
|
||||
icon_type=icon_type,
|
||||
icon=icon,
|
||||
icon_background=icon_background
|
||||
)
|
||||
@ -358,6 +366,7 @@ class AppDslService:
|
||||
account: Account,
|
||||
name: str,
|
||||
description: str,
|
||||
icon_type: str,
|
||||
icon: str,
|
||||
icon_background: str) -> App:
|
||||
"""
|
||||
@ -368,6 +377,7 @@ class AppDslService:
|
||||
:param account: Account instance
|
||||
:param name: app name
|
||||
:param description: app description
|
||||
:param icon_type: app icon type, "emoji" or "image"
|
||||
:param icon: app icon
|
||||
:param icon_background: app icon background
|
||||
"""
|
||||
@ -376,6 +386,7 @@ class AppDslService:
|
||||
mode=app_mode.value,
|
||||
name=name,
|
||||
description=description,
|
||||
icon_type=icon_type,
|
||||
icon=icon,
|
||||
icon_background=icon_background,
|
||||
enable_site=True,
|
||||
|
||||
@ -111,6 +111,12 @@ class AppService:
|
||||
'completion_params': {}
|
||||
}
|
||||
else:
|
||||
provider, model = model_manager.get_default_provider_model_name(
|
||||
tenant_id=account.current_tenant_id,
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
default_model_config['model']['provider'] = provider
|
||||
default_model_config['model']['name'] = model
|
||||
default_model_dict = default_model_config['model']
|
||||
|
||||
default_model_config['model'] = json.dumps(default_model_dict)
|
||||
@ -119,6 +125,7 @@ class AppService:
|
||||
app.name = args['name']
|
||||
app.description = args.get('description', '')
|
||||
app.mode = args['mode']
|
||||
app.icon_type = args.get('icon_type', 'emoji')
|
||||
app.icon = args['icon']
|
||||
app.icon_background = args['icon_background']
|
||||
app.tenant_id = tenant_id
|
||||
@ -189,13 +196,14 @@ class AppService:
|
||||
"""
|
||||
Modified App class
|
||||
"""
|
||||
|
||||
def __init__(self, app):
|
||||
self.__dict__.update(app.__dict__)
|
||||
|
||||
@property
|
||||
def app_model_config(self):
|
||||
return model_config
|
||||
|
||||
|
||||
app = ModifiedApp(app)
|
||||
|
||||
return app
|
||||
@ -210,6 +218,7 @@ class AppService:
|
||||
app.name = args.get('name')
|
||||
app.description = args.get('description', '')
|
||||
app.max_active_requests = args.get('max_active_requests')
|
||||
app.icon_type = args.get('icon_type', 'emoji')
|
||||
app.icon = args.get('icon')
|
||||
app.icon_background = args.get('icon_background')
|
||||
app.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, Union
|
||||
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy import asc, desc, or_
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
@ -18,7 +19,8 @@ class ConversationService:
|
||||
last_id: Optional[str], limit: int,
|
||||
invoke_from: InvokeFrom,
|
||||
include_ids: Optional[list] = None,
|
||||
exclude_ids: Optional[list] = None) -> InfiniteScrollPagination:
|
||||
exclude_ids: Optional[list] = None,
|
||||
sort_by: str = '-updated_at') -> InfiniteScrollPagination:
|
||||
if not user:
|
||||
return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
|
||||
|
||||
@ -37,28 +39,28 @@ class ConversationService:
|
||||
if exclude_ids is not None:
|
||||
base_query = base_query.filter(~Conversation.id.in_(exclude_ids))
|
||||
|
||||
if last_id:
|
||||
last_conversation = base_query.filter(
|
||||
Conversation.id == last_id,
|
||||
).first()
|
||||
# define sort fields and directions
|
||||
sort_field, sort_direction = cls._get_sort_params(sort_by)
|
||||
|
||||
if last_id:
|
||||
last_conversation = base_query.filter(Conversation.id == last_id).first()
|
||||
if not last_conversation:
|
||||
raise LastConversationNotExistsError()
|
||||
|
||||
conversations = base_query.filter(
|
||||
Conversation.created_at < last_conversation.created_at,
|
||||
Conversation.id != last_conversation.id
|
||||
).order_by(Conversation.created_at.desc()).limit(limit).all()
|
||||
else:
|
||||
conversations = base_query.order_by(Conversation.created_at.desc()).limit(limit).all()
|
||||
# build filters based on sorting
|
||||
filter_condition = cls._build_filter_condition(sort_field, sort_direction, last_conversation)
|
||||
base_query = base_query.filter(filter_condition)
|
||||
|
||||
base_query = base_query.order_by(sort_direction(getattr(Conversation, sort_field)))
|
||||
|
||||
conversations = base_query.limit(limit).all()
|
||||
|
||||
has_more = False
|
||||
if len(conversations) == limit:
|
||||
current_page_first_conversation = conversations[-1]
|
||||
rest_count = base_query.filter(
|
||||
Conversation.created_at < current_page_first_conversation.created_at,
|
||||
Conversation.id != current_page_first_conversation.id
|
||||
).count()
|
||||
current_page_last_conversation = conversations[-1]
|
||||
rest_filter_condition = cls._build_filter_condition(sort_field, sort_direction,
|
||||
current_page_last_conversation, is_next_page=True)
|
||||
rest_count = base_query.filter(rest_filter_condition).count()
|
||||
|
||||
if rest_count > 0:
|
||||
has_more = True
|
||||
@ -69,6 +71,21 @@ class ConversationService:
|
||||
has_more=has_more
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _get_sort_params(cls, sort_by: str) -> tuple[str, callable]:
|
||||
if sort_by.startswith('-'):
|
||||
return sort_by[1:], desc
|
||||
return sort_by, asc
|
||||
|
||||
@classmethod
|
||||
def _build_filter_condition(cls, sort_field: str, sort_direction: callable, reference_conversation: Conversation,
|
||||
is_next_page: bool = False):
|
||||
field_value = getattr(reference_conversation, sort_field)
|
||||
if (sort_direction == desc and not is_next_page) or (sort_direction == asc and is_next_page):
|
||||
return getattr(Conversation, sort_field) < field_value
|
||||
else:
|
||||
return getattr(Conversation, sort_field) > field_value
|
||||
|
||||
@classmethod
|
||||
def rename(cls, app_model: App, conversation_id: str,
|
||||
user: Optional[Union[Account, EndUser]], name: str, auto_generate: bool):
|
||||
@ -78,6 +95,7 @@ class ConversationService:
|
||||
return cls.auto_generate_name(app_model, conversation)
|
||||
else:
|
||||
conversation.name = name
|
||||
conversation.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
|
||||
return conversation
|
||||
@ -87,9 +105,9 @@ class ConversationService:
|
||||
# get conversation first message
|
||||
message = db.session.query(Message) \
|
||||
.filter(
|
||||
Message.app_id == app_model.id,
|
||||
Message.conversation_id == conversation.id
|
||||
).order_by(Message.created_at.asc()).first()
|
||||
Message.app_id == app_model.id,
|
||||
Message.conversation_id == conversation.id
|
||||
).order_by(Message.created_at.asc()).first()
|
||||
|
||||
if not message:
|
||||
raise MessageNotExistsError()
|
||||
|
||||
@ -1429,7 +1429,10 @@ class SegmentService:
|
||||
segment_data_list.append(segment_document)
|
||||
|
||||
pre_segment_data_list.append(segment_document)
|
||||
keywords_list.append(segment_item['keywords'])
|
||||
if 'keywords' in segment_item:
|
||||
keywords_list.append(segment_item['keywords'])
|
||||
else:
|
||||
keywords_list.append(None)
|
||||
|
||||
try:
|
||||
# save vector index
|
||||
@ -1482,7 +1485,7 @@ class SegmentService:
|
||||
db.session.add(segment)
|
||||
db.session.commit()
|
||||
# update segment index task
|
||||
if args['keywords']:
|
||||
if 'keywords' in args:
|
||||
keyword = Keyword(dataset)
|
||||
keyword.delete_by_ids([segment.index_node_id])
|
||||
document = RAGDocument(
|
||||
|
||||
@ -30,6 +30,7 @@ class ModelProviderService:
|
||||
"""
|
||||
Model Provider Service
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.provider_manager = ProviderManager()
|
||||
|
||||
@ -387,18 +388,21 @@ class ModelProviderService:
|
||||
tenant_id=tenant_id,
|
||||
model_type=model_type_enum
|
||||
)
|
||||
|
||||
return DefaultModelResponse(
|
||||
model=result.model,
|
||||
model_type=result.model_type,
|
||||
provider=SimpleProviderEntityResponse(
|
||||
provider=result.provider.provider,
|
||||
label=result.provider.label,
|
||||
icon_small=result.provider.icon_small,
|
||||
icon_large=result.provider.icon_large,
|
||||
supported_model_types=result.provider.supported_model_types
|
||||
)
|
||||
) if result else None
|
||||
try:
|
||||
return DefaultModelResponse(
|
||||
model=result.model,
|
||||
model_type=result.model_type,
|
||||
provider=SimpleProviderEntityResponse(
|
||||
provider=result.provider.provider,
|
||||
label=result.provider.label,
|
||||
icon_small=result.provider.icon_small,
|
||||
icon_large=result.provider.icon_large,
|
||||
supported_model_types=result.provider.supported_model_types
|
||||
)
|
||||
) if result else None
|
||||
except Exception as e:
|
||||
logger.info(f"get_default_model_of_model_type error: {e}")
|
||||
return None
|
||||
|
||||
def update_default_model_of_model_type(self, tenant_id: str, model_type: str, provider: str, model: str) -> None:
|
||||
"""
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.position_helper import is_filtered
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.entities.api_entities import UserTool, UserToolProvider
|
||||
from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError
|
||||
@ -43,14 +45,14 @@ class BuiltinToolManageService:
|
||||
result = []
|
||||
for tool in tools:
|
||||
result.append(ToolTransformService.tool_to_user_tool(
|
||||
tool=tool,
|
||||
credentials=credentials,
|
||||
tool=tool,
|
||||
credentials=credentials,
|
||||
tenant_id=tenant_id,
|
||||
labels=ToolLabelManager.get_tool_labels(provider_controller)
|
||||
))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@staticmethod
|
||||
def list_builtin_provider_credentials_schema(
|
||||
provider_name
|
||||
@ -78,7 +80,7 @@ class BuiltinToolManageService:
|
||||
BuiltinToolProvider.provider == provider_name,
|
||||
).first()
|
||||
|
||||
try:
|
||||
try:
|
||||
# get provider
|
||||
provider_controller = ToolManager.get_builtin_provider(provider_name)
|
||||
if not provider_controller.need_credentials:
|
||||
@ -119,8 +121,8 @@ class BuiltinToolManageService:
|
||||
# delete cache
|
||||
tool_configuration.delete_tool_credentials_cache()
|
||||
|
||||
return { 'result': 'success' }
|
||||
|
||||
return {'result': 'success'}
|
||||
|
||||
@staticmethod
|
||||
def get_builtin_tool_provider_credentials(
|
||||
user_id: str, tenant_id: str, provider: str
|
||||
@ -135,7 +137,7 @@ class BuiltinToolManageService:
|
||||
|
||||
if provider is None:
|
||||
return {}
|
||||
|
||||
|
||||
provider_controller = ToolManager.get_builtin_provider(provider.provider)
|
||||
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
|
||||
credentials = tool_configuration.decrypt_tool_credentials(provider.credentials)
|
||||
@ -156,7 +158,7 @@ class BuiltinToolManageService:
|
||||
|
||||
if provider is None:
|
||||
raise ValueError(f'you have not added provider {provider_name}')
|
||||
|
||||
|
||||
db.session.delete(provider)
|
||||
db.session.commit()
|
||||
|
||||
@ -165,8 +167,8 @@ class BuiltinToolManageService:
|
||||
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
|
||||
tool_configuration.delete_tool_credentials_cache()
|
||||
|
||||
return { 'result': 'success' }
|
||||
|
||||
return {'result': 'success'}
|
||||
|
||||
@staticmethod
|
||||
def get_builtin_tool_provider_icon(
|
||||
provider: str
|
||||
@ -179,7 +181,7 @@ class BuiltinToolManageService:
|
||||
icon_bytes = f.read()
|
||||
|
||||
return icon_bytes, mime_type
|
||||
|
||||
|
||||
@staticmethod
|
||||
def list_builtin_tools(
|
||||
user_id: str, tenant_id: str
|
||||
@ -202,6 +204,15 @@ class BuiltinToolManageService:
|
||||
|
||||
for provider_controller in provider_controllers:
|
||||
try:
|
||||
# handle include, exclude
|
||||
if is_filtered(
|
||||
include_set=dify_config.POSITION_TOOL_INCLUDES_SET,
|
||||
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET,
|
||||
data=provider_controller,
|
||||
name_func=lambda x: x.identity.name
|
||||
):
|
||||
continue
|
||||
|
||||
# convert provider controller to user provider
|
||||
user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider(
|
||||
provider_controller=provider_controller,
|
||||
@ -226,4 +237,3 @@ class BuiltinToolManageService:
|
||||
raise e
|
||||
|
||||
return BuiltinToolProviderSort.sort(result)
|
||||
|
||||
@ -32,6 +32,7 @@ class WorkflowToolManageService:
|
||||
:param description: the description
|
||||
:param parameters: the parameters
|
||||
:param privacy_policy: the privacy policy
|
||||
:param labels: labels
|
||||
:return: the created tool
|
||||
"""
|
||||
WorkflowToolConfigurationUtils.check_parameter_configurations(parameters)
|
||||
@ -92,7 +93,14 @@ class WorkflowToolManageService:
|
||||
Update a workflow tool.
|
||||
:param user_id: the user id
|
||||
:param tenant_id: the tenant id
|
||||
:param tool: the tool
|
||||
:param workflow_tool_id: workflow tool id
|
||||
:param name: name
|
||||
:param label: label
|
||||
:param icon: icon
|
||||
:param description: description
|
||||
:param parameters: parameters
|
||||
:param privacy_policy: privacy policy
|
||||
:param labels: labels
|
||||
:return: the updated tool
|
||||
"""
|
||||
WorkflowToolConfigurationUtils.check_parameter_configurations(parameters)
|
||||
|
||||
@ -13,7 +13,8 @@ class WebConversationService:
|
||||
@classmethod
|
||||
def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account, EndUser]],
|
||||
last_id: Optional[str], limit: int, invoke_from: InvokeFrom,
|
||||
pinned: Optional[bool] = None) -> InfiniteScrollPagination:
|
||||
pinned: Optional[bool] = None,
|
||||
sort_by='-updated_at') -> InfiniteScrollPagination:
|
||||
include_ids = None
|
||||
exclude_ids = None
|
||||
if pinned is not None:
|
||||
@ -36,6 +37,7 @@ class WebConversationService:
|
||||
invoke_from=invoke_from,
|
||||
include_ids=include_ids,
|
||||
exclude_ids=exclude_ids,
|
||||
sort_by=sort_by
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -32,11 +32,9 @@ class WorkflowConverter:
|
||||
App Convert to Workflow Mode
|
||||
"""
|
||||
|
||||
def convert_to_workflow(self, app_model: App,
|
||||
account: Account,
|
||||
name: str,
|
||||
icon: str,
|
||||
icon_background: str) -> App:
|
||||
def convert_to_workflow(
|
||||
self, app_model: App, account: Account, name: str, icon_type: str, icon: str, icon_background: str
|
||||
):
|
||||
"""
|
||||
Convert app to workflow
|
||||
|
||||
@ -50,22 +48,24 @@ class WorkflowConverter:
|
||||
:param account: Account
|
||||
:param name: new app name
|
||||
:param icon: new app icon
|
||||
:param icon_type: new app icon type
|
||||
:param icon_background: new app icon background
|
||||
:return: new App instance
|
||||
"""
|
||||
# convert app model config
|
||||
if not app_model.app_model_config:
|
||||
raise ValueError("App model config is required")
|
||||
|
||||
workflow = self.convert_app_model_config_to_workflow(
|
||||
app_model=app_model,
|
||||
app_model_config=app_model.app_model_config,
|
||||
account_id=account.id
|
||||
app_model=app_model, app_model_config=app_model.app_model_config, account_id=account.id
|
||||
)
|
||||
|
||||
# create new app
|
||||
new_app = App()
|
||||
new_app.tenant_id = app_model.tenant_id
|
||||
new_app.name = name if name else app_model.name + '(workflow)'
|
||||
new_app.mode = AppMode.ADVANCED_CHAT.value \
|
||||
if app_model.mode == AppMode.CHAT.value else AppMode.WORKFLOW.value
|
||||
new_app.name = name if name else app_model.name + "(workflow)"
|
||||
new_app.mode = AppMode.ADVANCED_CHAT.value if app_model.mode == AppMode.CHAT.value else AppMode.WORKFLOW.value
|
||||
new_app.icon_type = icon_type if icon_type else app_model.icon_type
|
||||
new_app.icon = icon if icon else app_model.icon
|
||||
new_app.icon_background = icon_background if icon_background else app_model.icon_background
|
||||
new_app.enable_site = app_model.enable_site
|
||||
@ -85,30 +85,21 @@ class WorkflowConverter:
|
||||
|
||||
return new_app
|
||||
|
||||
def convert_app_model_config_to_workflow(self, app_model: App,
|
||||
app_model_config: AppModelConfig,
|
||||
account_id: str) -> Workflow:
|
||||
def convert_app_model_config_to_workflow(self, app_model: App, app_model_config: AppModelConfig, account_id: str):
|
||||
"""
|
||||
Convert app model config to workflow mode
|
||||
:param app_model: App instance
|
||||
:param app_model_config: AppModelConfig instance
|
||||
:param account_id: Account ID
|
||||
:return:
|
||||
"""
|
||||
# get new app mode
|
||||
new_app_mode = self._get_new_app_mode(app_model)
|
||||
|
||||
# convert app model config
|
||||
app_config = self._convert_to_app_config(
|
||||
app_model=app_model,
|
||||
app_model_config=app_model_config
|
||||
)
|
||||
app_config = self._convert_to_app_config(app_model=app_model, app_model_config=app_model_config)
|
||||
|
||||
# init workflow graph
|
||||
graph = {
|
||||
"nodes": [],
|
||||
"edges": []
|
||||
}
|
||||
graph = {"nodes": [], "edges": []}
|
||||
|
||||
# Convert list:
|
||||
# - variables -> start
|
||||
@ -120,11 +111,9 @@ class WorkflowConverter:
|
||||
# - show_retrieve_source -> knowledge-retrieval
|
||||
|
||||
# convert to start node
|
||||
start_node = self._convert_to_start_node(
|
||||
variables=app_config.variables
|
||||
)
|
||||
start_node = self._convert_to_start_node(variables=app_config.variables)
|
||||
|
||||
graph['nodes'].append(start_node)
|
||||
graph["nodes"].append(start_node)
|
||||
|
||||
# convert to http request node
|
||||
external_data_variable_node_mapping = {}
|
||||
@ -132,7 +121,7 @@ class WorkflowConverter:
|
||||
http_request_nodes, external_data_variable_node_mapping = self._convert_to_http_request_node(
|
||||
app_model=app_model,
|
||||
variables=app_config.variables,
|
||||
external_data_variables=app_config.external_data_variables
|
||||
external_data_variables=app_config.external_data_variables,
|
||||
)
|
||||
|
||||
for http_request_node in http_request_nodes:
|
||||
@ -141,9 +130,7 @@ class WorkflowConverter:
|
||||
# convert to knowledge retrieval node
|
||||
if app_config.dataset:
|
||||
knowledge_retrieval_node = self._convert_to_knowledge_retrieval_node(
|
||||
new_app_mode=new_app_mode,
|
||||
dataset_config=app_config.dataset,
|
||||
model_config=app_config.model
|
||||
new_app_mode=new_app_mode, dataset_config=app_config.dataset, model_config=app_config.model
|
||||
)
|
||||
|
||||
if knowledge_retrieval_node:
|
||||
@ -157,7 +144,7 @@ class WorkflowConverter:
|
||||
model_config=app_config.model,
|
||||
prompt_template=app_config.prompt_template,
|
||||
file_upload=app_config.additional_features.file_upload,
|
||||
external_data_variable_node_mapping=external_data_variable_node_mapping
|
||||
external_data_variable_node_mapping=external_data_variable_node_mapping,
|
||||
)
|
||||
|
||||
graph = self._append_node(graph, llm_node)
|
||||
@ -196,11 +183,12 @@ class WorkflowConverter:
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
type=WorkflowType.from_app_mode(new_app_mode).value,
|
||||
version='draft',
|
||||
version="draft",
|
||||
graph=json.dumps(graph),
|
||||
features=json.dumps(features),
|
||||
created_by=account_id,
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
|
||||
db.session.add(workflow)
|
||||
@ -208,24 +196,18 @@ class WorkflowConverter:
|
||||
|
||||
return workflow
|
||||
|
||||
def _convert_to_app_config(self, app_model: App,
|
||||
app_model_config: AppModelConfig) -> EasyUIBasedAppConfig:
|
||||
def _convert_to_app_config(self, app_model: App, app_model_config: AppModelConfig) -> EasyUIBasedAppConfig:
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode == AppMode.AGENT_CHAT or app_model.is_agent:
|
||||
app_model.mode = AppMode.AGENT_CHAT.value
|
||||
app_config = AgentChatAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
app_model_config=app_model_config
|
||||
app_model=app_model, app_model_config=app_model_config
|
||||
)
|
||||
elif app_mode == AppMode.CHAT:
|
||||
app_config = ChatAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
app_model_config=app_model_config
|
||||
)
|
||||
app_config = ChatAppConfigManager.get_app_config(app_model=app_model, app_model_config=app_model_config)
|
||||
elif app_mode == AppMode.COMPLETION:
|
||||
app_config = CompletionAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
app_model_config=app_model_config
|
||||
app_model=app_model, app_model_config=app_model_config
|
||||
)
|
||||
else:
|
||||
raise ValueError("Invalid app mode")
|
||||
@ -244,14 +226,13 @@ class WorkflowConverter:
|
||||
"data": {
|
||||
"title": "START",
|
||||
"type": NodeType.START.value,
|
||||
"variables": [jsonable_encoder(v) for v in variables]
|
||||
}
|
||||
"variables": [jsonable_encoder(v) for v in variables],
|
||||
},
|
||||
}
|
||||
|
||||
def _convert_to_http_request_node(self, app_model: App,
|
||||
variables: list[VariableEntity],
|
||||
external_data_variables: list[ExternalDataVariableEntity]) \
|
||||
-> tuple[list[dict], dict[str, str]]:
|
||||
def _convert_to_http_request_node(
|
||||
self, app_model: App, variables: list[VariableEntity], external_data_variables: list[ExternalDataVariableEntity]
|
||||
) -> tuple[list[dict], dict[str, str]]:
|
||||
"""
|
||||
Convert API Based Extension to HTTP Request Node
|
||||
:param app_model: App instance
|
||||
@ -273,40 +254,33 @@ class WorkflowConverter:
|
||||
|
||||
# get params from config
|
||||
api_based_extension_id = tool_config.get("api_based_extension_id")
|
||||
if not api_based_extension_id:
|
||||
continue
|
||||
|
||||
# get api_based_extension
|
||||
api_based_extension = self._get_api_based_extension(
|
||||
tenant_id=tenant_id,
|
||||
api_based_extension_id=api_based_extension_id
|
||||
tenant_id=tenant_id, api_based_extension_id=api_based_extension_id
|
||||
)
|
||||
|
||||
if not api_based_extension:
|
||||
raise ValueError("[External data tool] API query failed, variable: {}, "
|
||||
"error: api_based_extension_id is invalid"
|
||||
.format(tool_variable))
|
||||
|
||||
# decrypt api_key
|
||||
api_key = encrypter.decrypt_token(
|
||||
tenant_id=tenant_id,
|
||||
token=api_based_extension.api_key
|
||||
)
|
||||
api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=api_based_extension.api_key)
|
||||
|
||||
inputs = {}
|
||||
for v in variables:
|
||||
inputs[v.variable] = '{{#start.' + v.variable + '#}}'
|
||||
inputs[v.variable] = "{{#start." + v.variable + "#}}"
|
||||
|
||||
request_body = {
|
||||
'point': APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY.value,
|
||||
'params': {
|
||||
'app_id': app_model.id,
|
||||
'tool_variable': tool_variable,
|
||||
'inputs': inputs,
|
||||
'query': '{{#sys.query#}}' if app_model.mode == AppMode.CHAT.value else ''
|
||||
}
|
||||
"point": APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY.value,
|
||||
"params": {
|
||||
"app_id": app_model.id,
|
||||
"tool_variable": tool_variable,
|
||||
"inputs": inputs,
|
||||
"query": "{{#sys.query#}}" if app_model.mode == AppMode.CHAT.value else "",
|
||||
},
|
||||
}
|
||||
|
||||
request_body_json = json.dumps(request_body)
|
||||
request_body_json = request_body_json.replace(r'\{\{', '{{').replace(r'\}\}', '}}')
|
||||
request_body_json = request_body_json.replace(r"\{\{", "{{").replace(r"\}\}", "}}")
|
||||
|
||||
http_request_node = {
|
||||
"id": f"http_request_{index}",
|
||||
@ -316,20 +290,11 @@ class WorkflowConverter:
|
||||
"type": NodeType.HTTP_REQUEST.value,
|
||||
"method": "post",
|
||||
"url": api_based_extension.api_endpoint,
|
||||
"authorization": {
|
||||
"type": "api-key",
|
||||
"config": {
|
||||
"type": "bearer",
|
||||
"api_key": api_key
|
||||
}
|
||||
},
|
||||
"authorization": {"type": "api-key", "config": {"type": "bearer", "api_key": api_key}},
|
||||
"headers": "",
|
||||
"params": "",
|
||||
"body": {
|
||||
"type": "json",
|
||||
"data": request_body_json
|
||||
}
|
||||
}
|
||||
"body": {"type": "json", "data": request_body_json},
|
||||
},
|
||||
}
|
||||
|
||||
nodes.append(http_request_node)
|
||||
@ -341,32 +306,24 @@ class WorkflowConverter:
|
||||
"data": {
|
||||
"title": f"Parse {api_based_extension.name} Response",
|
||||
"type": NodeType.CODE.value,
|
||||
"variables": [{
|
||||
"variable": "response_json",
|
||||
"value_selector": [http_request_node['id'], "body"]
|
||||
}],
|
||||
"variables": [{"variable": "response_json", "value_selector": [http_request_node["id"], "body"]}],
|
||||
"code_language": "python3",
|
||||
"code": "import json\n\ndef main(response_json: str) -> str:\n response_body = json.loads("
|
||||
"response_json)\n return {\n \"result\": response_body[\"result\"]\n }",
|
||||
"outputs": {
|
||||
"result": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
}
|
||||
'response_json)\n return {\n "result": response_body["result"]\n }',
|
||||
"outputs": {"result": {"type": "string"}},
|
||||
},
|
||||
}
|
||||
|
||||
nodes.append(code_node)
|
||||
|
||||
external_data_variable_node_mapping[external_data_variable.variable] = code_node['id']
|
||||
external_data_variable_node_mapping[external_data_variable.variable] = code_node["id"]
|
||||
index += 1
|
||||
|
||||
return nodes, external_data_variable_node_mapping
|
||||
|
||||
def _convert_to_knowledge_retrieval_node(self, new_app_mode: AppMode,
|
||||
dataset_config: DatasetEntity,
|
||||
model_config: ModelConfigEntity) \
|
||||
-> Optional[dict]:
|
||||
def _convert_to_knowledge_retrieval_node(
|
||||
self, new_app_mode: AppMode, dataset_config: DatasetEntity, model_config: ModelConfigEntity
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Convert datasets to Knowledge Retrieval Node
|
||||
:param new_app_mode: new app mode
|
||||
@ -400,7 +357,7 @@ class WorkflowConverter:
|
||||
"completion_params": {
|
||||
**model_config.parameters,
|
||||
"stop": model_config.stop,
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE
|
||||
@ -408,20 +365,23 @@ class WorkflowConverter:
|
||||
"multiple_retrieval_config": {
|
||||
"top_k": retrieve_config.top_k,
|
||||
"score_threshold": retrieve_config.score_threshold,
|
||||
"reranking_model": retrieve_config.reranking_model
|
||||
"reranking_model": retrieve_config.reranking_model,
|
||||
}
|
||||
if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE
|
||||
else None,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
def _convert_to_llm_node(self, original_app_mode: AppMode,
|
||||
new_app_mode: AppMode,
|
||||
graph: dict,
|
||||
model_config: ModelConfigEntity,
|
||||
prompt_template: PromptTemplateEntity,
|
||||
file_upload: Optional[FileExtraConfig] = None,
|
||||
external_data_variable_node_mapping: dict[str, str] = None) -> dict:
|
||||
def _convert_to_llm_node(
|
||||
self,
|
||||
original_app_mode: AppMode,
|
||||
new_app_mode: AppMode,
|
||||
graph: dict,
|
||||
model_config: ModelConfigEntity,
|
||||
prompt_template: PromptTemplateEntity,
|
||||
file_upload: Optional[FileExtraConfig] = None,
|
||||
external_data_variable_node_mapping: dict[str, str] | None = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Convert to LLM Node
|
||||
:param original_app_mode: original app mode
|
||||
@ -433,17 +393,18 @@ class WorkflowConverter:
|
||||
:param external_data_variable_node_mapping: external data variable node mapping
|
||||
"""
|
||||
# fetch start and knowledge retrieval node
|
||||
start_node = next(filter(lambda n: n['data']['type'] == NodeType.START.value, graph['nodes']))
|
||||
knowledge_retrieval_node = next(filter(
|
||||
lambda n: n['data']['type'] == NodeType.KNOWLEDGE_RETRIEVAL.value,
|
||||
graph['nodes']
|
||||
), None)
|
||||
start_node = next(filter(lambda n: n["data"]["type"] == NodeType.START.value, graph["nodes"]))
|
||||
knowledge_retrieval_node = next(
|
||||
filter(lambda n: n["data"]["type"] == NodeType.KNOWLEDGE_RETRIEVAL.value, graph["nodes"]), None
|
||||
)
|
||||
|
||||
role_prefix = None
|
||||
|
||||
# Chat Model
|
||||
if model_config.mode == LLMMode.CHAT.value:
|
||||
if prompt_template.prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
|
||||
if not prompt_template.simple_prompt_template:
|
||||
raise ValueError("Simple prompt template is required")
|
||||
# get prompt template
|
||||
prompt_transform = SimplePromptTransform()
|
||||
prompt_template_config = prompt_transform.get_prompt_template(
|
||||
@ -452,45 +413,35 @@ class WorkflowConverter:
|
||||
model=model_config.model,
|
||||
pre_prompt=prompt_template.simple_prompt_template,
|
||||
has_context=knowledge_retrieval_node is not None,
|
||||
query_in_prompt=False
|
||||
query_in_prompt=False,
|
||||
)
|
||||
|
||||
template = prompt_template_config['prompt_template'].template
|
||||
template = prompt_template_config["prompt_template"].template
|
||||
if not template:
|
||||
prompts = []
|
||||
else:
|
||||
template = self._replace_template_variables(
|
||||
template,
|
||||
start_node['data']['variables'],
|
||||
external_data_variable_node_mapping
|
||||
template, start_node["data"]["variables"], external_data_variable_node_mapping
|
||||
)
|
||||
|
||||
prompts = [
|
||||
{
|
||||
"role": 'user',
|
||||
"text": template
|
||||
}
|
||||
]
|
||||
prompts = [{"role": "user", "text": template}]
|
||||
else:
|
||||
advanced_chat_prompt_template = prompt_template.advanced_chat_prompt_template
|
||||
|
||||
prompts = []
|
||||
for m in advanced_chat_prompt_template.messages:
|
||||
if advanced_chat_prompt_template:
|
||||
if advanced_chat_prompt_template:
|
||||
for m in advanced_chat_prompt_template.messages:
|
||||
text = m.text
|
||||
text = self._replace_template_variables(
|
||||
text,
|
||||
start_node['data']['variables'],
|
||||
external_data_variable_node_mapping
|
||||
text, start_node["data"]["variables"], external_data_variable_node_mapping
|
||||
)
|
||||
|
||||
prompts.append({
|
||||
"role": m.role.value,
|
||||
"text": text
|
||||
})
|
||||
prompts.append({"role": m.role.value, "text": text})
|
||||
# Completion Model
|
||||
else:
|
||||
if prompt_template.prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
|
||||
if not prompt_template.simple_prompt_template:
|
||||
raise ValueError("Simple prompt template is required")
|
||||
# get prompt template
|
||||
prompt_transform = SimplePromptTransform()
|
||||
prompt_template_config = prompt_transform.get_prompt_template(
|
||||
@ -499,57 +450,50 @@ class WorkflowConverter:
|
||||
model=model_config.model,
|
||||
pre_prompt=prompt_template.simple_prompt_template,
|
||||
has_context=knowledge_retrieval_node is not None,
|
||||
query_in_prompt=False
|
||||
query_in_prompt=False,
|
||||
)
|
||||
|
||||
template = prompt_template_config['prompt_template'].template
|
||||
template = prompt_template_config["prompt_template"].template
|
||||
template = self._replace_template_variables(
|
||||
template,
|
||||
start_node['data']['variables'],
|
||||
external_data_variable_node_mapping
|
||||
template=template,
|
||||
variables=start_node["data"]["variables"],
|
||||
external_data_variable_node_mapping=external_data_variable_node_mapping,
|
||||
)
|
||||
|
||||
prompts = {
|
||||
"text": template
|
||||
}
|
||||
prompts = {"text": template}
|
||||
|
||||
prompt_rules = prompt_template_config['prompt_rules']
|
||||
prompt_rules = prompt_template_config["prompt_rules"]
|
||||
role_prefix = {
|
||||
"user": prompt_rules.get('human_prefix', 'Human'),
|
||||
"assistant": prompt_rules.get('assistant_prefix', 'Assistant')
|
||||
"user": prompt_rules.get("human_prefix", "Human"),
|
||||
"assistant": prompt_rules.get("assistant_prefix", "Assistant"),
|
||||
}
|
||||
else:
|
||||
advanced_completion_prompt_template = prompt_template.advanced_completion_prompt_template
|
||||
if advanced_completion_prompt_template:
|
||||
text = advanced_completion_prompt_template.prompt
|
||||
text = self._replace_template_variables(
|
||||
text,
|
||||
start_node['data']['variables'],
|
||||
external_data_variable_node_mapping
|
||||
template=text,
|
||||
variables=start_node["data"]["variables"],
|
||||
external_data_variable_node_mapping=external_data_variable_node_mapping,
|
||||
)
|
||||
else:
|
||||
text = ""
|
||||
|
||||
text = text.replace('{{#query#}}', '{{#sys.query#}}')
|
||||
text = text.replace("{{#query#}}", "{{#sys.query#}}")
|
||||
|
||||
prompts = {
|
||||
"text": text,
|
||||
}
|
||||
|
||||
if advanced_completion_prompt_template.role_prefix:
|
||||
if advanced_completion_prompt_template and advanced_completion_prompt_template.role_prefix:
|
||||
role_prefix = {
|
||||
"user": advanced_completion_prompt_template.role_prefix.user,
|
||||
"assistant": advanced_completion_prompt_template.role_prefix.assistant
|
||||
"assistant": advanced_completion_prompt_template.role_prefix.assistant,
|
||||
}
|
||||
|
||||
memory = None
|
||||
if new_app_mode == AppMode.ADVANCED_CHAT:
|
||||
memory = {
|
||||
"role_prefix": role_prefix,
|
||||
"window": {
|
||||
"enabled": False
|
||||
}
|
||||
}
|
||||
memory = {"role_prefix": role_prefix, "window": {"enabled": False}}
|
||||
|
||||
completion_params = model_config.parameters
|
||||
completion_params.update({"stop": model_config.stop})
|
||||
@ -563,41 +507,42 @@ class WorkflowConverter:
|
||||
"provider": model_config.provider,
|
||||
"name": model_config.model,
|
||||
"mode": model_config.mode,
|
||||
"completion_params": completion_params
|
||||
"completion_params": completion_params,
|
||||
},
|
||||
"prompt_template": prompts,
|
||||
"memory": memory,
|
||||
"context": {
|
||||
"enabled": knowledge_retrieval_node is not None,
|
||||
"variable_selector": ["knowledge_retrieval", "result"]
|
||||
if knowledge_retrieval_node is not None else None
|
||||
if knowledge_retrieval_node is not None
|
||||
else None,
|
||||
},
|
||||
"vision": {
|
||||
"enabled": file_upload is not None,
|
||||
"variable_selector": ["sys", "files"] if file_upload is not None else None,
|
||||
"configs": {
|
||||
"detail": file_upload.image_config['detail']
|
||||
} if file_upload is not None else None
|
||||
}
|
||||
}
|
||||
"configs": {"detail": file_upload.image_config["detail"]}
|
||||
if file_upload is not None and file_upload.image_config is not None
|
||||
else None,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def _replace_template_variables(self, template: str,
|
||||
variables: list[dict],
|
||||
external_data_variable_node_mapping: dict[str, str] = None) -> str:
|
||||
def _replace_template_variables(
|
||||
self, template: str, variables: list[dict], external_data_variable_node_mapping: dict[str, str] | None = None
|
||||
) -> str:
|
||||
"""
|
||||
Replace Template Variables
|
||||
:param template: template
|
||||
:param variables: list of variables
|
||||
:param external_data_variable_node_mapping: external data variable node mapping
|
||||
:return:
|
||||
"""
|
||||
for v in variables:
|
||||
template = template.replace('{{' + v['variable'] + '}}', '{{#start.' + v['variable'] + '#}}')
|
||||
template = template.replace("{{" + v["variable"] + "}}", "{{#start." + v["variable"] + "#}}")
|
||||
|
||||
if external_data_variable_node_mapping:
|
||||
for variable, code_node_id in external_data_variable_node_mapping.items():
|
||||
template = template.replace('{{' + variable + '}}',
|
||||
'{{#' + code_node_id + '.result#}}')
|
||||
template = template.replace("{{" + variable + "}}", "{{#" + code_node_id + ".result#}}")
|
||||
|
||||
return template
|
||||
|
||||
@ -613,11 +558,8 @@ class WorkflowConverter:
|
||||
"data": {
|
||||
"title": "END",
|
||||
"type": NodeType.END.value,
|
||||
"outputs": [{
|
||||
"variable": "result",
|
||||
"value_selector": ["llm", "text"]
|
||||
}]
|
||||
}
|
||||
"outputs": [{"variable": "result", "value_selector": ["llm", "text"]}],
|
||||
},
|
||||
}
|
||||
|
||||
def _convert_to_answer_node(self) -> dict:
|
||||
@ -629,11 +571,7 @@ class WorkflowConverter:
|
||||
return {
|
||||
"id": "answer",
|
||||
"position": None,
|
||||
"data": {
|
||||
"title": "ANSWER",
|
||||
"type": NodeType.ANSWER.value,
|
||||
"answer": "{{#llm.text#}}"
|
||||
}
|
||||
"data": {"title": "ANSWER", "type": NodeType.ANSWER.value, "answer": "{{#llm.text#}}"},
|
||||
}
|
||||
|
||||
def _create_edge(self, source: str, target: str) -> dict:
|
||||
@ -643,11 +581,7 @@ class WorkflowConverter:
|
||||
:param target: target node id
|
||||
:return:
|
||||
"""
|
||||
return {
|
||||
"id": f"{source}-{target}",
|
||||
"source": source,
|
||||
"target": target
|
||||
}
|
||||
return {"id": f"{source}-{target}", "source": source, "target": target}
|
||||
|
||||
def _append_node(self, graph: dict, node: dict) -> dict:
|
||||
"""
|
||||
@ -657,9 +591,9 @@ class WorkflowConverter:
|
||||
:param node: Node to append
|
||||
:return:
|
||||
"""
|
||||
previous_node = graph['nodes'][-1]
|
||||
graph['nodes'].append(node)
|
||||
graph['edges'].append(self._create_edge(previous_node['id'], node['id']))
|
||||
previous_node = graph["nodes"][-1]
|
||||
graph["nodes"].append(node)
|
||||
graph["edges"].append(self._create_edge(previous_node["id"], node["id"]))
|
||||
return graph
|
||||
|
||||
def _get_new_app_mode(self, app_model: App) -> AppMode:
|
||||
@ -673,14 +607,20 @@ class WorkflowConverter:
|
||||
else:
|
||||
return AppMode.ADVANCED_CHAT
|
||||
|
||||
def _get_api_based_extension(self, tenant_id: str, api_based_extension_id: str) -> APIBasedExtension:
|
||||
def _get_api_based_extension(self, tenant_id: str, api_based_extension_id: str):
|
||||
"""
|
||||
Get API Based Extension
|
||||
:param tenant_id: tenant id
|
||||
:param api_based_extension_id: api based extension id
|
||||
:return:
|
||||
"""
|
||||
return db.session.query(APIBasedExtension).filter(
|
||||
APIBasedExtension.tenant_id == tenant_id,
|
||||
APIBasedExtension.id == api_based_extension_id
|
||||
).first()
|
||||
api_based_extension = (
|
||||
db.session.query(APIBasedExtension)
|
||||
.filter(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not api_based_extension:
|
||||
raise ValueError(f"API Based Extension not found, id: {api_based_extension_id}")
|
||||
|
||||
return api_based_extension
|
||||
|
||||
@ -297,6 +297,7 @@ class WorkflowService:
|
||||
app_model=app_model,
|
||||
account=account,
|
||||
name=args.get('name'),
|
||||
icon_type=args.get('icon_type'),
|
||||
icon=args.get('icon'),
|
||||
icon_background=args.get('icon_background'),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user