mirror of
https://github.com/langgenius/dify.git
synced 2026-01-19 19:55:06 +08:00
Compare commits
74 Commits
feat/conta
...
1.8.1
| Author | SHA1 | Date | |
|---|---|---|---|
| c7700ac176 | |||
| d011ddfc64 | |||
| 67cc70ad61 | |||
| a384ae9140 | |||
| a7627882a7 | |||
| 8eae7a95be | |||
| dabf266048 | |||
| 462e764a3c | |||
| 0e8a37dca8 | |||
| bffbe54120 | |||
| b673560b92 | |||
| 9e125e2029 | |||
| b88146c443 | |||
| c40cb7fd59 | |||
| 9d5956cef8 | |||
| 1fff4620e6 | |||
| c3820f55f4 | |||
| 60c5bdd62f | |||
| 5092e5f631 | |||
| c0bd35594e | |||
| bc9efa7ea8 | |||
| f540d0b747 | |||
| 7bcaa513fa | |||
| d33dfee8a3 | |||
| b5216df4fe | |||
| 25a11bfafc | |||
| 8fcc864fb7 | |||
| ed5ed0306e | |||
| a418c43d32 | |||
| 5aa8c9c8df | |||
| 32972b45db | |||
| af351b1723 | |||
| af88266212 | |||
| b14119b531 | |||
| 68c75f221b | |||
| 7b379e2a61 | |||
| c373b734bc | |||
| 2ac8f8003f | |||
| d6b3df8f6f | |||
| deea07e905 | |||
| 0caa94bd1c | |||
| a32dde5428 | |||
| 067b0d07c4 | |||
| 044f96bd93 | |||
| ca96350707 | |||
| be3af1e234 | |||
| 2e89d29c87 | |||
| e4eb9f7c55 | |||
| dd6547de06 | |||
| 84d09b8b8a | |||
| 2c462154f7 | |||
| b810efdb3f | |||
| ae04ccc445 | |||
| f7ac1192ae | |||
| e048588a88 | |||
| 2042353526 | |||
| 9486715929 | |||
| 64319c0d56 | |||
| acd209a890 | |||
| bd482eb8ef | |||
| 5b3cc560d5 | |||
| d41d4deaac | |||
| 208ce4e774 | |||
| 414ee51975 | |||
| d5a521eef2 | |||
| 1b401063e8 | |||
| 60d9d0584a | |||
| ffba341258 | |||
| f11131f8b5 | |||
| 2e6e414a9e | |||
| c45d676477 | |||
| b8d8dddd5a | |||
| c45c22b1b2 | |||
| 3d57a9ccdc |
15
.github/workflows/api-tests.yml
vendored
15
.github/workflows/api-tests.yml
vendored
@ -42,11 +42,7 @@ jobs:
|
||||
- name: Run Unit tests
|
||||
run: |
|
||||
uv run --project api bash dev/pytest/pytest_unit_tests.sh
|
||||
- name: Run ty check
|
||||
run: |
|
||||
cd api
|
||||
uv add --dev ty
|
||||
uv run ty check || true
|
||||
|
||||
- name: Run pyrefly check
|
||||
run: |
|
||||
cd api
|
||||
@ -66,15 +62,6 @@ jobs:
|
||||
- name: Run dify config tests
|
||||
run: uv run --project api dev/pytest/pytest_config_tests.py
|
||||
|
||||
- name: MyPy Cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: api/.mypy_cache
|
||||
key: mypy-${{ matrix.python-version }}-${{ runner.os }}-${{ hashFiles('api/uv.lock') }}
|
||||
|
||||
- name: Run MyPy Checks
|
||||
run: dev/mypy-check
|
||||
|
||||
- name: Set up dotenvs
|
||||
run: |
|
||||
cp docker/.env.example docker/.env
|
||||
|
||||
8
.github/workflows/style.yml
vendored
8
.github/workflows/style.yml
vendored
@ -44,6 +44,10 @@ jobs:
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: uv sync --project api --dev
|
||||
|
||||
- name: Run Basedpyright Checks
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: dev/basedpyright-check
|
||||
|
||||
- name: Dotenv check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: uv run --project api dotenv-linter ./api/.env.example ./web/.env.example
|
||||
@ -89,7 +93,9 @@ jobs:
|
||||
- name: Web style check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: pnpm run lint
|
||||
run: |
|
||||
pnpm run lint
|
||||
pnpm run eslint
|
||||
|
||||
docker-compose-template:
|
||||
name: Docker Compose Template
|
||||
|
||||
5
.gitignore
vendored
5
.gitignore
vendored
@ -123,10 +123,12 @@ venv.bak/
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
# type checking
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
pyrightconfig.json
|
||||
!api/pyrightconfig.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
@ -195,7 +197,6 @@ sdks/python-client/dify_client.egg-info
|
||||
.vscode/*
|
||||
!.vscode/launch.json.template
|
||||
!.vscode/README.md
|
||||
pyrightconfig.json
|
||||
api/.vscode
|
||||
# vscode Code History Extension
|
||||
.history
|
||||
|
||||
@ -32,7 +32,7 @@ uv run --project api pytest tests/integration_tests/ # Integration tests
|
||||
./dev/reformat # Run all formatters and linters
|
||||
uv run --project api ruff check --fix ./ # Fix linting issues
|
||||
uv run --project api ruff format ./ # Format code
|
||||
uv run --project api mypy . # Type checking
|
||||
uv run --directory api basedpyright # Type checking
|
||||
```
|
||||
|
||||
### Frontend (Web)
|
||||
|
||||
60
Makefile
60
Makefile
@ -4,6 +4,48 @@ WEB_IMAGE=$(DOCKER_REGISTRY)/dify-web
|
||||
API_IMAGE=$(DOCKER_REGISTRY)/dify-api
|
||||
VERSION=latest
|
||||
|
||||
# Backend Development Environment Setup
|
||||
.PHONY: dev-setup prepare-docker prepare-web prepare-api
|
||||
|
||||
# Default dev setup target
|
||||
dev-setup: prepare-docker prepare-web prepare-api
|
||||
@echo "✅ Backend development environment setup complete!"
|
||||
|
||||
# Step 1: Prepare Docker middleware
|
||||
prepare-docker:
|
||||
@echo "🐳 Setting up Docker middleware..."
|
||||
@cp -n docker/middleware.env.example docker/middleware.env 2>/dev/null || echo "Docker middleware.env already exists"
|
||||
@cd docker && docker compose -f docker-compose.middleware.yaml --env-file middleware.env -p dify-middlewares-dev up -d
|
||||
@echo "✅ Docker middleware started"
|
||||
|
||||
# Step 2: Prepare web environment
|
||||
prepare-web:
|
||||
@echo "🌐 Setting up web environment..."
|
||||
@cp -n web/.env.example web/.env 2>/dev/null || echo "Web .env already exists"
|
||||
@cd web && pnpm install
|
||||
@cd web && pnpm build
|
||||
@echo "✅ Web environment prepared (not started)"
|
||||
|
||||
# Step 3: Prepare API environment
|
||||
prepare-api:
|
||||
@echo "🔧 Setting up API environment..."
|
||||
@cp -n api/.env.example api/.env 2>/dev/null || echo "API .env already exists"
|
||||
@cd api && uv sync --dev
|
||||
@cd api && uv run flask db upgrade
|
||||
@echo "✅ API environment prepared (not started)"
|
||||
|
||||
# Clean dev environment
|
||||
dev-clean:
|
||||
@echo "⚠️ Stopping Docker containers..."
|
||||
@cd docker && docker compose -f docker-compose.middleware.yaml --env-file middleware.env -p dify-middlewares-dev down
|
||||
@echo "🗑️ Removing volumes..."
|
||||
@rm -rf docker/volumes/db
|
||||
@rm -rf docker/volumes/redis
|
||||
@rm -rf docker/volumes/plugin_daemon
|
||||
@rm -rf docker/volumes/weaviate
|
||||
@rm -rf api/storage
|
||||
@echo "✅ Cleanup complete"
|
||||
|
||||
# Build Docker images
|
||||
build-web:
|
||||
@echo "Building web Docker image: $(WEB_IMAGE):$(VERSION)..."
|
||||
@ -39,5 +81,21 @@ build-push-web: build-web push-web
|
||||
build-push-all: build-all push-all
|
||||
@echo "All Docker images have been built and pushed."
|
||||
|
||||
# Help target
|
||||
help:
|
||||
@echo "Development Setup Targets:"
|
||||
@echo " make dev-setup - Run all setup steps for backend dev environment"
|
||||
@echo " make prepare-docker - Set up Docker middleware"
|
||||
@echo " make prepare-web - Set up web environment"
|
||||
@echo " make prepare-api - Set up API environment"
|
||||
@echo " make dev-clean - Stop Docker middleware containers"
|
||||
@echo ""
|
||||
@echo "Docker Build Targets:"
|
||||
@echo " make build-web - Build web Docker image"
|
||||
@echo " make build-api - Build API Docker image"
|
||||
@echo " make build-all - Build all Docker images"
|
||||
@echo " make push-all - Push all Docker images"
|
||||
@echo " make build-push-all - Build and push all Docker images"
|
||||
|
||||
# Phony targets
|
||||
.PHONY: build-web build-api push-web push-api build-all push-all build-push-all
|
||||
.PHONY: build-web build-api push-web push-api build-all push-all build-push-all dev-setup prepare-docker prepare-web prepare-api dev-clean help
|
||||
|
||||
@ -460,16 +460,6 @@ WORKFLOW_CALL_MAX_DEPTH=5
|
||||
WORKFLOW_PARALLEL_DEPTH_LIMIT=3
|
||||
MAX_VARIABLE_SIZE=204800
|
||||
|
||||
# GraphEngine Worker Pool Configuration
|
||||
# Minimum number of workers per GraphEngine instance (default: 1)
|
||||
GRAPH_ENGINE_MIN_WORKERS=1
|
||||
# Maximum number of workers per GraphEngine instance (default: 10)
|
||||
GRAPH_ENGINE_MAX_WORKERS=10
|
||||
# Queue depth threshold that triggers worker scale up (default: 3)
|
||||
GRAPH_ENGINE_SCALE_UP_THRESHOLD=3
|
||||
# Seconds of idle time before scaling down workers (default: 5.0)
|
||||
GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME=5.0
|
||||
|
||||
# Workflow storage configuration
|
||||
# Options: rdbms, hybrid
|
||||
# rdbms: Use only the relational database (default)
|
||||
|
||||
@ -1,112 +0,0 @@
|
||||
[importlinter]
|
||||
root_packages =
|
||||
core
|
||||
configs
|
||||
controllers
|
||||
models
|
||||
tasks
|
||||
services
|
||||
|
||||
[importlinter:contract:workflow]
|
||||
name = Workflow
|
||||
type=layers
|
||||
layers =
|
||||
graph_engine
|
||||
graph_events
|
||||
graph
|
||||
nodes
|
||||
node_events
|
||||
entities
|
||||
containers =
|
||||
core.workflow
|
||||
ignore_imports =
|
||||
core.workflow.nodes.base.node -> core.workflow.graph_events
|
||||
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_events
|
||||
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_engine
|
||||
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph
|
||||
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_engine.command_channels
|
||||
core.workflow.nodes.loop.loop_node -> core.workflow.graph_events
|
||||
core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine
|
||||
core.workflow.nodes.loop.loop_node -> core.workflow.graph
|
||||
core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine.command_channels
|
||||
core.workflow.nodes.node_factory -> core.workflow.graph
|
||||
|
||||
[importlinter:contract:rsc]
|
||||
name = RSC
|
||||
type = layers
|
||||
layers =
|
||||
graph_engine
|
||||
response_coordinator
|
||||
containers =
|
||||
core.workflow.graph_engine
|
||||
|
||||
[importlinter:contract:worker]
|
||||
name = Worker
|
||||
type = layers
|
||||
layers =
|
||||
graph_engine
|
||||
worker
|
||||
containers =
|
||||
core.workflow.graph_engine
|
||||
|
||||
[importlinter:contract:graph-engine-architecture]
|
||||
name = Graph Engine Architecture
|
||||
type = layers
|
||||
layers =
|
||||
graph_engine
|
||||
orchestration
|
||||
command_processing
|
||||
event_management
|
||||
error_handling
|
||||
graph_traversal
|
||||
state_management
|
||||
worker_management
|
||||
domain
|
||||
containers =
|
||||
core.workflow.graph_engine
|
||||
|
||||
[importlinter:contract:domain-isolation]
|
||||
name = Domain Model Isolation
|
||||
type = forbidden
|
||||
source_modules =
|
||||
core.workflow.graph_engine.domain
|
||||
forbidden_modules =
|
||||
core.workflow.graph_engine.worker_management
|
||||
core.workflow.graph_engine.command_channels
|
||||
core.workflow.graph_engine.layers
|
||||
core.workflow.graph_engine.protocols
|
||||
|
||||
[importlinter:contract:worker-management]
|
||||
name = Worker Management
|
||||
type = forbidden
|
||||
source_modules =
|
||||
core.workflow.graph_engine.worker_management
|
||||
forbidden_modules =
|
||||
core.workflow.graph_engine.orchestration
|
||||
core.workflow.graph_engine.command_processing
|
||||
core.workflow.graph_engine.event_management
|
||||
|
||||
[importlinter:contract:error-handling-strategies]
|
||||
name = Error Handling Strategies
|
||||
type = independence
|
||||
modules =
|
||||
core.workflow.graph_engine.error_handling.abort_strategy
|
||||
core.workflow.graph_engine.error_handling.retry_strategy
|
||||
core.workflow.graph_engine.error_handling.fail_branch_strategy
|
||||
core.workflow.graph_engine.error_handling.default_value_strategy
|
||||
|
||||
[importlinter:contract:graph-traversal-components]
|
||||
name = Graph Traversal Components
|
||||
type = layers
|
||||
layers =
|
||||
edge_processor
|
||||
skip_propagator
|
||||
containers =
|
||||
core.workflow.graph_engine.graph_traversal
|
||||
|
||||
[importlinter:contract:command-channels]
|
||||
name = Command Channels Independence
|
||||
type = independence
|
||||
modules =
|
||||
core.workflow.graph_engine.command_channels.in_memory_channel
|
||||
core.workflow.graph_engine.command_channels.redis_channel
|
||||
@ -108,5 +108,5 @@ uv run celery -A app.celery beat
|
||||
../dev/reformat # Run all formatters and linters
|
||||
uv run ruff check --fix ./ # Fix linting issues
|
||||
uv run ruff format ./ # Format code
|
||||
uv run mypy . # Type checking
|
||||
uv run basedpyright . # Type checking
|
||||
```
|
||||
|
||||
@ -1,11 +0,0 @@
|
||||
from tests.integration_tests.utils.parent_class import ParentClass
|
||||
|
||||
|
||||
class ChildClass(ParentClass):
|
||||
"""Test child class for module import helper tests"""
|
||||
|
||||
def __init__(self, name):
|
||||
super().__init__(name)
|
||||
|
||||
def get_name(self):
|
||||
return f"Child: {self.name}"
|
||||
@ -13,6 +13,7 @@ from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from configs import dify_config
|
||||
from constants.languages import languages
|
||||
from core.plugin.entities.plugin import ToolProviderID
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.index_processor.constant.built_in_field import BuiltInField
|
||||
@ -30,7 +31,6 @@ from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, D
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation
|
||||
from models.provider import Provider, ProviderModel
|
||||
from models.provider_ids import ToolProviderID
|
||||
from models.tools import ToolOAuthSystemClient
|
||||
from services.account_service import AccountService, RegisterService, TenantService
|
||||
from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpiredLogs
|
||||
@ -571,7 +571,7 @@ def old_metadata_migration():
|
||||
for document in documents:
|
||||
if document.doc_metadata:
|
||||
doc_metadata = document.doc_metadata
|
||||
for key, value in doc_metadata.items():
|
||||
for key in doc_metadata:
|
||||
for field in BuiltInField:
|
||||
if field.value == key:
|
||||
break
|
||||
|
||||
@ -529,28 +529,6 @@ class WorkflowConfig(BaseSettings):
|
||||
default=200 * 1024,
|
||||
)
|
||||
|
||||
# GraphEngine Worker Pool Configuration
|
||||
GRAPH_ENGINE_MIN_WORKERS: PositiveInt = Field(
|
||||
description="Minimum number of workers per GraphEngine instance",
|
||||
default=1,
|
||||
)
|
||||
|
||||
GRAPH_ENGINE_MAX_WORKERS: PositiveInt = Field(
|
||||
description="Maximum number of workers per GraphEngine instance",
|
||||
default=10,
|
||||
)
|
||||
|
||||
GRAPH_ENGINE_SCALE_UP_THRESHOLD: PositiveInt = Field(
|
||||
description="Queue depth threshold that triggers worker scale up",
|
||||
default=3,
|
||||
)
|
||||
|
||||
GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME: float = Field(
|
||||
description="Seconds of idle time before scaling down workers",
|
||||
default=5.0,
|
||||
ge=0.1,
|
||||
)
|
||||
|
||||
|
||||
class WorkflowNodeExecutionConfig(BaseSettings):
|
||||
"""
|
||||
|
||||
@ -29,7 +29,7 @@ class NacosSettingsSource(RemoteSettingsSource):
|
||||
try:
|
||||
content = NacosHttpClient().http_request("/nacos/v1/cs/configs", method="GET", headers={}, params=params)
|
||||
self.remote_configs = self._parse_config(content)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.exception("[get-access-token] exception occurred")
|
||||
raise
|
||||
|
||||
|
||||
@ -27,7 +27,7 @@ class NacosHttpClient:
|
||||
response = requests.request(method, url="http://" + self.server + url, headers=headers, params=params)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
except requests.exceptions.RequestException as e:
|
||||
except requests.RequestException as e:
|
||||
return f"Request to Nacos failed: {e}"
|
||||
|
||||
def _inject_auth_info(self, headers, params, module="config"):
|
||||
@ -77,6 +77,6 @@ class NacosHttpClient:
|
||||
self.token = response_data.get("accessToken")
|
||||
self.token_ttl = response_data.get("tokenTtl", 18000)
|
||||
self.token_expire_time = current_time + self.token_ttl - 10
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.exception("[get-access-token] exception occur")
|
||||
raise
|
||||
|
||||
@ -19,6 +19,7 @@ language_timezone_mapping = {
|
||||
"fa-IR": "Asia/Tehran",
|
||||
"sl-SI": "Europe/Ljubljana",
|
||||
"th-TH": "Asia/Bangkok",
|
||||
"id-ID": "Asia/Jakarta",
|
||||
}
|
||||
|
||||
languages = list(language_timezone_mapping.keys())
|
||||
|
||||
@ -130,15 +130,19 @@ class InsertExploreAppApi(Resource):
|
||||
app.is_public = False
|
||||
|
||||
with Session(db.engine) as session:
|
||||
installed_apps = session.execute(
|
||||
select(InstalledApp).where(
|
||||
InstalledApp.app_id == recommended_app.app_id,
|
||||
InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id,
|
||||
installed_apps = (
|
||||
session.execute(
|
||||
select(InstalledApp).where(
|
||||
InstalledApp.app_id == recommended_app.app_id,
|
||||
InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id,
|
||||
)
|
||||
)
|
||||
).all()
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
for installed_app in installed_apps:
|
||||
db.session.delete(installed_app)
|
||||
for installed_app in installed_apps:
|
||||
session.delete(installed_app)
|
||||
|
||||
db.session.delete(recommended_app)
|
||||
db.session.commit()
|
||||
|
||||
@ -84,7 +84,7 @@ class BaseApiKeyListResource(Resource):
|
||||
flask_restx.abort(
|
||||
400,
|
||||
message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
|
||||
code="max_keys_exceeded",
|
||||
custom="max_keys_exceeded",
|
||||
)
|
||||
|
||||
key = ApiToken.generate_api_key(self.token_prefix, 24)
|
||||
|
||||
@ -237,9 +237,14 @@ class AppExportApi(Resource):
|
||||
# Add include_secret params
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("include_secret", type=inputs.boolean, default=False, location="args")
|
||||
parser.add_argument("workflow_id", type=str, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
return {"data": AppDslService.export_dsl(app_model=app_model, include_secret=args["include_secret"])}
|
||||
return {
|
||||
"data": AppDslService.export_dsl(
|
||||
app_model=app_model, include_secret=args["include_secret"], workflow_id=args.get("workflow_id")
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
class AppNameApi(Resource):
|
||||
|
||||
@ -16,10 +16,7 @@ from core.helper.code_executor.javascript.javascript_code_provider import Javasc
|
||||
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from extensions.ext_database import db
|
||||
from libs.login import login_required
|
||||
from models import App
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
|
||||
class RuleGenerateApi(Resource):
|
||||
@ -138,6 +135,9 @@ class InstructionGenerateApi(Resource):
|
||||
try:
|
||||
# Generate from nothing for a workflow node
|
||||
if (args["current"] == code_template or args["current"] == "") and args["node_id"] != "":
|
||||
from models import App, db
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
app = db.session.query(App).where(App.id == args["flow_id"]).first()
|
||||
if not app:
|
||||
return {"error": f"app {args['flow_id']} not found"}, 400
|
||||
|
||||
@ -24,7 +24,6 @@ from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file.models import File
|
||||
from core.helper.trace_id_helper import get_external_trace_id
|
||||
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory, variable_factory
|
||||
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
|
||||
@ -414,12 +413,7 @@ class WorkflowTaskStopApi(Resource):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
# Stop using both mechanisms for backward compatibility
|
||||
# Legacy stop flag mechanism (without user check)
|
||||
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||
|
||||
# New graph engine command channel mechanism
|
||||
GraphEngineManager.send_stop_command(task_id)
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@ -532,7 +526,7 @@ class PublishedWorkflowApi(Resource):
|
||||
)
|
||||
|
||||
app_model.workflow_id = workflow.id
|
||||
db.session.commit()
|
||||
db.session.commit() # NOTE: this is necessary for update app_model.workflow_id
|
||||
|
||||
workflow_created_at = TimestampField().format(workflow.created_at)
|
||||
|
||||
|
||||
@ -6,7 +6,7 @@ from sqlalchemy.orm import Session
|
||||
from controllers.console import api
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus
|
||||
from extensions.ext_database import db
|
||||
from fields.workflow_app_log_fields import workflow_app_log_pagination_fields
|
||||
from libs.login import login_required
|
||||
@ -27,7 +27,9 @@ class WorkflowAppLogApi(Resource):
|
||||
"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("keyword", type=str, location="args")
|
||||
parser.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args")
|
||||
parser.add_argument(
|
||||
"status", type=str, choices=["succeeded", "failed", "stopped", "partial-succeeded"], location="args"
|
||||
)
|
||||
parser.add_argument(
|
||||
"created_at__before", type=str, location="args", help="Filter logs created before this timestamp"
|
||||
)
|
||||
|
||||
@ -17,11 +17,10 @@ from core.variables.segment_group import SegmentGroup
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from extensions.ext_database import db
|
||||
from factories.file_factory import build_from_mapping, build_from_mappings
|
||||
from factories.variable_factory import build_segment_with_type
|
||||
from libs.login import current_user, login_required
|
||||
from models import App, AppMode
|
||||
from models import App, AppMode, db
|
||||
from models.account import Account
|
||||
from models.workflow import WorkflowDraftVariable
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
|
||||
|
||||
@ -81,7 +81,7 @@ class OAuthDataSourceBinding(Resource):
|
||||
return {"error": "Invalid code"}, 400
|
||||
try:
|
||||
oauth_provider.get_access_token(code)
|
||||
except requests.exceptions.HTTPError as e:
|
||||
except requests.HTTPError as e:
|
||||
logger.exception(
|
||||
"An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text
|
||||
)
|
||||
@ -104,7 +104,7 @@ class OAuthDataSourceSync(Resource):
|
||||
return {"error": "Invalid provider"}, 400
|
||||
try:
|
||||
oauth_provider.sync_data_source(binding_id)
|
||||
except requests.exceptions.HTTPError as e:
|
||||
except requests.HTTPError as e:
|
||||
logger.exception(
|
||||
"An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text
|
||||
)
|
||||
|
||||
@ -130,7 +130,7 @@ class ResetPasswordSendEmailApi(Resource):
|
||||
language = "en-US"
|
||||
try:
|
||||
account = AccountService.get_user_through_email(args["email"])
|
||||
except AccountRegisterError as are:
|
||||
except AccountRegisterError:
|
||||
raise AccountInFreezeError()
|
||||
|
||||
if account is None:
|
||||
@ -162,7 +162,7 @@ class EmailCodeLoginSendEmailApi(Resource):
|
||||
language = "en-US"
|
||||
try:
|
||||
account = AccountService.get_user_through_email(args["email"])
|
||||
except AccountRegisterError as are:
|
||||
except AccountRegisterError:
|
||||
raise AccountInFreezeError()
|
||||
|
||||
if account is None:
|
||||
@ -200,7 +200,7 @@ class EmailCodeLoginApi(Resource):
|
||||
AccountService.revoke_email_code_login_token(args["token"])
|
||||
try:
|
||||
account = AccountService.get_user_through_email(user_email)
|
||||
except AccountRegisterError as are:
|
||||
except AccountRegisterError:
|
||||
raise AccountInFreezeError()
|
||||
if account:
|
||||
tenants = TenantService.get_join_tenants(account)
|
||||
@ -223,7 +223,7 @@ class EmailCodeLoginApi(Resource):
|
||||
)
|
||||
except WorkSpaceNotAllowedCreateError:
|
||||
raise NotAllowedCreateWorkspace()
|
||||
except AccountRegisterError as are:
|
||||
except AccountRegisterError:
|
||||
raise AccountInFreezeError()
|
||||
except WorkspacesLimitExceededError:
|
||||
raise WorkspacesLimitExceeded()
|
||||
|
||||
@ -80,7 +80,7 @@ class OAuthCallback(Resource):
|
||||
try:
|
||||
token = oauth_provider.get_access_token(code)
|
||||
user_info = oauth_provider.get_user_info(token)
|
||||
except requests.exceptions.RequestException as e:
|
||||
except requests.RequestException as e:
|
||||
error_text = e.response.text if e.response else str(e)
|
||||
logger.exception("An error occurred during the OAuth process with %s: %s", provider, error_text)
|
||||
return {"error": "OAuth process failed"}, 400
|
||||
|
||||
@ -44,22 +44,19 @@ def oauth_server_access_token_required(view):
|
||||
if not oauth_provider_app or not isinstance(oauth_provider_app, OAuthProviderApp):
|
||||
raise BadRequest("Invalid oauth_provider_app")
|
||||
|
||||
if not request.headers.get("Authorization"):
|
||||
raise BadRequest("Authorization is required")
|
||||
|
||||
authorization_header = request.headers.get("Authorization")
|
||||
if not authorization_header:
|
||||
raise BadRequest("Authorization header is required")
|
||||
|
||||
parts = authorization_header.split(" ")
|
||||
parts = authorization_header.strip().split(" ")
|
||||
if len(parts) != 2:
|
||||
raise BadRequest("Invalid Authorization header format")
|
||||
|
||||
token_type = parts[0]
|
||||
if token_type != "Bearer":
|
||||
token_type = parts[0].strip()
|
||||
if token_type.lower() != "bearer":
|
||||
raise BadRequest("token_type is invalid")
|
||||
|
||||
access_token = parts[1]
|
||||
access_token = parts[1].strip()
|
||||
if not access_token:
|
||||
raise BadRequest("access_token is required")
|
||||
|
||||
@ -125,7 +122,10 @@ class OAuthServerUserTokenApi(Resource):
|
||||
parser.add_argument("refresh_token", type=str, required=False, location="json")
|
||||
parsed_args = parser.parse_args()
|
||||
|
||||
grant_type = OAuthGrantType(parsed_args["grant_type"])
|
||||
try:
|
||||
grant_type = OAuthGrantType(parsed_args["grant_type"])
|
||||
except ValueError:
|
||||
raise BadRequest("invalid grant_type")
|
||||
|
||||
if grant_type == OAuthGrantType.AUTHORIZATION_CODE:
|
||||
if not parsed_args["code"]:
|
||||
@ -163,8 +163,6 @@ class OAuthServerUserTokenApi(Resource):
|
||||
"refresh_token": refresh_token,
|
||||
}
|
||||
)
|
||||
else:
|
||||
raise BadRequest("invalid grant_type")
|
||||
|
||||
|
||||
class OAuthServerUserAccountApi(Resource):
|
||||
|
||||
@ -10,6 +10,7 @@ from werkzeug.exceptions import NotFound
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.extractor.notion_extractor import NotionExtractor
|
||||
from extensions.ext_database import db
|
||||
@ -214,7 +215,7 @@ class DataSourceNotionApi(Resource):
|
||||
workspace_id = notion_info["workspace_id"]
|
||||
for page in notion_info["pages"]:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="notion_import",
|
||||
datasource_type=DatasourceType.NOTION.value,
|
||||
notion_info={
|
||||
"notion_workspace_id": workspace_id,
|
||||
"notion_obj_id": page["page_id"],
|
||||
|
||||
@ -19,8 +19,10 @@ from controllers.console.wraps import (
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.plugin.entities.plugin import ModelProviderID
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from extensions.ext_database import db
|
||||
@ -30,7 +32,6 @@ from fields.document_fields import document_status_fields
|
||||
from libs.login import login_required
|
||||
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
|
||||
from models.dataset import DatasetPermissionEnum
|
||||
from models.provider_ids import ModelProviderID
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
||||
|
||||
|
||||
@ -422,7 +423,9 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
if file_details:
|
||||
for file_detail in file_details:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="upload_file", upload_file=file_detail, document_model=args["doc_form"]
|
||||
datasource_type=DatasourceType.FILE.value,
|
||||
upload_file=file_detail,
|
||||
document_model=args["doc_form"],
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
elif args["info_list"]["data_source_type"] == "notion_import":
|
||||
@ -431,7 +434,7 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
workspace_id = notion_info["workspace_id"]
|
||||
for page in notion_info["pages"]:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="notion_import",
|
||||
datasource_type=DatasourceType.NOTION.value,
|
||||
notion_info={
|
||||
"notion_workspace_id": workspace_id,
|
||||
"notion_obj_id": page["page_id"],
|
||||
@ -445,7 +448,7 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
website_info_list = args["info_list"]["website_info_list"]
|
||||
for url in website_info_list["urls"]:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="website_crawl",
|
||||
datasource_type=DatasourceType.WEBSITE.value,
|
||||
website_info={
|
||||
"provider": website_info_list["provider"],
|
||||
"job_id": website_info_list["job_id"],
|
||||
|
||||
@ -40,6 +40,7 @@ from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from extensions.ext_database import db
|
||||
from fields.document_fields import (
|
||||
@ -354,9 +355,6 @@ class DatasetInitApi(Resource):
|
||||
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
knowledge_config = KnowledgeConfig(**args)
|
||||
if knowledge_config.indexing_technique == "high_quality":
|
||||
if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None:
|
||||
@ -428,7 +426,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
|
||||
raise NotFound("File not found.")
|
||||
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="upload_file", upload_file=file, document_model=document.doc_form
|
||||
datasource_type=DatasourceType.FILE.value, upload_file=file, document_model=document.doc_form
|
||||
)
|
||||
|
||||
indexing_runner = IndexingRunner()
|
||||
@ -488,13 +486,13 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
raise NotFound("File not found.")
|
||||
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="upload_file", upload_file=file_detail, document_model=document.doc_form
|
||||
datasource_type=DatasourceType.FILE.value, upload_file=file_detail, document_model=document.doc_form
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
|
||||
elif document.data_source_type == "notion_import":
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="notion_import",
|
||||
datasource_type=DatasourceType.NOTION.value,
|
||||
notion_info={
|
||||
"notion_workspace_id": data_source_info["notion_workspace_id"],
|
||||
"notion_obj_id": data_source_info["notion_page_id"],
|
||||
@ -506,7 +504,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
extract_settings.append(extract_setting)
|
||||
elif document.data_source_type == "website_crawl":
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="website_crawl",
|
||||
datasource_type=DatasourceType.WEBSITE.value,
|
||||
website_info={
|
||||
"provider": data_source_info["provider"],
|
||||
"job_id": data_source_info["job_id"],
|
||||
|
||||
@ -61,7 +61,6 @@ class ConversationApi(InstalledAppResource):
|
||||
ConversationService.delete(app_model, conversation_id, current_user)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
WebConversationService.unpin(app_model, conversation_id, current_user)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@ -20,7 +20,6 @@ from core.errors.error import (
|
||||
QuotaExceededError,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||
from libs import helper
|
||||
from libs.login import current_user
|
||||
from models.model import AppMode, InstalledApp
|
||||
@ -79,11 +78,6 @@ class InstalledAppWorkflowTaskStopApi(InstalledAppResource):
|
||||
raise NotWorkflowAppError()
|
||||
assert current_user is not None
|
||||
|
||||
# Stop using both mechanisms for backward compatibility
|
||||
# Legacy stop flag mechanism (without user check)
|
||||
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||
|
||||
# New graph engine command channel mechanism
|
||||
GraphEngineManager.send_stop_command(task_id)
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@ -67,7 +67,7 @@ class ModelProviderCredentialApi(Resource):
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
@ -94,7 +94,7 @@ class ModelProviderCredentialApi(Resource):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
|
||||
@ -219,7 +219,11 @@ class ModelProviderModelCredentialApi(Resource):
|
||||
|
||||
model_load_balancing_service = ModelLoadBalancingService()
|
||||
is_load_balancing_enabled, load_balancing_configs = model_load_balancing_service.get_load_balancing_configs(
|
||||
tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model=args["model"],
|
||||
model_type=args["model_type"],
|
||||
config_from=args.get("config_from", ""),
|
||||
)
|
||||
|
||||
if args.get("config_from", "") == "predefined-model":
|
||||
@ -263,7 +267,7 @@ class ModelProviderModelCredentialApi(Resource):
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
|
||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -309,7 +313,7 @@ class ModelProviderModelCredentialApi(Resource):
|
||||
)
|
||||
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
|
||||
@ -21,11 +21,11 @@ from core.mcp.auth.auth_provider import OAuthClientProvider
|
||||
from core.mcp.error import MCPAuthError, MCPError
|
||||
from core.mcp.mcp_client import MCPClient
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.entities.plugin import ToolProviderID
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from core.tools.entities.tool_entities import CredentialType
|
||||
from libs.helper import StrLen, alphanumeric, uuid_value
|
||||
from libs.login import login_required
|
||||
from models.provider_ids import ToolProviderID
|
||||
from services.plugin.oauth_service import OAuthProxyService
|
||||
from services.tools.api_tools_manage_service import ApiToolManageService
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
|
||||
@ -8,7 +8,7 @@ from controllers.common.errors import UnsupportedFileTypeError
|
||||
from controllers.files import files_ns
|
||||
from core.tools.signature import verify_tool_file_signature
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from extensions.ext_database import db as global_db
|
||||
from models import db as global_db
|
||||
|
||||
|
||||
@files_ns.route("/tools/<uuid:file_id>.<string:extension>")
|
||||
|
||||
@ -1,8 +1,12 @@
|
||||
from base64 import b64encode
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from hashlib import sha1
|
||||
from hmac import new as hmac_new
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
from flask import abort, request
|
||||
|
||||
from configs import dify_config
|
||||
@ -10,9 +14,9 @@ from extensions.ext_database import db
|
||||
from models.model import EndUser
|
||||
|
||||
|
||||
def billing_inner_api_only(view):
|
||||
def billing_inner_api_only(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
if not dify_config.INNER_API:
|
||||
abort(404)
|
||||
|
||||
@ -26,9 +30,9 @@ def billing_inner_api_only(view):
|
||||
return decorated
|
||||
|
||||
|
||||
def enterprise_inner_api_only(view):
|
||||
def enterprise_inner_api_only(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
if not dify_config.INNER_API:
|
||||
abort(404)
|
||||
|
||||
@ -78,9 +82,9 @@ def enterprise_inner_api_user_auth(view):
|
||||
return decorated
|
||||
|
||||
|
||||
def plugin_inner_api_only(view):
|
||||
def plugin_inner_api_only(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
if not dify_config.PLUGIN_DAEMON_KEY:
|
||||
abort(404)
|
||||
|
||||
|
||||
@ -55,7 +55,7 @@ class AudioApi(Resource):
|
||||
file = request.files["file"]
|
||||
|
||||
try:
|
||||
response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=end_user)
|
||||
response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=end_user.id)
|
||||
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
|
||||
@ -59,7 +59,7 @@ class FilePreviewApi(Resource):
|
||||
args = file_preview_parser.parse_args()
|
||||
|
||||
# Validate file ownership and get file objects
|
||||
message_file, upload_file = self._validate_file_ownership(file_id, app_model.id)
|
||||
_, upload_file = self._validate_file_ownership(file_id, app_model.id)
|
||||
|
||||
# Get file content generator
|
||||
try:
|
||||
|
||||
@ -26,8 +26,7 @@ from core.errors.error import (
|
||||
)
|
||||
from core.helper.trace_id_helper import get_external_trace_id
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus
|
||||
from extensions.ext_database import db
|
||||
from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model
|
||||
from libs import helper
|
||||
@ -263,12 +262,7 @@ class WorkflowTaskStopApi(Resource):
|
||||
if app_mode != AppMode.WORKFLOW:
|
||||
raise NotWorkflowAppError()
|
||||
|
||||
# Stop using both mechanisms for backward compatibility
|
||||
# Legacy stop flag mechanism (without user check)
|
||||
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||
|
||||
# New graph engine command channel mechanism
|
||||
GraphEngineManager.send_stop_command(task_id)
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@ -13,13 +13,13 @@ from controllers.service_api.wraps import (
|
||||
validate_dataset_token,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.plugin.entities.plugin import ModelProviderID
|
||||
from core.provider_manager import ProviderManager
|
||||
from fields.dataset_fields import dataset_detail_fields
|
||||
from fields.tag_fields import build_dataset_tag_fields
|
||||
from libs.login import current_user
|
||||
from models.account import Account
|
||||
from models.dataset import Dataset, DatasetPermissionEnum
|
||||
from models.provider_ids import ModelProviderID
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
|
||||
from services.tag_service import TagService
|
||||
|
||||
@ -410,7 +410,7 @@ class DocumentUpdateByFileApi(DatasetApiResource):
|
||||
DocumentService.document_create_args_validate(knowledge_config)
|
||||
|
||||
try:
|
||||
documents, batch = DocumentService.save_document_with_dataset_id(
|
||||
documents, _ = DocumentService.save_document_with_dataset_id(
|
||||
dataset=dataset,
|
||||
knowledge_config=knowledge_config,
|
||||
account=dataset.created_by_account,
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from datetime import timedelta
|
||||
from enum import Enum
|
||||
from enum import StrEnum, auto
|
||||
from functools import wraps
|
||||
from typing import Optional
|
||||
|
||||
@ -23,14 +23,14 @@ from models.model import ApiToken, App, EndUser
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
class WhereisUserArg(Enum):
|
||||
class WhereisUserArg(StrEnum):
|
||||
"""
|
||||
Enum for whereis_user_arg.
|
||||
"""
|
||||
|
||||
QUERY = "query"
|
||||
JSON = "json"
|
||||
FORM = "form"
|
||||
QUERY = auto()
|
||||
JSON = auto()
|
||||
FORM = auto()
|
||||
|
||||
|
||||
class FetchUserArg(BaseModel):
|
||||
@ -291,27 +291,28 @@ def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str]
|
||||
if not user_id:
|
||||
user_id = "DEFAULT-USER"
|
||||
|
||||
end_user = (
|
||||
db.session.query(EndUser)
|
||||
.where(
|
||||
EndUser.tenant_id == app_model.tenant_id,
|
||||
EndUser.app_id == app_model.id,
|
||||
EndUser.session_id == user_id,
|
||||
EndUser.type == "service_api",
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
end_user = (
|
||||
session.query(EndUser)
|
||||
.where(
|
||||
EndUser.tenant_id == app_model.tenant_id,
|
||||
EndUser.app_id == app_model.id,
|
||||
EndUser.session_id == user_id,
|
||||
EndUser.type == "service_api",
|
||||
)
|
||||
.first()
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if end_user is None:
|
||||
end_user = EndUser(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
type="service_api",
|
||||
is_anonymous=user_id == "DEFAULT-USER",
|
||||
session_id=user_id,
|
||||
)
|
||||
db.session.add(end_user)
|
||||
db.session.commit()
|
||||
if end_user is None:
|
||||
end_user = EndUser(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
type="service_api",
|
||||
is_anonymous=user_id == "DEFAULT-USER",
|
||||
session_id=user_id,
|
||||
)
|
||||
session.add(end_user)
|
||||
session.commit()
|
||||
|
||||
return end_user
|
||||
|
||||
|
||||
@ -73,8 +73,6 @@ class ConversationApi(WebApiResource):
|
||||
ConversationService.delete(app_model, conversation_id, end_user)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
WebConversationService.unpin(app_model, conversation_id, end_user)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
|
||||
@ -21,7 +21,6 @@ from core.errors.error import (
|
||||
QuotaExceededError,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||
from libs import helper
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.app_generate_service import AppGenerateService
|
||||
@ -111,12 +110,7 @@ class WorkflowTaskStopApi(WebApiResource):
|
||||
if app_mode != AppMode.WORKFLOW:
|
||||
raise NotWorkflowAppError()
|
||||
|
||||
# Stop using both mechanisms for backward compatibility
|
||||
# Legacy stop flag mechanism (without user check)
|
||||
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||
|
||||
# New graph engine command channel mechanism
|
||||
GraphEngineManager.send_stop_command(task_id)
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@ -4,6 +4,7 @@ from functools import wraps
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
|
||||
|
||||
from controllers.web.error import WebAppAuthAccessDeniedError, WebAppAuthRequiredError
|
||||
@ -49,18 +50,19 @@ def decode_jwt_token():
|
||||
decoded = PassportService().verify(tk)
|
||||
app_code = decoded.get("app_code")
|
||||
app_id = decoded.get("app_id")
|
||||
app_model = db.session.scalar(select(App).where(App.id == app_id))
|
||||
site = db.session.scalar(select(Site).where(Site.code == app_code))
|
||||
if not app_model:
|
||||
raise NotFound()
|
||||
if not app_code or not site:
|
||||
raise BadRequest("Site URL is no longer valid.")
|
||||
if app_model.enable_site is False:
|
||||
raise BadRequest("Site is disabled.")
|
||||
end_user_id = decoded.get("end_user_id")
|
||||
end_user = db.session.scalar(select(EndUser).where(EndUser.id == end_user_id))
|
||||
if not end_user:
|
||||
raise NotFound()
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
app_model = session.scalar(select(App).where(App.id == app_id))
|
||||
site = session.scalar(select(Site).where(Site.code == app_code))
|
||||
if not app_model:
|
||||
raise NotFound()
|
||||
if not app_code or not site:
|
||||
raise BadRequest("Site URL is no longer valid.")
|
||||
if app_model.enable_site is False:
|
||||
raise BadRequest("Site is disabled.")
|
||||
end_user_id = decoded.get("end_user_id")
|
||||
end_user = session.scalar(select(EndUser).where(EndUser.id == end_user_id))
|
||||
if not end_user:
|
||||
raise NotFound()
|
||||
|
||||
# for enterprise webapp auth
|
||||
app_web_auth_enabled = False
|
||||
|
||||
@ -334,7 +334,8 @@ class BaseAgentRunner(AppRunner):
|
||||
"""
|
||||
Save agent thought
|
||||
"""
|
||||
agent_thought = db.session.query(MessageAgentThought).where(MessageAgentThought.id == agent_thought_id).first()
|
||||
stmt = select(MessageAgentThought).where(MessageAgentThought.id == agent_thought_id)
|
||||
agent_thought = db.session.scalar(stmt)
|
||||
if not agent_thought:
|
||||
raise ValueError("agent thought not found")
|
||||
|
||||
@ -492,7 +493,8 @@ class BaseAgentRunner(AppRunner):
|
||||
return result
|
||||
|
||||
def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage:
|
||||
files = db.session.query(MessageFile).where(MessageFile.message_id == message.id).all()
|
||||
stmt = select(MessageFile).where(MessageFile.message_id == message.id)
|
||||
files = db.session.scalars(stmt).all()
|
||||
if not files:
|
||||
return UserPromptMessage(content=message.query)
|
||||
if message.app_model_config:
|
||||
|
||||
@ -4,8 +4,8 @@ from typing import Any
|
||||
from core.app.app_config.entities import ModelConfigEntity
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from core.plugin.entities.plugin import ModelProviderID
|
||||
from core.provider_manager import ProviderManager
|
||||
from models.provider_ids import ModelProviderID
|
||||
|
||||
|
||||
class ModelConfigManager:
|
||||
|
||||
@ -26,7 +26,7 @@ class MoreLikeThisConfigManager:
|
||||
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
||||
try:
|
||||
return AppConfigModel.model_validate(config).model_dump(), ["more_like_this"]
|
||||
except ValidationError as e:
|
||||
except ValidationError:
|
||||
raise ValueError(
|
||||
"more_like_this must be of dict type and enabled in more_like_this must be of boolean type"
|
||||
)
|
||||
|
||||
@ -450,6 +450,12 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
|
||||
worker_thread.start()
|
||||
|
||||
# release database connection, because the following new thread operations may take a long time
|
||||
db.session.refresh(workflow)
|
||||
db.session.refresh(message)
|
||||
# db.session.refresh(user)
|
||||
db.session.close()
|
||||
|
||||
# return response or stream generator
|
||||
response = self._handle_advanced_chat_response(
|
||||
application_generate_entity=application_generate_entity,
|
||||
|
||||
@ -1,11 +1,11 @@
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||
@ -23,17 +23,16 @@ from core.app.features.annotation_reply.annotation_reply import AnnotationReplyF
|
||||
from core.moderation.base import ModerationError
|
||||
from core.moderation.input_moderation import InputModeration
|
||||
from core.variables.variables import VariableUnion
|
||||
from core.workflow.entities import GraphRuntimeState, VariablePool
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.variable_loader import VariableLoader
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models import Workflow
|
||||
from models.enums import UserFrom
|
||||
from models.model import App, Conversation, Message, MessageAnnotation
|
||||
from models.workflow import ConversationVariable
|
||||
from models.workflow import ConversationVariable, WorkflowType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -73,36 +72,30 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
app_config = self.application_generate_entity.app_config
|
||||
app_config = cast(AdvancedChatAppConfig, app_config)
|
||||
|
||||
app_record = db.session.query(App).where(App.id == app_config.app_id).first()
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
app_record = session.scalar(select(App).where(App.id == app_config.app_id))
|
||||
|
||||
if not app_record:
|
||||
raise ValueError("App not found")
|
||||
|
||||
workflow_callbacks: list[WorkflowCallback] = []
|
||||
if dify_config.DEBUG:
|
||||
workflow_callbacks.append(WorkflowLoggingCallback())
|
||||
|
||||
if self.application_generate_entity.single_iteration_run:
|
||||
# if only single iteration run is requested
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool.empty(),
|
||||
start_at=time.time(),
|
||||
)
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||
workflow=self._workflow,
|
||||
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
||||
user_inputs=dict(self.application_generate_entity.single_iteration_run.inputs),
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
graph_runtime_state.variable_pool = variable_pool
|
||||
elif self.application_generate_entity.single_loop_run:
|
||||
# if only single loop run is requested
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool.empty(),
|
||||
start_at=time.time(),
|
||||
)
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
|
||||
workflow=self._workflow,
|
||||
node_id=self.application_generate_entity.single_loop_run.node_id,
|
||||
user_inputs=dict(self.application_generate_entity.single_loop_run.inputs),
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
graph_runtime_state.variable_pool = variable_pool
|
||||
else:
|
||||
inputs = self.application_generate_entity.inputs
|
||||
query = self.application_generate_entity.query
|
||||
@ -149,31 +142,20 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
environment_variables=self._workflow.environment_variables,
|
||||
# Based on the definition of `VariableUnion`,
|
||||
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
|
||||
conversation_variables=cast(list[VariableUnion], conversation_variables),
|
||||
conversation_variables=conversation_variables,
|
||||
)
|
||||
|
||||
# init graph
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.time())
|
||||
graph = self._init_graph(
|
||||
graph_config=self._workflow.graph_dict,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
workflow_id=self._workflow.id,
|
||||
tenant_id=self._workflow.tenant_id,
|
||||
user_id=self.application_generate_entity.user_id,
|
||||
)
|
||||
graph = self._init_graph(graph_config=self._workflow.graph_dict)
|
||||
|
||||
db.session.close()
|
||||
|
||||
# RUN WORKFLOW
|
||||
# Create Redis command channel for this workflow execution
|
||||
task_id = self.application_generate_entity.task_id
|
||||
channel_key = f"workflow:{task_id}:commands"
|
||||
command_channel = RedisChannel(redis_client, channel_key)
|
||||
|
||||
workflow_entry = WorkflowEntry(
|
||||
tenant_id=self._workflow.tenant_id,
|
||||
app_id=self._workflow.app_id,
|
||||
workflow_id=self._workflow.id,
|
||||
workflow_type=WorkflowType.value_of(self._workflow.type),
|
||||
graph=graph,
|
||||
graph_config=self._workflow.graph_dict,
|
||||
user_id=self.application_generate_entity.user_id,
|
||||
@ -184,11 +166,12 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
),
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
call_depth=self.application_generate_entity.call_depth,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
command_channel=command_channel,
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
generator = workflow_entry.run()
|
||||
generator = workflow_entry.run(
|
||||
callbacks=workflow_callbacks,
|
||||
)
|
||||
|
||||
for event in generator:
|
||||
self._handle_event(workflow_entry, event)
|
||||
|
||||
@ -118,7 +118,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
|
||||
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
|
||||
response_chunk.update(sub_stream_response.to_ignore_detail_dict()) # ty: ignore [unresolved-attribute]
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.to_dict())
|
||||
|
||||
|
||||
@ -31,9 +31,14 @@ from core.app.entities.queue_entities import (
|
||||
QueueMessageReplaceEvent,
|
||||
QueueNodeExceptionEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeInLoopFailedEvent,
|
||||
QueueNodeRetryEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
QueueParallelBranchRunStartedEvent,
|
||||
QueueParallelBranchRunSucceededEvent,
|
||||
QueuePingEvent,
|
||||
QueueRetrieverResourcesEvent,
|
||||
QueueStopEvent,
|
||||
@ -60,15 +65,14 @@ from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
|
||||
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.entities import GraphRuntimeState
|
||||
from core.workflow.enums import WorkflowExecutionStatus, WorkflowType
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus, WorkflowType
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
|
||||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Conversation, EndUser, Message, MessageFile
|
||||
@ -306,13 +310,8 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
err = self._base_task_pipeline._handle_error(event=event, session=session, message_id=self._message_id)
|
||||
yield self._base_task_pipeline._error_to_stream_response(err)
|
||||
|
||||
def _handle_workflow_started_event(
|
||||
self, event: QueueWorkflowStartedEvent, *, graph_runtime_state: Optional[GraphRuntimeState] = None, **kwargs
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
def _handle_workflow_started_event(self, *args, **kwargs) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle workflow started events."""
|
||||
# Override graph runtime state - this is a side effect but necessary
|
||||
graph_runtime_state = event.graph_runtime_state
|
||||
|
||||
with self._database_session() as session:
|
||||
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start()
|
||||
self._workflow_run_id = workflow_execution.id_
|
||||
@ -333,15 +332,14 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
"""Handle node retry events."""
|
||||
self._ensure_workflow_initialized()
|
||||
|
||||
with self._database_session() as session:
|
||||
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried(
|
||||
workflow_execution_id=self._workflow_run_id, event=event
|
||||
)
|
||||
node_retry_resp = self._workflow_response_converter.workflow_node_retry_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried(
|
||||
workflow_execution_id=self._workflow_run_id, event=event
|
||||
)
|
||||
node_retry_resp = self._workflow_response_converter.workflow_node_retry_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if node_retry_resp:
|
||||
yield node_retry_resp
|
||||
@ -375,13 +373,12 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
self._workflow_response_converter.fetch_files_from_node_outputs(event.outputs or {})
|
||||
)
|
||||
|
||||
with self._database_session() as session:
|
||||
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_success(event=event)
|
||||
node_finish_resp = self._workflow_response_converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_success(event=event)
|
||||
node_finish_resp = self._workflow_response_converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
self._save_output_for_event(event, workflow_node_execution.id)
|
||||
|
||||
@ -390,7 +387,9 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
|
||||
def _handle_node_failed_events(
|
||||
self,
|
||||
event: Union[QueueNodeFailedEvent, QueueNodeExceptionEvent],
|
||||
event: Union[
|
||||
QueueNodeFailedEvent, QueueNodeInIterationFailedEvent, QueueNodeInLoopFailedEvent, QueueNodeExceptionEvent
|
||||
],
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle various node failure events."""
|
||||
@ -435,6 +434,32 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector
|
||||
)
|
||||
|
||||
def _handle_parallel_branch_started_event(
|
||||
self, event: QueueParallelBranchRunStartedEvent, **kwargs
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle parallel branch started events."""
|
||||
self._ensure_workflow_initialized()
|
||||
|
||||
parallel_start_resp = self._workflow_response_converter.workflow_parallel_branch_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution_id=self._workflow_run_id,
|
||||
event=event,
|
||||
)
|
||||
yield parallel_start_resp
|
||||
|
||||
def _handle_parallel_branch_finished_events(
|
||||
self, event: Union[QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent], **kwargs
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle parallel branch finished events."""
|
||||
self._ensure_workflow_initialized()
|
||||
|
||||
parallel_finish_resp = self._workflow_response_converter.workflow_parallel_branch_finished_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution_id=self._workflow_run_id,
|
||||
event=event,
|
||||
)
|
||||
yield parallel_finish_resp
|
||||
|
||||
def _handle_iteration_start_event(
|
||||
self, event: QueueIterationStartEvent, **kwargs
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
@ -726,6 +751,8 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
QueueNodeRetryEvent: self._handle_node_retry_event,
|
||||
QueueNodeStartedEvent: self._handle_node_started_event,
|
||||
QueueNodeSucceededEvent: self._handle_node_succeeded_event,
|
||||
# Parallel branch events
|
||||
QueueParallelBranchRunStartedEvent: self._handle_parallel_branch_started_event,
|
||||
# Iteration events
|
||||
QueueIterationStartEvent: self._handle_iteration_start_event,
|
||||
QueueIterationNextEvent: self._handle_iteration_next_event,
|
||||
@ -773,6 +800,8 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
event,
|
||||
(
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeInLoopFailedEvent,
|
||||
QueueNodeExceptionEvent,
|
||||
),
|
||||
):
|
||||
@ -785,6 +814,17 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
)
|
||||
return
|
||||
|
||||
# Handle parallel branch finished events with isinstance check
|
||||
if isinstance(event, (QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent)):
|
||||
yield from self._handle_parallel_branch_finished_events(
|
||||
event,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
tts_publisher=tts_publisher,
|
||||
trace_manager=trace_manager,
|
||||
queue_message=queue_message,
|
||||
)
|
||||
return
|
||||
|
||||
# For unhandled events, we continue (original behavior)
|
||||
return
|
||||
|
||||
@ -808,6 +848,11 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
graph_runtime_state = event.graph_runtime_state
|
||||
yield from self._handle_workflow_started_event(event)
|
||||
|
||||
case QueueTextChunkEvent():
|
||||
yield from self._handle_text_chunk_event(
|
||||
event, tts_publisher=tts_publisher, queue_message=queue_message
|
||||
)
|
||||
|
||||
case QueueErrorEvent():
|
||||
yield from self._handle_error_event(event)
|
||||
break
|
||||
@ -886,10 +931,6 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
self._task_state.metadata.usage = usage
|
||||
else:
|
||||
self._task_state.metadata.usage = LLMUsage.empty_usage()
|
||||
message_was_created.send(
|
||||
message,
|
||||
application_generate_entity=self._application_generate_entity,
|
||||
)
|
||||
|
||||
def _message_end_to_stream_response(self) -> MessageEndStreamResponse:
|
||||
"""
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
import logging
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.agent.cot_chat_agent_runner import CotChatAgentRunner
|
||||
from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner
|
||||
from core.agent.entities import AgentEntity
|
||||
@ -44,8 +46,8 @@ class AgentChatAppRunner(AppRunner):
|
||||
"""
|
||||
app_config = application_generate_entity.app_config
|
||||
app_config = cast(AgentChatAppConfig, app_config)
|
||||
|
||||
app_record = db.session.query(App).where(App.id == app_config.app_id).first()
|
||||
app_stmt = select(App).where(App.id == app_config.app_id)
|
||||
app_record = db.session.scalar(app_stmt)
|
||||
if not app_record:
|
||||
raise ValueError("App not found")
|
||||
|
||||
@ -182,11 +184,12 @@ class AgentChatAppRunner(AppRunner):
|
||||
|
||||
if {ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL}.intersection(model_schema.features or []):
|
||||
agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING
|
||||
|
||||
conversation_result = db.session.query(Conversation).where(Conversation.id == conversation.id).first()
|
||||
conversation_stmt = select(Conversation).where(Conversation.id == conversation.id)
|
||||
conversation_result = db.session.scalar(conversation_stmt)
|
||||
if conversation_result is None:
|
||||
raise ValueError("Conversation not found")
|
||||
message_result = db.session.query(Message).where(Message.id == message.id).first()
|
||||
msg_stmt = select(Message).where(Message.id == message.id)
|
||||
message_result = db.session.scalar(msg_stmt)
|
||||
if message_result is None:
|
||||
raise ValueError("Message not found")
|
||||
db.session.close()
|
||||
|
||||
@ -6,7 +6,7 @@ from sqlalchemy.orm import Session
|
||||
from core.app.app_config.entities import VariableEntityType
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file import File, FileUploadConfig
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.repositories.draft_variable_repository import (
|
||||
DraftVariableSaver,
|
||||
DraftVariableSaverFactory,
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import queue
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from enum import Enum
|
||||
from enum import IntEnum, auto
|
||||
from typing import Any, Optional
|
||||
|
||||
from sqlalchemy.orm import DeclarativeMeta
|
||||
@ -19,9 +19,9 @@ from core.app.entities.queue_entities import (
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
|
||||
class PublishFrom(Enum):
|
||||
APPLICATION_MANAGER = 1
|
||||
TASK_PIPELINE = 2
|
||||
class PublishFrom(IntEnum):
|
||||
APPLICATION_MANAGER = auto()
|
||||
TASK_PIPELINE = auto()
|
||||
|
||||
|
||||
class AppQueueManager:
|
||||
@ -126,21 +126,6 @@ class AppQueueManager:
|
||||
stopped_cache_key = cls._generate_stopped_cache_key(task_id)
|
||||
redis_client.setex(stopped_cache_key, 600, 1)
|
||||
|
||||
@classmethod
|
||||
def set_stop_flag_no_user_check(cls, task_id: str) -> None:
|
||||
"""
|
||||
Set task stop flag without user permission check.
|
||||
This method allows stopping workflows without user context.
|
||||
|
||||
:param task_id: The task ID to stop
|
||||
:return:
|
||||
"""
|
||||
if not task_id:
|
||||
return
|
||||
|
||||
stopped_cache_key = cls._generate_stopped_cache_key(task_id)
|
||||
redis_client.setex(stopped_cache_key, 600, 1)
|
||||
|
||||
def _is_stopped(self) -> bool:
|
||||
"""
|
||||
Check if task is stopped
|
||||
@ -174,7 +159,7 @@ class AppQueueManager:
|
||||
def _check_for_sqlalchemy_models(self, data: Any):
|
||||
# from entity to dict or list
|
||||
if isinstance(data, dict):
|
||||
for key, value in data.items():
|
||||
for value in data.values():
|
||||
self._check_for_sqlalchemy_models(value)
|
||||
elif isinstance(data, list):
|
||||
for item in data:
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
import logging
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.base_app_runner import AppRunner
|
||||
from core.app.apps.chat.app_config_manager import ChatAppConfig
|
||||
@ -42,8 +44,8 @@ class ChatAppRunner(AppRunner):
|
||||
"""
|
||||
app_config = application_generate_entity.app_config
|
||||
app_config = cast(ChatAppConfig, app_config)
|
||||
|
||||
app_record = db.session.query(App).where(App.id == app_config.app_id).first()
|
||||
stmt = select(App).where(App.id == app_config.app_id)
|
||||
app_record = db.session.scalar(stmt)
|
||||
if not app_record:
|
||||
raise ValueError("App not found")
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import time
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@ -16,9 +16,14 @@ from core.app.entities.queue_entities import (
|
||||
QueueLoopStartEvent,
|
||||
QueueNodeExceptionEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeInLoopFailedEvent,
|
||||
QueueNodeRetryEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
QueueParallelBranchRunStartedEvent,
|
||||
QueueParallelBranchRunSucceededEvent,
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
AgentLogStreamResponse,
|
||||
@ -31,16 +36,18 @@ from core.app.entities.task_entities import (
|
||||
NodeFinishStreamResponse,
|
||||
NodeRetryStreamResponse,
|
||||
NodeStartStreamResponse,
|
||||
ParallelBranchFinishedStreamResponse,
|
||||
ParallelBranchStartStreamResponse,
|
||||
WorkflowFinishStreamResponse,
|
||||
WorkflowStartStreamResponse,
|
||||
)
|
||||
from core.file import FILE_MODEL_IDENTITY, File
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||
from core.workflow.entities import WorkflowExecution, WorkflowNodeExecution
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecution
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import (
|
||||
@ -164,10 +171,11 @@ class WorkflowResponseConverter:
|
||||
|
||||
# extras logic
|
||||
if event.node_type == NodeType.TOOL:
|
||||
node_data = cast(ToolNodeData, event.node_data)
|
||||
response.data.extras["icon"] = ToolManager.get_tool_icon(
|
||||
tenant_id=self._application_generate_entity.app_config.tenant_id,
|
||||
provider_type=ToolProviderType(event.provider_type),
|
||||
provider_id=event.provider_id,
|
||||
provider_type=node_data.provider_type,
|
||||
provider_id=node_data.provider_id,
|
||||
)
|
||||
|
||||
return response
|
||||
@ -175,7 +183,11 @@ class WorkflowResponseConverter:
|
||||
def workflow_node_finish_to_stream_response(
|
||||
self,
|
||||
*,
|
||||
event: QueueNodeSucceededEvent | QueueNodeFailedEvent | QueueNodeExceptionEvent,
|
||||
event: QueueNodeSucceededEvent
|
||||
| QueueNodeFailedEvent
|
||||
| QueueNodeInIterationFailedEvent
|
||||
| QueueNodeInLoopFailedEvent
|
||||
| QueueNodeExceptionEvent,
|
||||
task_id: str,
|
||||
workflow_node_execution: WorkflowNodeExecution,
|
||||
) -> Optional[NodeFinishStreamResponse]:
|
||||
@ -209,6 +221,9 @@ class WorkflowResponseConverter:
|
||||
finished_at=int(workflow_node_execution.finished_at.timestamp()),
|
||||
files=self.fetch_files_from_node_outputs(workflow_node_execution.outputs or {}),
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
iteration_id=event.in_iteration_id,
|
||||
loop_id=event.in_loop_id,
|
||||
),
|
||||
@ -260,6 +275,50 @@ class WorkflowResponseConverter:
|
||||
),
|
||||
)
|
||||
|
||||
def workflow_parallel_branch_start_to_stream_response(
|
||||
self,
|
||||
*,
|
||||
task_id: str,
|
||||
workflow_execution_id: str,
|
||||
event: QueueParallelBranchRunStartedEvent,
|
||||
) -> ParallelBranchStartStreamResponse:
|
||||
return ParallelBranchStartStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_execution_id,
|
||||
data=ParallelBranchStartStreamResponse.Data(
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_branch_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
iteration_id=event.in_iteration_id,
|
||||
loop_id=event.in_loop_id,
|
||||
created_at=int(time.time()),
|
||||
),
|
||||
)
|
||||
|
||||
def workflow_parallel_branch_finished_to_stream_response(
|
||||
self,
|
||||
*,
|
||||
task_id: str,
|
||||
workflow_execution_id: str,
|
||||
event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent,
|
||||
) -> ParallelBranchFinishedStreamResponse:
|
||||
return ParallelBranchFinishedStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_execution_id,
|
||||
data=ParallelBranchFinishedStreamResponse.Data(
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_branch_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
iteration_id=event.in_iteration_id,
|
||||
loop_id=event.in_loop_id,
|
||||
status="succeeded" if isinstance(event, QueueParallelBranchRunSucceededEvent) else "failed",
|
||||
error=event.error if isinstance(event, QueueParallelBranchRunFailedEvent) else None,
|
||||
created_at=int(time.time()),
|
||||
),
|
||||
)
|
||||
|
||||
def workflow_iteration_start_to_stream_response(
|
||||
self,
|
||||
*,
|
||||
@ -274,11 +333,13 @@ class WorkflowResponseConverter:
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_title,
|
||||
title=event.node_data.title,
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
inputs=event.inputs or {},
|
||||
metadata=event.metadata or {},
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
),
|
||||
)
|
||||
|
||||
@ -296,10 +357,15 @@ class WorkflowResponseConverter:
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_title,
|
||||
title=event.node_data.title,
|
||||
index=event.index,
|
||||
pre_iteration_output=event.output,
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parallel_mode_run_id=event.parallel_mode_run_id,
|
||||
duration=event.duration,
|
||||
),
|
||||
)
|
||||
|
||||
@ -318,7 +384,7 @@ class WorkflowResponseConverter:
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_title,
|
||||
title=event.node_data.title,
|
||||
outputs=json_converter.to_json_encodable(event.outputs),
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
@ -332,6 +398,8 @@ class WorkflowResponseConverter:
|
||||
execution_metadata=event.metadata,
|
||||
finished_at=int(time.time()),
|
||||
steps=event.steps,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
),
|
||||
)
|
||||
|
||||
@ -345,7 +413,7 @@ class WorkflowResponseConverter:
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_title,
|
||||
title=event.node_data.title,
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
inputs=event.inputs or {},
|
||||
@ -369,7 +437,7 @@ class WorkflowResponseConverter:
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_title,
|
||||
title=event.node_data.title,
|
||||
index=event.index,
|
||||
pre_loop_output=event.output,
|
||||
created_at=int(time.time()),
|
||||
@ -377,6 +445,7 @@ class WorkflowResponseConverter:
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parallel_mode_run_id=event.parallel_mode_run_id,
|
||||
duration=event.duration,
|
||||
),
|
||||
)
|
||||
|
||||
@ -394,7 +463,7 @@ class WorkflowResponseConverter:
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_title,
|
||||
title=event.node_data.title,
|
||||
outputs=WorkflowRuntimeTypeConverter().to_json_encodable(event.outputs),
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
|
||||
@ -6,6 +6,7 @@ from typing import Any, Literal, Union, overload
|
||||
|
||||
from flask import Flask, copy_current_request_context, current_app
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import select
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
|
||||
@ -248,17 +249,14 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
:param invoke_from: invoke from source
|
||||
:param stream: is stream
|
||||
"""
|
||||
message = (
|
||||
db.session.query(Message)
|
||||
.where(
|
||||
Message.id == message_id,
|
||||
Message.app_id == app_model.id,
|
||||
Message.from_source == ("api" if isinstance(user, EndUser) else "console"),
|
||||
Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
|
||||
Message.from_account_id == (user.id if isinstance(user, Account) else None),
|
||||
)
|
||||
.first()
|
||||
stmt = select(Message).where(
|
||||
Message.id == message_id,
|
||||
Message.app_id == app_model.id,
|
||||
Message.from_source == ("api" if isinstance(user, EndUser) else "console"),
|
||||
Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
|
||||
Message.from_account_id == (user.id if isinstance(user, Account) else None),
|
||||
)
|
||||
message = db.session.scalar(stmt)
|
||||
|
||||
if not message:
|
||||
raise MessageNotExistsError()
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
import logging
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.base_app_runner import AppRunner
|
||||
from core.app.apps.completion.app_config_manager import CompletionAppConfig
|
||||
@ -35,8 +37,8 @@ class CompletionAppRunner(AppRunner):
|
||||
"""
|
||||
app_config = application_generate_entity.app_config
|
||||
app_config = cast(CompletionAppConfig, app_config)
|
||||
|
||||
app_record = db.session.query(App).where(App.id == app_config.app_id).first()
|
||||
stmt = select(App).where(App.id == app_config.app_id)
|
||||
app_record = db.session.scalar(stmt)
|
||||
if not app_record:
|
||||
raise ValueError("App not found")
|
||||
|
||||
|
||||
@ -3,6 +3,9 @@ import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
@ -83,11 +86,10 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
|
||||
def _get_app_model_config(self, app_model: App, conversation: Optional[Conversation] = None) -> AppModelConfig:
|
||||
if conversation:
|
||||
app_model_config = (
|
||||
db.session.query(AppModelConfig)
|
||||
.where(AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id)
|
||||
.first()
|
||||
stmt = select(AppModelConfig).where(
|
||||
AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id
|
||||
)
|
||||
app_model_config = db.session.scalar(stmt)
|
||||
|
||||
if not app_model_config:
|
||||
raise AppModelConfigBrokenError()
|
||||
@ -253,7 +255,8 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
:param conversation_id: conversation id
|
||||
:return: conversation
|
||||
"""
|
||||
conversation = db.session.query(Conversation).where(Conversation.id == conversation_id).first()
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
conversation = session.scalar(select(Conversation).where(Conversation.id == conversation_id))
|
||||
|
||||
if not conversation:
|
||||
raise ConversationNotExistsError("Conversation not exists")
|
||||
@ -266,7 +269,8 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
:param message_id: message id
|
||||
:return: message
|
||||
"""
|
||||
message = db.session.query(Message).where(Message.id == message_id).first()
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
message = session.scalar(select(Message).where(Message.id == message_id))
|
||||
|
||||
if message is None:
|
||||
raise MessageNotExistsError("Message not exists")
|
||||
|
||||
@ -3,7 +3,7 @@ import logging
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, Literal, Union, overload
|
||||
from typing import Any, Literal, Optional, Union, overload
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
@ -53,6 +53,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[True],
|
||||
call_depth: int,
|
||||
workflow_thread_pool_id: Optional[str],
|
||||
) -> Generator[Mapping | str, None, None]: ...
|
||||
|
||||
@overload
|
||||
@ -66,6 +67,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[False],
|
||||
call_depth: int,
|
||||
workflow_thread_pool_id: Optional[str],
|
||||
) -> Mapping[str, Any]: ...
|
||||
|
||||
@overload
|
||||
@ -79,6 +81,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool,
|
||||
call_depth: int,
|
||||
workflow_thread_pool_id: Optional[str],
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ...
|
||||
|
||||
def generate(
|
||||
@ -91,6 +94,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
call_depth: int = 0,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]:
|
||||
files: Sequence[Mapping[str, Any]] = args.get("files") or []
|
||||
|
||||
@ -182,6 +186,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||
)
|
||||
|
||||
def _generate(
|
||||
@ -195,6 +200,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
streaming: bool = True,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
|
||||
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
|
||||
"""
|
||||
@ -208,6 +214,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
:param workflow_execution_repository: repository for workflow execution
|
||||
:param workflow_node_execution_repository: repository for workflow node execution
|
||||
:param streaming: is stream
|
||||
:param workflow_thread_pool_id: workflow thread pool id
|
||||
"""
|
||||
# init queue manager
|
||||
queue_manager = WorkflowAppQueueManager(
|
||||
@ -230,6 +237,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"queue_manager": queue_manager,
|
||||
"context": context,
|
||||
"workflow_thread_pool_id": workflow_thread_pool_id,
|
||||
"variable_loader": variable_loader,
|
||||
},
|
||||
)
|
||||
@ -426,7 +434,17 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
queue_manager: AppQueueManager,
|
||||
context: contextvars.Context,
|
||||
variable_loader: VariableLoader,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Generate worker in a new thread.
|
||||
:param flask_app: Flask app
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: queue manager
|
||||
:param workflow_thread_pool_id: workflow thread pool id
|
||||
:return:
|
||||
"""
|
||||
|
||||
with preserve_flask_contexts(flask_app, context_vars=context):
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow = session.scalar(
|
||||
@ -456,6 +474,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
runner = WorkflowAppRunner(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||
variable_loader=variable_loader,
|
||||
workflow=workflow,
|
||||
system_user_id=system_user_id,
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import cast
|
||||
from typing import Optional, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfig
|
||||
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||
@ -9,14 +9,13 @@ from core.app.entities.app_invoke_entities import (
|
||||
InvokeFrom,
|
||||
WorkflowAppGenerateEntity,
|
||||
)
|
||||
from core.workflow.entities import GraphRuntimeState, VariablePool
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.variable_loader import VariableLoader
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import Workflow
|
||||
from models.workflow import Workflow, WorkflowType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -32,6 +31,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
application_generate_entity: WorkflowAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
variable_loader: VariableLoader,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
workflow: Workflow,
|
||||
system_user_id: str,
|
||||
) -> None:
|
||||
@ -41,6 +41,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
)
|
||||
self.application_generate_entity = application_generate_entity
|
||||
self.workflow_thread_pool_id = workflow_thread_pool_id
|
||||
self._workflow = workflow
|
||||
self._sys_user_id = system_user_id
|
||||
|
||||
@ -51,30 +52,24 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
app_config = self.application_generate_entity.app_config
|
||||
app_config = cast(WorkflowAppConfig, app_config)
|
||||
|
||||
workflow_callbacks: list[WorkflowCallback] = []
|
||||
if dify_config.DEBUG:
|
||||
workflow_callbacks.append(WorkflowLoggingCallback())
|
||||
|
||||
# if only single iteration run is requested
|
||||
if self.application_generate_entity.single_iteration_run:
|
||||
# if only single iteration run is requested
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool.empty(),
|
||||
start_at=time.time(),
|
||||
)
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||
workflow=self._workflow,
|
||||
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
||||
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
elif self.application_generate_entity.single_loop_run:
|
||||
# if only single loop run is requested
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool.empty(),
|
||||
start_at=time.time(),
|
||||
)
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
|
||||
workflow=self._workflow,
|
||||
node_id=self.application_generate_entity.single_loop_run.node_id,
|
||||
user_inputs=self.application_generate_entity.single_loop_run.inputs,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
else:
|
||||
inputs = self.application_generate_entity.inputs
|
||||
@ -97,26 +92,15 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
conversation_variables=[],
|
||||
)
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
# init graph
|
||||
graph = self._init_graph(
|
||||
graph_config=self._workflow.graph_dict,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
workflow_id=self._workflow.id,
|
||||
tenant_id=self._workflow.tenant_id,
|
||||
)
|
||||
graph = self._init_graph(graph_config=self._workflow.graph_dict)
|
||||
|
||||
# RUN WORKFLOW
|
||||
# Create Redis command channel for this workflow execution
|
||||
task_id = self.application_generate_entity.task_id
|
||||
channel_key = f"workflow:{task_id}:commands"
|
||||
command_channel = RedisChannel(redis_client, channel_key)
|
||||
|
||||
workflow_entry = WorkflowEntry(
|
||||
tenant_id=self._workflow.tenant_id,
|
||||
app_id=self._workflow.app_id,
|
||||
workflow_id=self._workflow.id,
|
||||
workflow_type=WorkflowType.value_of(self._workflow.type),
|
||||
graph=graph,
|
||||
graph_config=self._workflow.graph_dict,
|
||||
user_id=self.application_generate_entity.user_id,
|
||||
@ -127,11 +111,11 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
),
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
call_depth=self.application_generate_entity.call_depth,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
command_channel=command_channel,
|
||||
variable_pool=variable_pool,
|
||||
thread_pool_id=self.workflow_thread_pool_id,
|
||||
)
|
||||
|
||||
generator = workflow_entry.run()
|
||||
generator = workflow_entry.run(callbacks=workflow_callbacks)
|
||||
|
||||
for event in generator:
|
||||
self._handle_event(workflow_entry, event)
|
||||
|
||||
@ -89,7 +89,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
|
||||
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
|
||||
response_chunk.update(sub_stream_response.to_ignore_detail_dict()) # ty: ignore [unresolved-attribute]
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.to_dict())
|
||||
yield response_chunk
|
||||
|
||||
@ -2,7 +2,7 @@ import logging
|
||||
import time
|
||||
from collections.abc import Callable, Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@ -14,7 +14,6 @@ from core.app.entities.app_invoke_entities import (
|
||||
WorkflowAppGenerateEntity,
|
||||
)
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
MessageQueueMessage,
|
||||
QueueAgentLogEvent,
|
||||
QueueErrorEvent,
|
||||
@ -26,9 +25,14 @@ from core.app.entities.queue_entities import (
|
||||
QueueLoopStartEvent,
|
||||
QueueNodeExceptionEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeInLoopFailedEvent,
|
||||
QueueNodeRetryEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
QueueParallelBranchRunStartedEvent,
|
||||
QueueParallelBranchRunSucceededEvent,
|
||||
QueuePingEvent,
|
||||
QueueStopEvent,
|
||||
QueueTextChunkEvent,
|
||||
@ -53,8 +57,8 @@ from core.app.entities.task_entities import (
|
||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.entities import GraphRuntimeState, WorkflowExecution
|
||||
from core.workflow.enums import WorkflowExecutionStatus, WorkflowType
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
@ -296,16 +300,15 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
"""Handle node retry events."""
|
||||
self._ensure_workflow_initialized()
|
||||
|
||||
with self._database_session() as session:
|
||||
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried(
|
||||
workflow_execution_id=self._workflow_run_id,
|
||||
event=event,
|
||||
)
|
||||
response = self._workflow_response_converter.workflow_node_retry_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried(
|
||||
workflow_execution_id=self._workflow_run_id,
|
||||
event=event,
|
||||
)
|
||||
response = self._workflow_response_converter.workflow_node_retry_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if response:
|
||||
yield response
|
||||
@ -346,7 +349,9 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
|
||||
def _handle_node_failed_events(
|
||||
self,
|
||||
event: Union[QueueNodeFailedEvent, QueueNodeExceptionEvent],
|
||||
event: Union[
|
||||
QueueNodeFailedEvent, QueueNodeInIterationFailedEvent, QueueNodeInLoopFailedEvent, QueueNodeExceptionEvent
|
||||
],
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle various node failure events."""
|
||||
@ -365,6 +370,32 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
if node_failed_response:
|
||||
yield node_failed_response
|
||||
|
||||
def _handle_parallel_branch_started_event(
|
||||
self, event: QueueParallelBranchRunStartedEvent, **kwargs
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle parallel branch started events."""
|
||||
self._ensure_workflow_initialized()
|
||||
|
||||
parallel_start_resp = self._workflow_response_converter.workflow_parallel_branch_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution_id=self._workflow_run_id,
|
||||
event=event,
|
||||
)
|
||||
yield parallel_start_resp
|
||||
|
||||
def _handle_parallel_branch_finished_events(
|
||||
self, event: Union[QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent], **kwargs
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle parallel branch finished events."""
|
||||
self._ensure_workflow_initialized()
|
||||
|
||||
parallel_finish_resp = self._workflow_response_converter.workflow_parallel_branch_finished_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution_id=self._workflow_run_id,
|
||||
event=event,
|
||||
)
|
||||
yield parallel_finish_resp
|
||||
|
||||
def _handle_iteration_start_event(
|
||||
self, event: QueueIterationStartEvent, **kwargs
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
@ -586,6 +617,8 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
QueueNodeRetryEvent: self._handle_node_retry_event,
|
||||
QueueNodeStartedEvent: self._handle_node_started_event,
|
||||
QueueNodeSucceededEvent: self._handle_node_succeeded_event,
|
||||
# Parallel branch events
|
||||
QueueParallelBranchRunStartedEvent: self._handle_parallel_branch_started_event,
|
||||
# Iteration events
|
||||
QueueIterationStartEvent: self._handle_iteration_start_event,
|
||||
QueueIterationNextEvent: self._handle_iteration_next_event,
|
||||
@ -600,7 +633,7 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
|
||||
def _dispatch_event(
|
||||
self,
|
||||
event: AppQueueEvent,
|
||||
event: Any,
|
||||
*,
|
||||
graph_runtime_state: Optional[GraphRuntimeState] = None,
|
||||
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
|
||||
@ -627,6 +660,8 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
event,
|
||||
(
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeInLoopFailedEvent,
|
||||
QueueNodeExceptionEvent,
|
||||
),
|
||||
):
|
||||
@ -639,6 +674,17 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
)
|
||||
return
|
||||
|
||||
# Handle parallel branch finished events with isinstance check
|
||||
if isinstance(event, (QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent)):
|
||||
yield from self._handle_parallel_branch_finished_events(
|
||||
event,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
tts_publisher=tts_publisher,
|
||||
trace_manager=trace_manager,
|
||||
queue_message=queue_message,
|
||||
)
|
||||
return
|
||||
|
||||
# Handle workflow failed and stop events with isinstance check
|
||||
if isinstance(event, (QueueWorkflowFailedEvent, QueueStopEvent)):
|
||||
yield from self._handle_workflow_failed_and_stop_events(
|
||||
|
||||
@ -2,7 +2,6 @@ from collections.abc import Mapping
|
||||
from typing import Any, cast
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
QueueAgentLogEvent,
|
||||
@ -14,9 +13,14 @@ from core.app.entities.queue_entities import (
|
||||
QueueLoopStartEvent,
|
||||
QueueNodeExceptionEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeInLoopFailedEvent,
|
||||
QueueNodeRetryEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
QueueParallelBranchRunStartedEvent,
|
||||
QueueParallelBranchRunSucceededEvent,
|
||||
QueueRetrieverResourcesEvent,
|
||||
QueueTextChunkEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
@ -24,39 +28,42 @@ from core.app.entities.queue_entities import (
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import (
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
AgentLogEvent,
|
||||
GraphEngineEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunPartialSucceededEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunAgentLogEvent,
|
||||
IterationRunFailedEvent,
|
||||
IterationRunNextEvent,
|
||||
IterationRunStartedEvent,
|
||||
IterationRunSucceededEvent,
|
||||
LoopRunFailedEvent,
|
||||
LoopRunNextEvent,
|
||||
LoopRunStartedEvent,
|
||||
LoopRunSucceededEvent,
|
||||
NodeInIterationFailedEvent,
|
||||
NodeInLoopFailedEvent,
|
||||
NodeRunExceptionEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunIterationFailedEvent,
|
||||
NodeRunIterationNextEvent,
|
||||
NodeRunIterationStartedEvent,
|
||||
NodeRunIterationSucceededEvent,
|
||||
NodeRunLoopFailedEvent,
|
||||
NodeRunLoopNextEvent,
|
||||
NodeRunLoopStartedEvent,
|
||||
NodeRunLoopSucceededEvent,
|
||||
NodeRunRetrieverResourceEvent,
|
||||
NodeRunRetryEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
ParallelBranchRunFailedEvent,
|
||||
ParallelBranchRunStartedEvent,
|
||||
ParallelBranchRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.graph_events.graph import GraphRunAbortedEvent
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import Workflow
|
||||
|
||||
|
||||
@ -72,14 +79,7 @@ class WorkflowBasedAppRunner:
|
||||
self._variable_loader = variable_loader
|
||||
self._app_id = app_id
|
||||
|
||||
def _init_graph(
|
||||
self,
|
||||
graph_config: Mapping[str, Any],
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
workflow_id: str = "",
|
||||
tenant_id: str = "",
|
||||
user_id: str = "",
|
||||
) -> Graph:
|
||||
def _init_graph(self, graph_config: Mapping[str, Any]) -> Graph:
|
||||
"""
|
||||
Init graph
|
||||
"""
|
||||
@ -91,28 +91,8 @@ class WorkflowBasedAppRunner:
|
||||
|
||||
if not isinstance(graph_config.get("edges"), list):
|
||||
raise ValueError("edges in workflow graph must be a list")
|
||||
|
||||
# Create required parameters for Graph.init
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id=tenant_id or "",
|
||||
app_id=self._app_id,
|
||||
workflow_id=workflow_id,
|
||||
graph_config=graph_config,
|
||||
user_id=user_id,
|
||||
user_from=UserFrom.ACCOUNT.value,
|
||||
invoke_from=InvokeFrom.SERVICE_API.value,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
# Use the provided graph_runtime_state for consistent state management
|
||||
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
# init graph
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
if not graph:
|
||||
raise ValueError("graph not found in workflow")
|
||||
@ -124,7 +104,6 @@ class WorkflowBasedAppRunner:
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user_inputs: dict,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
) -> tuple[Graph, VariablePool]:
|
||||
"""
|
||||
Get variable pool of single iteration
|
||||
@ -149,9 +128,7 @@ class WorkflowBasedAppRunner:
|
||||
node_configs = [
|
||||
node
|
||||
for node in graph_config.get("nodes", [])
|
||||
if node.get("id") == node_id
|
||||
or node.get("data", {}).get("iteration_id", "") == node_id
|
||||
or node.get("id") == f"{node_id}start"
|
||||
if node.get("id") == node_id or node.get("data", {}).get("iteration_id", "") == node_id
|
||||
]
|
||||
|
||||
graph_config["nodes"] = node_configs
|
||||
@ -168,25 +145,8 @@ class WorkflowBasedAppRunner:
|
||||
|
||||
graph_config["edges"] = edge_configs
|
||||
|
||||
# Create required parameters for Graph.init
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=self._app_id,
|
||||
workflow_id=workflow.id,
|
||||
graph_config=graph_config,
|
||||
user_id="",
|
||||
user_from=UserFrom.ACCOUNT.value,
|
||||
invoke_from=InvokeFrom.SERVICE_API.value,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
# init graph
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=node_id)
|
||||
graph = Graph.init(graph_config=graph_config, root_node_id=node_id)
|
||||
|
||||
if not graph:
|
||||
raise ValueError("graph not found in workflow")
|
||||
@ -241,7 +201,6 @@ class WorkflowBasedAppRunner:
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user_inputs: dict,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
) -> tuple[Graph, VariablePool]:
|
||||
"""
|
||||
Get variable pool of single loop
|
||||
@ -266,9 +225,7 @@ class WorkflowBasedAppRunner:
|
||||
node_configs = [
|
||||
node
|
||||
for node in graph_config.get("nodes", [])
|
||||
if node.get("id") == node_id
|
||||
or node.get("data", {}).get("loop_id", "") == node_id
|
||||
or node.get("id") == f"{node_id}start"
|
||||
if node.get("id") == node_id or node.get("data", {}).get("loop_id", "") == node_id
|
||||
]
|
||||
|
||||
graph_config["nodes"] = node_configs
|
||||
@ -285,25 +242,8 @@ class WorkflowBasedAppRunner:
|
||||
|
||||
graph_config["edges"] = edge_configs
|
||||
|
||||
# Create required parameters for Graph.init
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=self._app_id,
|
||||
workflow_id=workflow.id,
|
||||
graph_config=graph_config,
|
||||
user_id="",
|
||||
user_from=UserFrom.ACCOUNT.value,
|
||||
invoke_from=InvokeFrom.SERVICE_API.value,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
# init graph
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=node_id)
|
||||
graph = Graph.init(graph_config=graph_config, root_node_id=node_id)
|
||||
|
||||
if not graph:
|
||||
raise ValueError("graph not found in workflow")
|
||||
@ -370,21 +310,29 @@ class WorkflowBasedAppRunner:
|
||||
)
|
||||
elif isinstance(event, GraphRunFailedEvent):
|
||||
self._publish_event(QueueWorkflowFailedEvent(error=event.error, exceptions_count=event.exceptions_count))
|
||||
elif isinstance(event, GraphRunAbortedEvent):
|
||||
self._publish_event(QueueWorkflowFailedEvent(error=event.reason or "Unknown error", exceptions_count=0))
|
||||
elif isinstance(event, NodeRunRetryEvent):
|
||||
node_run_result = event.node_run_result
|
||||
inputs = node_run_result.inputs
|
||||
process_data = node_run_result.process_data
|
||||
outputs = node_run_result.outputs
|
||||
execution_metadata = node_run_result.metadata
|
||||
node_run_result = event.route_node_state.node_run_result
|
||||
inputs: Mapping[str, Any] | None = {}
|
||||
process_data: Mapping[str, Any] | None = {}
|
||||
outputs: Mapping[str, Any] | None = {}
|
||||
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = {}
|
||||
if node_run_result:
|
||||
inputs = node_run_result.inputs
|
||||
process_data = node_run_result.process_data
|
||||
outputs = node_run_result.outputs
|
||||
execution_metadata = node_run_result.metadata
|
||||
self._publish_event(
|
||||
QueueNodeRetryEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_title=event.node_title,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.start_at,
|
||||
node_run_index=event.route_node_state.index,
|
||||
predecessor_node_id=event.predecessor_node_id,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
@ -395,8 +343,6 @@ class WorkflowBasedAppRunner:
|
||||
error=event.error,
|
||||
execution_metadata=execution_metadata,
|
||||
retry_index=event.retry_index,
|
||||
provider_type=event.provider_type,
|
||||
provider_id=event.provider_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunStartedEvent):
|
||||
@ -404,30 +350,44 @@ class WorkflowBasedAppRunner:
|
||||
QueueNodeStartedEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_title=event.node_title,
|
||||
node_type=event.node_type,
|
||||
start_at=event.start_at,
|
||||
node_data=event.node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.route_node_state.start_at,
|
||||
node_run_index=event.route_node_state.index,
|
||||
predecessor_node_id=event.predecessor_node_id,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
parallel_mode_run_id=event.parallel_mode_run_id,
|
||||
agent_strategy=event.agent_strategy,
|
||||
provider_type=event.provider_type,
|
||||
provider_id=event.provider_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunSucceededEvent):
|
||||
node_run_result = event.node_run_result
|
||||
inputs = node_run_result.inputs
|
||||
process_data = node_run_result.process_data
|
||||
outputs = node_run_result.outputs
|
||||
execution_metadata = node_run_result.metadata
|
||||
node_run_result = event.route_node_state.node_run_result
|
||||
if node_run_result:
|
||||
inputs = node_run_result.inputs
|
||||
process_data = node_run_result.process_data
|
||||
outputs = node_run_result.outputs
|
||||
execution_metadata = node_run_result.metadata
|
||||
else:
|
||||
inputs = {}
|
||||
process_data = {}
|
||||
outputs = {}
|
||||
execution_metadata = {}
|
||||
self._publish_event(
|
||||
QueueNodeSucceededEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
start_at=event.start_at,
|
||||
node_data=event.node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.route_node_state.start_at,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
@ -436,18 +396,34 @@ class WorkflowBasedAppRunner:
|
||||
in_loop_id=event.in_loop_id,
|
||||
)
|
||||
)
|
||||
|
||||
elif isinstance(event, NodeRunFailedEvent):
|
||||
self._publish_event(
|
||||
QueueNodeFailedEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
start_at=event.start_at,
|
||||
inputs=event.node_run_result.inputs,
|
||||
process_data=event.node_run_result.process_data,
|
||||
outputs=event.node_run_result.outputs,
|
||||
error=event.node_run_result.error or "Unknown error",
|
||||
execution_metadata=event.node_run_result.metadata,
|
||||
node_data=event.node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.route_node_state.start_at,
|
||||
inputs=event.route_node_state.node_run_result.inputs
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
process_data=event.route_node_state.node_run_result.process_data
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
outputs=event.route_node_state.node_run_result.outputs or {}
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
error=event.route_node_state.node_run_result.error
|
||||
if event.route_node_state.node_run_result and event.route_node_state.node_run_result.error
|
||||
else "Unknown error",
|
||||
execution_metadata=event.route_node_state.node_run_result.metadata
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
)
|
||||
@ -458,21 +434,93 @@ class WorkflowBasedAppRunner:
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
start_at=event.start_at,
|
||||
inputs=event.node_run_result.inputs,
|
||||
process_data=event.node_run_result.process_data,
|
||||
outputs=event.node_run_result.outputs,
|
||||
error=event.node_run_result.error or "Unknown error",
|
||||
execution_metadata=event.node_run_result.metadata,
|
||||
node_data=event.node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.route_node_state.start_at,
|
||||
inputs=event.route_node_state.node_run_result.inputs
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
process_data=event.route_node_state.node_run_result.process_data
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
outputs=event.route_node_state.node_run_result.outputs
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
error=event.route_node_state.node_run_result.error
|
||||
if event.route_node_state.node_run_result and event.route_node_state.node_run_result.error
|
||||
else "Unknown error",
|
||||
execution_metadata=event.route_node_state.node_run_result.metadata
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
)
|
||||
)
|
||||
|
||||
elif isinstance(event, NodeInIterationFailedEvent):
|
||||
self._publish_event(
|
||||
QueueNodeInIterationFailedEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.route_node_state.start_at,
|
||||
inputs=event.route_node_state.node_run_result.inputs
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
process_data=event.route_node_state.node_run_result.process_data
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
outputs=event.route_node_state.node_run_result.outputs or {}
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
execution_metadata=event.route_node_state.node_run_result.metadata
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
error=event.error,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeInLoopFailedEvent):
|
||||
self._publish_event(
|
||||
QueueNodeInLoopFailedEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.route_node_state.start_at,
|
||||
inputs=event.route_node_state.node_run_result.inputs
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
process_data=event.route_node_state.node_run_result.process_data
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
outputs=event.route_node_state.node_run_result.outputs or {}
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
execution_metadata=event.route_node_state.node_run_result.metadata
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
in_loop_id=event.in_loop_id,
|
||||
error=event.error,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunStreamChunkEvent):
|
||||
self._publish_event(
|
||||
QueueTextChunkEvent(
|
||||
text=event.chunk,
|
||||
from_variable_selector=list(event.selector),
|
||||
text=event.chunk_content,
|
||||
from_variable_selector=event.from_variable_selector,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
)
|
||||
@ -485,10 +533,10 @@ class WorkflowBasedAppRunner:
|
||||
in_loop_id=event.in_loop_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunAgentLogEvent):
|
||||
elif isinstance(event, AgentLogEvent):
|
||||
self._publish_event(
|
||||
QueueAgentLogEvent(
|
||||
id=event.message_id,
|
||||
id=event.id,
|
||||
label=event.label,
|
||||
node_execution_id=event.node_execution_id,
|
||||
parent_id=event.parent_id,
|
||||
@ -499,13 +547,51 @@ class WorkflowBasedAppRunner:
|
||||
node_id=event.node_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunIterationStartedEvent):
|
||||
elif isinstance(event, ParallelBranchRunStartedEvent):
|
||||
self._publish_event(
|
||||
QueueParallelBranchRunStartedEvent(
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, ParallelBranchRunSucceededEvent):
|
||||
self._publish_event(
|
||||
QueueParallelBranchRunSucceededEvent(
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, ParallelBranchRunFailedEvent):
|
||||
self._publish_event(
|
||||
QueueParallelBranchRunFailedEvent(
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
error=event.error,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, IterationRunStartedEvent):
|
||||
self._publish_event(
|
||||
QueueIterationStartEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_title=event.node_title,
|
||||
node_execution_id=event.iteration_id,
|
||||
node_id=event.iteration_node_id,
|
||||
node_type=event.iteration_node_type,
|
||||
node_data=event.iteration_node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.start_at,
|
||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
inputs=event.inputs,
|
||||
@ -513,41 +599,55 @@ class WorkflowBasedAppRunner:
|
||||
metadata=event.metadata,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunIterationNextEvent):
|
||||
elif isinstance(event, IterationRunNextEvent):
|
||||
self._publish_event(
|
||||
QueueIterationNextEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_title=event.node_title,
|
||||
node_execution_id=event.iteration_id,
|
||||
node_id=event.iteration_node_id,
|
||||
node_type=event.iteration_node_type,
|
||||
node_data=event.iteration_node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
index=event.index,
|
||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
output=event.pre_iteration_output,
|
||||
parallel_mode_run_id=event.parallel_mode_run_id,
|
||||
duration=event.duration,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, (NodeRunIterationSucceededEvent | NodeRunIterationFailedEvent)):
|
||||
elif isinstance(event, (IterationRunSucceededEvent | IterationRunFailedEvent)):
|
||||
self._publish_event(
|
||||
QueueIterationCompletedEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_title=event.node_title,
|
||||
node_execution_id=event.iteration_id,
|
||||
node_id=event.iteration_node_id,
|
||||
node_type=event.iteration_node_type,
|
||||
node_data=event.iteration_node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.start_at,
|
||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
inputs=event.inputs,
|
||||
outputs=event.outputs,
|
||||
metadata=event.metadata,
|
||||
steps=event.steps,
|
||||
error=event.error if isinstance(event, NodeRunIterationFailedEvent) else None,
|
||||
error=event.error if isinstance(event, IterationRunFailedEvent) else None,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunLoopStartedEvent):
|
||||
elif isinstance(event, LoopRunStartedEvent):
|
||||
self._publish_event(
|
||||
QueueLoopStartEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_title=event.node_title,
|
||||
node_execution_id=event.loop_id,
|
||||
node_id=event.loop_node_id,
|
||||
node_type=event.loop_node_type,
|
||||
node_data=event.loop_node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.start_at,
|
||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
inputs=event.inputs,
|
||||
@ -555,32 +655,42 @@ class WorkflowBasedAppRunner:
|
||||
metadata=event.metadata,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunLoopNextEvent):
|
||||
elif isinstance(event, LoopRunNextEvent):
|
||||
self._publish_event(
|
||||
QueueLoopNextEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_title=event.node_title,
|
||||
node_execution_id=event.loop_id,
|
||||
node_id=event.loop_node_id,
|
||||
node_type=event.loop_node_type,
|
||||
node_data=event.loop_node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
index=event.index,
|
||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
output=event.pre_loop_output,
|
||||
parallel_mode_run_id=event.parallel_mode_run_id,
|
||||
duration=event.duration,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, (NodeRunLoopSucceededEvent | NodeRunLoopFailedEvent)):
|
||||
elif isinstance(event, (LoopRunSucceededEvent | LoopRunFailedEvent)):
|
||||
self._publish_event(
|
||||
QueueLoopCompletedEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_title=event.node_title,
|
||||
node_execution_id=event.loop_id,
|
||||
node_id=event.loop_node_id,
|
||||
node_type=event.loop_node_type,
|
||||
node_data=event.loop_node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.start_at,
|
||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
inputs=event.inputs,
|
||||
outputs=event.outputs,
|
||||
metadata=event.metadata,
|
||||
steps=event.steps,
|
||||
error=event.error if isinstance(event, NodeRunLoopFailedEvent) else None,
|
||||
error=event.error if isinstance(event, LoopRunFailedEvent) else None,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@ -7,9 +7,11 @@ from pydantic import BaseModel
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities import AgentNodeStrategyInit, GraphRuntimeState
|
||||
from core.workflow.enums import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.entities.node_entities import AgentNodeStrategyInit
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
|
||||
class QueueEvent(StrEnum):
|
||||
@ -41,6 +43,9 @@ class QueueEvent(StrEnum):
|
||||
ANNOTATION_REPLY = "annotation_reply"
|
||||
AGENT_THOUGHT = "agent_thought"
|
||||
MESSAGE_FILE = "message_file"
|
||||
PARALLEL_BRANCH_RUN_STARTED = "parallel_branch_run_started"
|
||||
PARALLEL_BRANCH_RUN_SUCCEEDED = "parallel_branch_run_succeeded"
|
||||
PARALLEL_BRANCH_RUN_FAILED = "parallel_branch_run_failed"
|
||||
AGENT_LOG = "agent_log"
|
||||
ERROR = "error"
|
||||
PING = "ping"
|
||||
@ -75,7 +80,15 @@ class QueueIterationStartEvent(AppQueueEvent):
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_title: str
|
||||
node_data: BaseNodeData
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
start_at: datetime
|
||||
|
||||
node_run_index: int
|
||||
@ -95,9 +108,20 @@ class QueueIterationNextEvent(AppQueueEvent):
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_title: str
|
||||
node_data: BaseNodeData
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
parallel_mode_run_id: Optional[str] = None
|
||||
"""iteration run in parallel mode run id"""
|
||||
node_run_index: int
|
||||
output: Optional[Any] = None # output for the current iteration
|
||||
duration: Optional[float] = None
|
||||
|
||||
|
||||
class QueueIterationCompletedEvent(AppQueueEvent):
|
||||
@ -110,7 +134,15 @@ class QueueIterationCompletedEvent(AppQueueEvent):
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_title: str
|
||||
node_data: BaseNodeData
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
start_at: datetime
|
||||
|
||||
node_run_index: int
|
||||
@ -131,7 +163,7 @@ class QueueLoopStartEvent(AppQueueEvent):
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_title: str
|
||||
node_data: BaseNodeData
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
@ -159,7 +191,7 @@ class QueueLoopNextEvent(AppQueueEvent):
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_title: str
|
||||
node_data: BaseNodeData
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
@ -172,6 +204,7 @@ class QueueLoopNextEvent(AppQueueEvent):
|
||||
"""iteration run in parallel mode run id"""
|
||||
node_run_index: int
|
||||
output: Optional[Any] = None # output for the current loop
|
||||
duration: Optional[float] = None
|
||||
|
||||
|
||||
class QueueLoopCompletedEvent(AppQueueEvent):
|
||||
@ -184,7 +217,7 @@ class QueueLoopCompletedEvent(AppQueueEvent):
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_title: str
|
||||
node_data: BaseNodeData
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
@ -331,24 +364,27 @@ class QueueNodeStartedEvent(AppQueueEvent):
|
||||
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_title: str
|
||||
node_type: NodeType
|
||||
node_run_index: int = 1 # FIXME(-LAN-): may not used
|
||||
node_data: BaseNodeData
|
||||
node_run_index: int = 1
|
||||
predecessor_node_id: Optional[str] = None
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: Optional[str] = None
|
||||
"""loop id if node is in loop"""
|
||||
start_at: datetime
|
||||
parallel_mode_run_id: Optional[str] = None
|
||||
"""iteration run in parallel mode run id"""
|
||||
agent_strategy: Optional[AgentNodeStrategyInit] = None
|
||||
|
||||
# FIXME(-LAN-): only for ToolNode, need to refactor
|
||||
provider_type: str # should be a core.tools.entities.tool_entities.ToolProviderType
|
||||
provider_id: str
|
||||
|
||||
|
||||
class QueueNodeSucceededEvent(AppQueueEvent):
|
||||
"""
|
||||
@ -360,6 +396,7 @@ class QueueNodeSucceededEvent(AppQueueEvent):
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
@ -380,6 +417,10 @@ class QueueNodeSucceededEvent(AppQueueEvent):
|
||||
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
|
||||
|
||||
error: Optional[str] = None
|
||||
"""single iteration duration map"""
|
||||
iteration_duration_map: Optional[dict[str, float]] = None
|
||||
"""single loop duration map"""
|
||||
loop_duration_map: Optional[dict[str, float]] = None
|
||||
|
||||
|
||||
class QueueAgentLogEvent(AppQueueEvent):
|
||||
@ -413,6 +454,72 @@ class QueueNodeRetryEvent(QueueNodeStartedEvent):
|
||||
retry_index: int # retry index
|
||||
|
||||
|
||||
class QueueNodeInIterationFailedEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueNodeInIterationFailedEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.NODE_FAILED
|
||||
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: Optional[str] = None
|
||||
"""loop id if node is in loop"""
|
||||
start_at: datetime
|
||||
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
process_data: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
|
||||
|
||||
error: str
|
||||
|
||||
|
||||
class QueueNodeInLoopFailedEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueNodeInLoopFailedEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.NODE_FAILED
|
||||
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: Optional[str] = None
|
||||
"""loop id if node is in loop"""
|
||||
start_at: datetime
|
||||
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
process_data: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
|
||||
|
||||
error: str
|
||||
|
||||
|
||||
class QueueNodeExceptionEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueNodeExceptionEvent entity
|
||||
@ -423,6 +530,7 @@ class QueueNodeExceptionEvent(AppQueueEvent):
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
@ -455,7 +563,15 @@ class QueueNodeFailedEvent(AppQueueEvent):
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: Optional[str] = None
|
||||
@ -562,3 +678,61 @@ class WorkflowQueueMessage(QueueMessage):
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QueueParallelBranchRunStartedEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueParallelBranchRunStartedEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_STARTED
|
||||
|
||||
parallel_id: str
|
||||
parallel_start_node_id: str
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: Optional[str] = None
|
||||
"""loop id if node is in loop"""
|
||||
|
||||
|
||||
class QueueParallelBranchRunSucceededEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueParallelBranchRunSucceededEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_SUCCEEDED
|
||||
|
||||
parallel_id: str
|
||||
parallel_start_node_id: str
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: Optional[str] = None
|
||||
"""loop id if node is in loop"""
|
||||
|
||||
|
||||
class QueueParallelBranchRunFailedEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueParallelBranchRunFailedEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_FAILED
|
||||
|
||||
parallel_id: str
|
||||
parallel_start_node_id: str
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: Optional[str] = None
|
||||
"""loop id if node is in loop"""
|
||||
error: str
|
||||
|
||||
@ -7,8 +7,8 @@ from pydantic import BaseModel, ConfigDict, Field
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities import AgentNodeStrategyInit
|
||||
from core.workflow.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.entities.node_entities import AgentNodeStrategyInit
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class AnnotationReplyAccount(BaseModel):
|
||||
@ -71,6 +71,8 @@ class StreamEvent(Enum):
|
||||
NODE_STARTED = "node_started"
|
||||
NODE_FINISHED = "node_finished"
|
||||
NODE_RETRY = "node_retry"
|
||||
PARALLEL_BRANCH_STARTED = "parallel_branch_started"
|
||||
PARALLEL_BRANCH_FINISHED = "parallel_branch_finished"
|
||||
ITERATION_STARTED = "iteration_started"
|
||||
ITERATION_NEXT = "iteration_next"
|
||||
ITERATION_COMPLETED = "iteration_completed"
|
||||
@ -438,6 +440,54 @@ class NodeRetryStreamResponse(StreamResponse):
|
||||
}
|
||||
|
||||
|
||||
class ParallelBranchStartStreamResponse(StreamResponse):
|
||||
"""
|
||||
ParallelBranchStartStreamResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
|
||||
parallel_id: str
|
||||
parallel_branch_id: str
|
||||
parent_parallel_id: Optional[str] = None
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
iteration_id: Optional[str] = None
|
||||
loop_id: Optional[str] = None
|
||||
created_at: int
|
||||
|
||||
event: StreamEvent = StreamEvent.PARALLEL_BRANCH_STARTED
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
|
||||
class ParallelBranchFinishedStreamResponse(StreamResponse):
|
||||
"""
|
||||
ParallelBranchFinishedStreamResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
|
||||
parallel_id: str
|
||||
parallel_branch_id: str
|
||||
parent_parallel_id: Optional[str] = None
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
iteration_id: Optional[str] = None
|
||||
loop_id: Optional[str] = None
|
||||
status: str
|
||||
error: Optional[str] = None
|
||||
created_at: int
|
||||
|
||||
event: StreamEvent = StreamEvent.PARALLEL_BRANCH_FINISHED
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
|
||||
class IterationNodeStartStreamResponse(StreamResponse):
|
||||
"""
|
||||
NodeStartStreamResponse entity
|
||||
@ -456,6 +506,8 @@ class IterationNodeStartStreamResponse(StreamResponse):
|
||||
extras: dict = Field(default_factory=dict)
|
||||
metadata: Mapping = {}
|
||||
inputs: Mapping = {}
|
||||
parallel_id: Optional[str] = None
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
|
||||
event: StreamEvent = StreamEvent.ITERATION_STARTED
|
||||
workflow_run_id: str
|
||||
@ -478,7 +530,12 @@ class IterationNodeNextStreamResponse(StreamResponse):
|
||||
title: str
|
||||
index: int
|
||||
created_at: int
|
||||
pre_iteration_output: Optional[Any] = None
|
||||
extras: dict = Field(default_factory=dict)
|
||||
parallel_id: Optional[str] = None
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
parallel_mode_run_id: Optional[str] = None
|
||||
duration: Optional[float] = None
|
||||
|
||||
event: StreamEvent = StreamEvent.ITERATION_NEXT
|
||||
workflow_run_id: str
|
||||
@ -510,6 +567,8 @@ class IterationNodeCompletedStreamResponse(StreamResponse):
|
||||
execution_metadata: Optional[Mapping] = None
|
||||
finished_at: int
|
||||
steps: int
|
||||
parallel_id: Optional[str] = None
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
|
||||
event: StreamEvent = StreamEvent.ITERATION_COMPLETED
|
||||
workflow_run_id: str
|
||||
@ -563,6 +622,7 @@ class LoopNodeNextStreamResponse(StreamResponse):
|
||||
parallel_id: Optional[str] = None
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
parallel_mode_run_id: Optional[str] = None
|
||||
duration: Optional[float] = None
|
||||
|
||||
event: StreamEvent = StreamEvent.LOOP_NEXT
|
||||
workflow_run_id: str
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from extensions.ext_database import db
|
||||
@ -25,9 +27,8 @@ class AnnotationReplyFeature:
|
||||
:param invoke_from: invoke from
|
||||
:return:
|
||||
"""
|
||||
annotation_setting = (
|
||||
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_record.id).first()
|
||||
)
|
||||
stmt = select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_record.id)
|
||||
annotation_setting = db.session.scalar(stmt)
|
||||
|
||||
if not annotation_setting:
|
||||
return None
|
||||
|
||||
@ -96,7 +96,11 @@ class RateLimit:
|
||||
if isinstance(generator, Mapping):
|
||||
return generator
|
||||
else:
|
||||
return RateLimitGenerator(rate_limit=self, generator=generator, request_id=request_id)
|
||||
return RateLimitGenerator(
|
||||
rate_limit=self,
|
||||
generator=generator, # ty: ignore [invalid-argument-type]
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
|
||||
class RateLimitGenerator:
|
||||
|
||||
@ -50,7 +50,7 @@ class BasedGenerateTaskPipeline:
|
||||
if isinstance(e, InvokeAuthorizationError):
|
||||
err = InvokeAuthorizationError("Incorrect API key provided")
|
||||
elif isinstance(e, InvokeError | ValueError):
|
||||
err = e
|
||||
err = e # ty: ignore [invalid-assignment]
|
||||
else:
|
||||
description = getattr(e, "description", None)
|
||||
err = Exception(description if description is not None else str(e))
|
||||
|
||||
@ -472,9 +472,10 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
:param event: agent thought event
|
||||
:return:
|
||||
"""
|
||||
agent_thought: Optional[MessageAgentThought] = (
|
||||
db.session.query(MessageAgentThought).where(MessageAgentThought.id == event.agent_thought_id).first()
|
||||
)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
agent_thought: Optional[MessageAgentThought] = (
|
||||
session.query(MessageAgentThought).where(MessageAgentThought.id == event.agent_thought_id).first()
|
||||
)
|
||||
|
||||
if agent_thought:
|
||||
return AgentThoughtStreamResponse(
|
||||
|
||||
@ -3,6 +3,8 @@ from threading import Thread
|
||||
from typing import Optional, Union
|
||||
|
||||
from flask import Flask, current_app
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
@ -84,7 +86,8 @@ class MessageCycleManager:
|
||||
def _generate_conversation_name_worker(self, flask_app: Flask, conversation_id: str, query: str):
|
||||
with flask_app.app_context():
|
||||
# get conversation and message
|
||||
conversation = db.session.query(Conversation).where(Conversation.id == conversation_id).first()
|
||||
stmt = select(Conversation).where(Conversation.id == conversation_id)
|
||||
conversation = db.session.scalar(stmt)
|
||||
|
||||
if not conversation:
|
||||
return
|
||||
@ -98,7 +101,7 @@ class MessageCycleManager:
|
||||
try:
|
||||
name = LLMGenerator.generate_conversation_name(app_model.tenant_id, query)
|
||||
conversation.name = name
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
if dify_config.DEBUG:
|
||||
logger.exception("generate conversation name failed, conversation_id: %s", conversation_id)
|
||||
pass
|
||||
@ -143,7 +146,8 @@ class MessageCycleManager:
|
||||
:param event: event
|
||||
:return:
|
||||
"""
|
||||
message_file = db.session.query(MessageFile).where(MessageFile.id == event.message_file_id).first()
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
message_file = session.scalar(select(MessageFile).where(MessageFile.id == event.message_file_id))
|
||||
|
||||
if message_file and message_file.url is not None:
|
||||
# get tool file id
|
||||
@ -183,7 +187,8 @@ class MessageCycleManager:
|
||||
:param message_id: message id
|
||||
:return:
|
||||
"""
|
||||
message_file = db.session.query(MessageFile).where(MessageFile.id == message_id).first()
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
message_file = session.scalar(select(MessageFile).where(MessageFile.id == message_id))
|
||||
event_type = StreamEvent.MESSAGE_FILE if message_file else StreamEvent.MESSAGE
|
||||
|
||||
return MessageStreamResponse(
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
|
||||
@ -49,7 +51,8 @@ class DatasetIndexToolCallbackHandler:
|
||||
for document in documents:
|
||||
if document.metadata is not None:
|
||||
document_id = document.metadata["document_id"]
|
||||
dataset_document = db.session.query(DatasetDocument).where(DatasetDocument.id == document_id).first()
|
||||
dataset_document_stmt = select(DatasetDocument).where(DatasetDocument.id == document_id)
|
||||
dataset_document = db.session.scalar(dataset_document_stmt)
|
||||
if not dataset_document:
|
||||
_logger.warning(
|
||||
"Expected DatasetDocument record to exist, but none was found, document_id=%s",
|
||||
@ -57,17 +60,14 @@ class DatasetIndexToolCallbackHandler:
|
||||
)
|
||||
continue
|
||||
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
||||
child_chunk = (
|
||||
db.session.query(ChildChunk)
|
||||
.where(
|
||||
ChildChunk.index_node_id == document.metadata["doc_id"],
|
||||
ChildChunk.dataset_id == dataset_document.dataset_id,
|
||||
ChildChunk.document_id == dataset_document.id,
|
||||
)
|
||||
.first()
|
||||
child_chunk_stmt = select(ChildChunk).where(
|
||||
ChildChunk.index_node_id == document.metadata["doc_id"],
|
||||
ChildChunk.dataset_id == dataset_document.dataset_id,
|
||||
ChildChunk.document_id == dataset_document.id,
|
||||
)
|
||||
child_chunk = db.session.scalar(child_chunk_stmt)
|
||||
if child_chunk:
|
||||
segment = (
|
||||
_ = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.id == child_chunk.segment_id)
|
||||
.update(
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterator, Sequence
|
||||
from json import JSONDecodeError
|
||||
@ -28,6 +29,7 @@ from core.model_runtime.entities.provider_entities import (
|
||||
)
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from core.plugin.entities.plugin import ModelProviderID
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.provider import (
|
||||
@ -40,7 +42,6 @@ from models.provider import (
|
||||
ProviderType,
|
||||
TenantPreferredModelProvider,
|
||||
)
|
||||
from models.provider_ids import ModelProviderID
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -343,7 +344,65 @@ class ProviderConfiguration(BaseModel):
|
||||
with Session(db.engine) as new_session:
|
||||
return _validate(new_session)
|
||||
|
||||
def create_provider_credential(self, credentials: dict, credential_name: str) -> None:
|
||||
def _generate_provider_credential_name(self, session) -> str:
|
||||
"""
|
||||
Generate a unique credential name for provider.
|
||||
:return: credential name
|
||||
"""
|
||||
return self._generate_next_api_key_name(
|
||||
session=session,
|
||||
query_factory=lambda: select(ProviderCredential).where(
|
||||
ProviderCredential.tenant_id == self.tenant_id,
|
||||
ProviderCredential.provider_name == self.provider.provider,
|
||||
),
|
||||
)
|
||||
|
||||
def _generate_custom_model_credential_name(self, model: str, model_type: ModelType, session) -> str:
|
||||
"""
|
||||
Generate a unique credential name for custom model.
|
||||
:return: credential name
|
||||
"""
|
||||
return self._generate_next_api_key_name(
|
||||
session=session,
|
||||
query_factory=lambda: select(ProviderModelCredential).where(
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name == self.provider.provider,
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
),
|
||||
)
|
||||
|
||||
def _generate_next_api_key_name(self, session, query_factory) -> str:
|
||||
"""
|
||||
Generate next available API KEY name by finding the highest numbered suffix.
|
||||
:param session: database session
|
||||
:param query_factory: function that returns the SQLAlchemy query
|
||||
:return: next available API KEY name
|
||||
"""
|
||||
try:
|
||||
stmt = query_factory()
|
||||
credential_records = session.execute(stmt).scalars().all()
|
||||
|
||||
if not credential_records:
|
||||
return "API KEY 1"
|
||||
|
||||
# Extract numbers from API KEY pattern using list comprehension
|
||||
pattern = re.compile(r"^API KEY\s+(\d+)$")
|
||||
numbers = [
|
||||
int(match.group(1))
|
||||
for cr in credential_records
|
||||
if cr.credential_name and (match := pattern.match(cr.credential_name.strip()))
|
||||
]
|
||||
|
||||
# Return next sequential number
|
||||
next_number = max(numbers, default=0) + 1
|
||||
return f"API KEY {next_number}"
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Error generating next credential name: %s", str(e))
|
||||
return "API KEY 1"
|
||||
|
||||
def create_provider_credential(self, credentials: dict, credential_name: str | None) -> None:
|
||||
"""
|
||||
Add custom provider credentials.
|
||||
:param credentials: provider credentials
|
||||
@ -351,8 +410,11 @@ class ProviderConfiguration(BaseModel):
|
||||
:return:
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
if self._check_provider_credential_name_exists(credential_name=credential_name, session=session):
|
||||
raise ValueError(f"Credential with name '{credential_name}' already exists.")
|
||||
if credential_name:
|
||||
if self._check_provider_credential_name_exists(credential_name=credential_name, session=session):
|
||||
raise ValueError(f"Credential with name '{credential_name}' already exists.")
|
||||
else:
|
||||
credential_name = self._generate_provider_credential_name(session)
|
||||
|
||||
credentials = self.validate_provider_credentials(credentials=credentials, session=session)
|
||||
provider_record = self._get_provider_record(session)
|
||||
@ -395,7 +457,7 @@ class ProviderConfiguration(BaseModel):
|
||||
self,
|
||||
credentials: dict,
|
||||
credential_id: str,
|
||||
credential_name: str,
|
||||
credential_name: str | None,
|
||||
) -> None:
|
||||
"""
|
||||
update a saved provider credential (by credential_id).
|
||||
@ -406,7 +468,7 @@ class ProviderConfiguration(BaseModel):
|
||||
:return:
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
if self._check_provider_credential_name_exists(
|
||||
if credential_name and self._check_provider_credential_name_exists(
|
||||
credential_name=credential_name, session=session, exclude_id=credential_id
|
||||
):
|
||||
raise ValueError(f"Credential with name '{credential_name}' already exists.")
|
||||
@ -428,9 +490,9 @@ class ProviderConfiguration(BaseModel):
|
||||
try:
|
||||
# Update credential
|
||||
credential_record.encrypted_config = json.dumps(credentials)
|
||||
credential_record.credential_name = credential_name
|
||||
credential_record.updated_at = naive_utc_now()
|
||||
|
||||
if credential_name:
|
||||
credential_record.credential_name = credential_name
|
||||
session.commit()
|
||||
|
||||
if provider_record and provider_record.credential_id == credential_id:
|
||||
@ -532,13 +594,7 @@ class ProviderConfiguration(BaseModel):
|
||||
cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL,
|
||||
)
|
||||
lb_credentials_cache.delete()
|
||||
|
||||
lb_config.credential_id = None
|
||||
lb_config.encrypted_config = None
|
||||
lb_config.enabled = False
|
||||
lb_config.name = "__delete__"
|
||||
lb_config.updated_at = naive_utc_now()
|
||||
session.add(lb_config)
|
||||
session.delete(lb_config)
|
||||
|
||||
# Check if this is the currently active credential
|
||||
provider_record = self._get_provider_record(session)
|
||||
@ -627,7 +683,6 @@ class ProviderConfiguration(BaseModel):
|
||||
Get custom model credentials.
|
||||
"""
|
||||
# get provider model
|
||||
|
||||
model_provider_id = ModelProviderID(self.provider.provider)
|
||||
provider_names = [self.provider.provider]
|
||||
if model_provider_id.is_langgenius():
|
||||
@ -823,7 +878,7 @@ class ProviderConfiguration(BaseModel):
|
||||
return _validate(new_session)
|
||||
|
||||
def create_custom_model_credential(
|
||||
self, model_type: ModelType, model: str, credentials: dict, credential_name: str
|
||||
self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None
|
||||
) -> None:
|
||||
"""
|
||||
Create a custom model credential.
|
||||
@ -834,10 +889,15 @@ class ProviderConfiguration(BaseModel):
|
||||
:return:
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
if self._check_custom_model_credential_name_exists(
|
||||
model=model, model_type=model_type, credential_name=credential_name, session=session
|
||||
):
|
||||
raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.")
|
||||
if credential_name:
|
||||
if self._check_custom_model_credential_name_exists(
|
||||
model=model, model_type=model_type, credential_name=credential_name, session=session
|
||||
):
|
||||
raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.")
|
||||
else:
|
||||
credential_name = self._generate_custom_model_credential_name(
|
||||
model=model, model_type=model_type, session=session
|
||||
)
|
||||
# validate custom model config
|
||||
credentials = self.validate_custom_model_credentials(
|
||||
model_type=model_type, model=model, credentials=credentials, session=session
|
||||
@ -881,7 +941,7 @@ class ProviderConfiguration(BaseModel):
|
||||
raise
|
||||
|
||||
def update_custom_model_credential(
|
||||
self, model_type: ModelType, model: str, credentials: dict, credential_name: str, credential_id: str
|
||||
self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None, credential_id: str
|
||||
) -> None:
|
||||
"""
|
||||
Update a custom model credential.
|
||||
@ -894,7 +954,7 @@ class ProviderConfiguration(BaseModel):
|
||||
:return:
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
if self._check_custom_model_credential_name_exists(
|
||||
if credential_name and self._check_custom_model_credential_name_exists(
|
||||
model=model,
|
||||
model_type=model_type,
|
||||
credential_name=credential_name,
|
||||
@ -926,8 +986,9 @@ class ProviderConfiguration(BaseModel):
|
||||
try:
|
||||
# Update credential
|
||||
credential_record.encrypted_config = json.dumps(credentials)
|
||||
credential_record.credential_name = credential_name
|
||||
credential_record.updated_at = naive_utc_now()
|
||||
if credential_name:
|
||||
credential_record.credential_name = credential_name
|
||||
session.commit()
|
||||
|
||||
if provider_model_record and provider_model_record.credential_id == credential_id:
|
||||
@ -983,12 +1044,7 @@ class ProviderConfiguration(BaseModel):
|
||||
cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL,
|
||||
)
|
||||
lb_credentials_cache.delete()
|
||||
lb_config.credential_id = None
|
||||
lb_config.encrypted_config = None
|
||||
lb_config.enabled = False
|
||||
lb_config.name = "__delete__"
|
||||
lb_config.updated_at = naive_utc_now()
|
||||
session.add(lb_config)
|
||||
session.delete(lb_config)
|
||||
|
||||
# Check if this is the currently active credential
|
||||
provider_model_record = self._get_custom_model_record(model_type, model, session=session)
|
||||
@ -1055,6 +1111,7 @@ class ProviderConfiguration(BaseModel):
|
||||
provider_name=self.provider.provider,
|
||||
model_name=model,
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
is_valid=True,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
else:
|
||||
@ -1125,7 +1182,6 @@ class ProviderConfiguration(BaseModel):
|
||||
"""
|
||||
Get provider model setting.
|
||||
"""
|
||||
|
||||
model_provider_id = ModelProviderID(self.provider.provider)
|
||||
provider_names = [self.provider.provider]
|
||||
if model_provider_id.is_langgenius():
|
||||
@ -1209,7 +1265,6 @@ class ProviderConfiguration(BaseModel):
|
||||
:param model: model name
|
||||
:return:
|
||||
"""
|
||||
|
||||
model_provider_id = ModelProviderID(self.provider.provider)
|
||||
provider_names = [self.provider.provider]
|
||||
if model_provider_id.is_langgenius():
|
||||
@ -1608,11 +1663,9 @@ class ProviderConfiguration(BaseModel):
|
||||
if config.credential_source_type != "custom_model"
|
||||
]
|
||||
|
||||
if len(provider_model_lb_configs) > 1:
|
||||
load_balancing_enabled = True
|
||||
|
||||
if any(config.name == "__delete__" for config in provider_model_lb_configs):
|
||||
has_invalid_load_balancing_configs = True
|
||||
load_balancing_enabled = model_setting.load_balancing_enabled
|
||||
# when the user enable load_balancing but available configs are less than 2 display warning
|
||||
has_invalid_load_balancing_configs = load_balancing_enabled and len(provider_model_lb_configs) < 2
|
||||
|
||||
provider_models.append(
|
||||
ModelWithProviderEntity(
|
||||
@ -1634,6 +1687,8 @@ class ProviderConfiguration(BaseModel):
|
||||
for model_configuration in self.custom_configuration.models:
|
||||
if model_configuration.model_type not in model_types:
|
||||
continue
|
||||
if model_configuration.unadded_to_model_list:
|
||||
continue
|
||||
if model and model != model_configuration.model:
|
||||
continue
|
||||
try:
|
||||
@ -1666,11 +1721,9 @@ class ProviderConfiguration(BaseModel):
|
||||
if config.credential_source_type != "provider"
|
||||
]
|
||||
|
||||
if len(custom_model_lb_configs) > 1:
|
||||
load_balancing_enabled = True
|
||||
|
||||
if any(config.name == "__delete__" for config in custom_model_lb_configs):
|
||||
has_invalid_load_balancing_configs = True
|
||||
load_balancing_enabled = model_setting.load_balancing_enabled
|
||||
# when the user enable load_balancing but available configs are less than 2 display warning
|
||||
has_invalid_load_balancing_configs = load_balancing_enabled and len(custom_model_lb_configs) < 2
|
||||
|
||||
if len(model_configuration.available_model_credentials) > 0 and not model_configuration.credentials:
|
||||
status = ModelStatus.CREDENTIAL_REMOVED
|
||||
|
||||
@ -111,11 +111,21 @@ class CustomModelConfiguration(BaseModel):
|
||||
current_credential_id: Optional[str] = None
|
||||
current_credential_name: Optional[str] = None
|
||||
available_model_credentials: list[CredentialConfiguration] = []
|
||||
unadded_to_model_list: Optional[bool] = False
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class UnaddedModelConfiguration(BaseModel):
|
||||
"""
|
||||
Model class for provider unadded model configuration.
|
||||
"""
|
||||
|
||||
model: str
|
||||
model_type: ModelType
|
||||
|
||||
|
||||
class CustomConfiguration(BaseModel):
|
||||
"""
|
||||
Model class for provider custom configuration.
|
||||
@ -123,6 +133,7 @@ class CustomConfiguration(BaseModel):
|
||||
|
||||
provider: Optional[CustomProviderConfiguration] = None
|
||||
models: list[CustomModelConfiguration] = []
|
||||
can_added_models: list[UnaddedModelConfiguration] = []
|
||||
|
||||
|
||||
class ModelLoadBalancingConfiguration(BaseModel):
|
||||
@ -144,6 +155,7 @@ class ModelSettings(BaseModel):
|
||||
model: str
|
||||
model_type: ModelType
|
||||
enabled: bool = True
|
||||
load_balancing_enabled: bool = False
|
||||
load_balancing_configs: list[ModelLoadBalancingConfiguration] = []
|
||||
|
||||
# pydantic configs
|
||||
|
||||
@ -43,9 +43,9 @@ class APIBasedExtensionRequestor:
|
||||
timeout=self.timeout,
|
||||
proxies=proxies,
|
||||
)
|
||||
except requests.exceptions.Timeout:
|
||||
except requests.Timeout:
|
||||
raise ValueError("request timeout")
|
||||
except requests.exceptions.ConnectionError:
|
||||
except requests.ConnectionError:
|
||||
raise ValueError("request connection error")
|
||||
|
||||
if response.status_code != 200:
|
||||
|
||||
@ -91,7 +91,7 @@ class Extensible:
|
||||
|
||||
# Find extension class
|
||||
extension_class = None
|
||||
for name, obj in vars(mod).items():
|
||||
for obj in vars(mod).values():
|
||||
if isinstance(obj, type) and issubclass(obj, cls) and obj != cls:
|
||||
extension_class = obj
|
||||
break
|
||||
@ -123,7 +123,7 @@ class Extensible:
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.exception("Error scanning extensions")
|
||||
raise
|
||||
|
||||
|
||||
@ -41,9 +41,3 @@ class Extension:
|
||||
assert module_extension.extension_class is not None
|
||||
t: type = module_extension.extension_class
|
||||
return t
|
||||
|
||||
def validate_form_schema(self, module: ExtensionModule, extension_name: str, config: dict) -> None:
|
||||
module_extension = self.module_extension(module, extension_name)
|
||||
form_schema = module_extension.form_schema
|
||||
|
||||
# TODO validate form_schema
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor
|
||||
from core.external_data_tool.base import ExternalDataTool
|
||||
from core.helper import encrypter
|
||||
@ -28,13 +30,11 @@ class ApiExternalDataTool(ExternalDataTool):
|
||||
api_based_extension_id = config.get("api_based_extension_id")
|
||||
if not api_based_extension_id:
|
||||
raise ValueError("api_based_extension_id is required")
|
||||
|
||||
# get api_based_extension
|
||||
api_based_extension = (
|
||||
db.session.query(APIBasedExtension)
|
||||
.where(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
|
||||
.first()
|
||||
stmt = select(APIBasedExtension).where(
|
||||
APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id
|
||||
)
|
||||
api_based_extension = db.session.scalar(stmt)
|
||||
|
||||
if not api_based_extension:
|
||||
raise ValueError("api_based_extension_id is invalid")
|
||||
@ -52,13 +52,11 @@ class ApiExternalDataTool(ExternalDataTool):
|
||||
raise ValueError(f"config is required, config: {self.config}")
|
||||
api_based_extension_id = self.config.get("api_based_extension_id")
|
||||
assert api_based_extension_id is not None, "api_based_extension_id is required"
|
||||
|
||||
# get api_based_extension
|
||||
api_based_extension = (
|
||||
db.session.query(APIBasedExtension)
|
||||
.where(APIBasedExtension.tenant_id == self.tenant_id, APIBasedExtension.id == api_based_extension_id)
|
||||
.first()
|
||||
stmt = select(APIBasedExtension).where(
|
||||
APIBasedExtension.tenant_id == self.tenant_id, APIBasedExtension.id == api_based_extension_id
|
||||
)
|
||||
api_based_extension = db.session.scalar(stmt)
|
||||
|
||||
if not api_based_extension:
|
||||
raise ValueError(
|
||||
|
||||
@ -22,7 +22,6 @@ class ExternalDataToolFactory:
|
||||
:param config: the form config data
|
||||
:return:
|
||||
"""
|
||||
code_based_extension.validate_form_schema(ExtensionModule.EXTERNAL_DATA_TOOL, name, config)
|
||||
extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name)
|
||||
# FIXME mypy issue here, figure out how to fix it
|
||||
extension_class.validate_config(tenant_id, config) # type: ignore
|
||||
|
||||
@ -3,7 +3,7 @@ import base64
|
||||
from libs import rsa
|
||||
|
||||
|
||||
def obfuscated_token(token: str):
|
||||
def obfuscated_token(token: str) -> str:
|
||||
if not token:
|
||||
return token
|
||||
if len(token) <= 8:
|
||||
@ -11,9 +11,13 @@ def obfuscated_token(token: str):
|
||||
return token[:6] + "*" * 12 + token[-2:]
|
||||
|
||||
|
||||
def full_mask_token(token_length=20):
|
||||
return "*" * token_length
|
||||
|
||||
|
||||
def encrypt_token(tenant_id: str, token: str):
|
||||
from extensions.ext_database import db
|
||||
from models.account import Tenant
|
||||
from models.engine import db
|
||||
|
||||
if not (tenant := db.session.query(Tenant).where(Tenant.id == tenant_id).first()):
|
||||
raise ValueError(f"Tenant with id {tenant_id} not found")
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from collections.abc import Sequence
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
from yarl import URL
|
||||
|
||||
from configs import dify_config
|
||||
@ -23,7 +23,7 @@ def batch_fetch_plugin_manifests(plugin_ids: list[str]) -> Sequence[MarketplaceP
|
||||
return []
|
||||
|
||||
url = str(marketplace_api_url / "api/v1/plugins/batch")
|
||||
response = requests.post(url, json={"plugin_ids": plugin_ids})
|
||||
response = httpx.post(url, json={"plugin_ids": plugin_ids})
|
||||
response.raise_for_status()
|
||||
|
||||
return [MarketplacePluginDeclaration(**plugin) for plugin in response.json()["data"]["plugins"]]
|
||||
@ -36,13 +36,13 @@ def batch_fetch_plugin_manifests_ignore_deserialization_error(
|
||||
return []
|
||||
|
||||
url = str(marketplace_api_url / "api/v1/plugins/batch")
|
||||
response = requests.post(url, json={"plugin_ids": plugin_ids})
|
||||
response = httpx.post(url, json={"plugin_ids": plugin_ids})
|
||||
response.raise_for_status()
|
||||
result: list[MarketplacePluginDeclaration] = []
|
||||
for plugin in response.json()["data"]["plugins"]:
|
||||
try:
|
||||
result.append(MarketplacePluginDeclaration(**plugin))
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return result
|
||||
@ -50,5 +50,5 @@ def batch_fetch_plugin_manifests_ignore_deserialization_error(
|
||||
|
||||
def record_install_plugin_event(plugin_unique_identifier: str):
|
||||
url = str(marketplace_api_url / "api/v1/stats/plugins/install_count")
|
||||
response = requests.post(url, json={"unique_identifier": plugin_unique_identifier})
|
||||
response = httpx.post(url, json={"unique_identifier": plugin_unique_identifier})
|
||||
response.raise_for_status()
|
||||
|
||||
@ -47,7 +47,7 @@ def get_subclasses_from_module(mod: ModuleType, parent_type: type) -> list[type]
|
||||
|
||||
|
||||
def load_single_subclass_from_source(
|
||||
*, module_name: str, script_path: AnyStr, parent_type: type, use_lazy_loader: bool = False
|
||||
*, module_name: str, script_path: str, parent_type: type, use_lazy_loader: bool = False
|
||||
) -> type:
|
||||
"""
|
||||
Load a single subclass from the source
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from typing import TypeVar
|
||||
|
||||
from configs import dify_config
|
||||
from core.tools.utils.yaml_utils import load_yaml_file
|
||||
@ -72,11 +72,14 @@ def pin_position_map(original_position_map: dict[str, int], pin_list: list[str])
|
||||
return position_map
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def is_filtered(
|
||||
include_set: set[str],
|
||||
exclude_set: set[str],
|
||||
data: Any,
|
||||
name_func: Callable[[Any], str],
|
||||
data: T,
|
||||
name_func: Callable[[T], str],
|
||||
) -> bool:
|
||||
"""
|
||||
Check if the object should be filtered out.
|
||||
@ -103,9 +106,9 @@ def is_filtered(
|
||||
|
||||
def sort_by_position_map(
|
||||
position_map: dict[str, int],
|
||||
data: list[Any],
|
||||
name_func: Callable[[Any], str],
|
||||
) -> list[Any]:
|
||||
data: list[T],
|
||||
name_func: Callable[[T], str],
|
||||
):
|
||||
"""
|
||||
Sort the objects by the position map.
|
||||
If the name of the object is not in the position map, it will be put at the end.
|
||||
@ -122,9 +125,9 @@ def sort_by_position_map(
|
||||
|
||||
def sort_to_dict_by_position_map(
|
||||
position_map: dict[str, int],
|
||||
data: list[Any],
|
||||
name_func: Callable[[Any], str],
|
||||
) -> OrderedDict[str, Any]:
|
||||
data: list[T],
|
||||
name_func: Callable[[T], str],
|
||||
):
|
||||
"""
|
||||
Sort the objects into a ordered dict by the position map.
|
||||
If the name of the object is not in the position map, it will be put at the end.
|
||||
@ -134,4 +137,4 @@ def sort_to_dict_by_position_map(
|
||||
:return: an OrderedDict with the sorted pairs of name and object
|
||||
"""
|
||||
sorted_items = sort_by_position_map(position_map, data, name_func)
|
||||
return OrderedDict([(name_func(item), item) for item in sorted_items])
|
||||
return OrderedDict((name_func(item), item) for item in sorted_items)
|
||||
|
||||
@ -5,9 +5,10 @@ import re
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Optional, cast
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask import current_app
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm.exc import ObjectDeletedError
|
||||
|
||||
from configs import dify_config
|
||||
@ -18,6 +19,7 @@ from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.rag.cleaner.clean_processor import CleanProcessor
|
||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
@ -56,13 +58,11 @@ class IndexingRunner:
|
||||
|
||||
if not dataset:
|
||||
raise ValueError("no dataset found")
|
||||
|
||||
# get the process rule
|
||||
processing_rule = (
|
||||
db.session.query(DatasetProcessRule)
|
||||
.where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
|
||||
.first()
|
||||
stmt = select(DatasetProcessRule).where(
|
||||
DatasetProcessRule.id == dataset_document.dataset_process_rule_id
|
||||
)
|
||||
processing_rule = db.session.scalar(stmt)
|
||||
if not processing_rule:
|
||||
raise ValueError("no process rule found")
|
||||
index_type = dataset_document.doc_form
|
||||
@ -123,11 +123,8 @@ class IndexingRunner:
|
||||
db.session.query(ChildChunk).where(ChildChunk.segment_id == document_segment.id).delete()
|
||||
db.session.commit()
|
||||
# get the process rule
|
||||
processing_rule = (
|
||||
db.session.query(DatasetProcessRule)
|
||||
.where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
|
||||
.first()
|
||||
)
|
||||
stmt = select(DatasetProcessRule).where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
|
||||
processing_rule = db.session.scalar(stmt)
|
||||
if not processing_rule:
|
||||
raise ValueError("no process rule found")
|
||||
|
||||
@ -208,7 +205,6 @@ class IndexingRunner:
|
||||
child_documents.append(child_document)
|
||||
document.children = child_documents
|
||||
documents.append(document)
|
||||
|
||||
# build index
|
||||
index_type = dataset_document.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
@ -310,7 +306,8 @@ class IndexingRunner:
|
||||
# delete image files and related db records
|
||||
image_upload_file_ids = get_image_upload_file_ids(document.page_content)
|
||||
for upload_file_id in image_upload_file_ids:
|
||||
image_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
|
||||
stmt = select(UploadFile).where(UploadFile.id == upload_file_id)
|
||||
image_file = db.session.scalar(stmt)
|
||||
if image_file is None:
|
||||
continue
|
||||
try:
|
||||
@ -339,14 +336,14 @@ class IndexingRunner:
|
||||
if dataset_document.data_source_type == "upload_file":
|
||||
if not data_source_info or "upload_file_id" not in data_source_info:
|
||||
raise ValueError("no upload file found")
|
||||
|
||||
file_detail = (
|
||||
db.session.query(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"]).one_or_none()
|
||||
)
|
||||
stmt = select(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"])
|
||||
file_detail = db.session.scalars(stmt).one_or_none()
|
||||
|
||||
if file_detail:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="upload_file", upload_file=file_detail, document_model=dataset_document.doc_form
|
||||
datasource_type=DatasourceType.FILE.value,
|
||||
upload_file=file_detail,
|
||||
document_model=dataset_document.doc_form,
|
||||
)
|
||||
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
|
||||
elif dataset_document.data_source_type == "notion_import":
|
||||
@ -357,7 +354,7 @@ class IndexingRunner:
|
||||
):
|
||||
raise ValueError("no notion import info found")
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="notion_import",
|
||||
datasource_type=DatasourceType.NOTION.value,
|
||||
notion_info={
|
||||
"notion_workspace_id": data_source_info["notion_workspace_id"],
|
||||
"notion_obj_id": data_source_info["notion_page_id"],
|
||||
@ -377,7 +374,7 @@ class IndexingRunner:
|
||||
):
|
||||
raise ValueError("no website import info found")
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="website_crawl",
|
||||
datasource_type=DatasourceType.WEBSITE.value,
|
||||
website_info={
|
||||
"provider": data_source_info["provider"],
|
||||
"job_id": data_source_info["job_id"],
|
||||
@ -400,7 +397,6 @@ class IndexingRunner:
|
||||
)
|
||||
|
||||
# replace doc id to document model id
|
||||
text_docs = cast(list[Document], text_docs)
|
||||
for text_doc in text_docs:
|
||||
if text_doc.metadata is not None:
|
||||
text_doc.metadata["document_id"] = dataset_document.id
|
||||
|
||||
@ -28,9 +28,8 @@ from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.ops.utils import measure_time
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.node_events import AgentLogEvent
|
||||
from extensions.ext_database import db
|
||||
from models import App, Message, WorkflowNodeExecutionModel
|
||||
from core.workflow.graph_engine.entities.event import AgentLogEvent
|
||||
from models import App, Message, WorkflowNodeExecutionModel, db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -57,11 +56,8 @@ class LLMGenerator:
|
||||
prompts = [UserPromptMessage(content=prompt)]
|
||||
|
||||
with measure_time() as timer:
|
||||
response = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=list(prompts), model_parameters={"max_tokens": 500, "temperature": 1}, stream=False
|
||||
),
|
||||
response: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompts), model_parameters={"max_tokens": 500, "temperature": 1}, stream=False
|
||||
)
|
||||
answer = cast(str, response.message.content)
|
||||
cleaned_answer = re.sub(r"^.*(\{.*\}).*$", r"\1", answer, flags=re.DOTALL)
|
||||
@ -70,7 +66,7 @@ class LLMGenerator:
|
||||
try:
|
||||
result_dict = json.loads(cleaned_answer)
|
||||
answer = result_dict["Your Output"]
|
||||
except json.JSONDecodeError as e:
|
||||
except json.JSONDecodeError:
|
||||
logger.exception("Failed to generate name after answer, use query instead")
|
||||
answer = query
|
||||
name = answer.strip()
|
||||
@ -114,13 +110,10 @@ class LLMGenerator:
|
||||
prompt_messages = [UserPromptMessage(content=prompt)]
|
||||
|
||||
try:
|
||||
response = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages),
|
||||
model_parameters={"max_tokens": 256, "temperature": 0},
|
||||
stream=False,
|
||||
),
|
||||
response: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages),
|
||||
model_parameters={"max_tokens": 256, "temperature": 0},
|
||||
stream=False,
|
||||
)
|
||||
|
||||
text_content = response.message.get_text_content()
|
||||
@ -163,11 +156,8 @@ class LLMGenerator:
|
||||
)
|
||||
|
||||
try:
|
||||
response = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
),
|
||||
response: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
)
|
||||
|
||||
rule_config["prompt"] = cast(str, response.message.content)
|
||||
@ -213,11 +203,8 @@ class LLMGenerator:
|
||||
try:
|
||||
try:
|
||||
# the first step to generate the task prompt
|
||||
prompt_content = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
),
|
||||
prompt_content: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
)
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
@ -249,11 +236,8 @@ class LLMGenerator:
|
||||
statement_messages = [UserPromptMessage(content=statement_generate_prompt)]
|
||||
|
||||
try:
|
||||
parameter_content = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=list(parameter_messages), model_parameters=model_parameters, stream=False
|
||||
),
|
||||
parameter_content: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(parameter_messages), model_parameters=model_parameters, stream=False
|
||||
)
|
||||
rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', cast(str, parameter_content.message.content))
|
||||
except InvokeError as e:
|
||||
@ -261,11 +245,8 @@ class LLMGenerator:
|
||||
error_step = "generate variables"
|
||||
|
||||
try:
|
||||
statement_content = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=list(statement_messages), model_parameters=model_parameters, stream=False
|
||||
),
|
||||
statement_content: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(statement_messages), model_parameters=model_parameters, stream=False
|
||||
)
|
||||
rule_config["opening_statement"] = cast(str, statement_content.message.content)
|
||||
except InvokeError as e:
|
||||
@ -308,11 +289,8 @@ class LLMGenerator:
|
||||
prompt_messages = [UserPromptMessage(content=prompt)]
|
||||
model_parameters = model_config.get("completion_params", {})
|
||||
try:
|
||||
response = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
),
|
||||
response: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
)
|
||||
|
||||
generated_code = cast(str, response.message.content)
|
||||
@ -339,13 +317,10 @@ class LLMGenerator:
|
||||
|
||||
prompt_messages = [SystemPromptMessage(content=prompt), UserPromptMessage(content=query)]
|
||||
|
||||
response = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters={"temperature": 0.01, "max_tokens": 2000},
|
||||
stream=False,
|
||||
),
|
||||
response: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters={"temperature": 0.01, "max_tokens": 2000},
|
||||
stream=False,
|
||||
)
|
||||
|
||||
answer = cast(str, response.message.content)
|
||||
@ -368,11 +343,8 @@ class LLMGenerator:
|
||||
model_parameters = model_config.get("model_parameters", {})
|
||||
|
||||
try:
|
||||
response = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
),
|
||||
response: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
)
|
||||
|
||||
raw_content = response.message.content
|
||||
@ -556,11 +528,8 @@ class LLMGenerator:
|
||||
model_parameters = {"temperature": 0.4}
|
||||
|
||||
try:
|
||||
response = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
),
|
||||
response: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
)
|
||||
|
||||
generated_raw = cast(str, response.message.content)
|
||||
|
||||
@ -101,7 +101,7 @@ def handle_callback(state_key: str, authorization_code: str) -> OAuthCallbackSta
|
||||
|
||||
def check_support_resource_discovery(server_url: str) -> tuple[bool, str]:
|
||||
"""Check if the server supports OAuth 2.0 Resource Discovery."""
|
||||
b_scheme, b_netloc, b_path, b_params, b_query, b_fragment = urlparse(server_url, "", True)
|
||||
b_scheme, b_netloc, b_path, _, b_query, b_fragment = urlparse(server_url, "", True)
|
||||
url_for_resource_discovery = f"{b_scheme}://{b_netloc}/.well-known/oauth-protected-resource{b_path}"
|
||||
if b_query:
|
||||
url_for_resource_discovery += f"?{b_query}"
|
||||
@ -117,7 +117,7 @@ def check_support_resource_discovery(server_url: str) -> tuple[bool, str]:
|
||||
else:
|
||||
return False, ""
|
||||
return False, ""
|
||||
except httpx.RequestError as e:
|
||||
except httpx.RequestError:
|
||||
# Not support resource discovery, fall back to well-known OAuth metadata
|
||||
return False, ""
|
||||
|
||||
|
||||
@ -246,6 +246,10 @@ class StreamableHTTPTransport:
|
||||
logger.debug("Received 202 Accepted")
|
||||
return
|
||||
|
||||
if response.status_code == 204:
|
||||
logger.debug("Received 204 No Content")
|
||||
return
|
||||
|
||||
if response.status_code == 404:
|
||||
if isinstance(message.root, JSONRPCRequest):
|
||||
self._send_session_terminated_error(
|
||||
|
||||
@ -2,7 +2,7 @@ import logging
|
||||
from collections.abc import Callable
|
||||
from contextlib import AbstractContextManager, ExitStack
|
||||
from types import TracebackType
|
||||
from typing import Any, Optional, cast
|
||||
from typing import Any, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from core.mcp.client.sse_client import sse_client
|
||||
@ -116,8 +116,7 @@ class MCPClient:
|
||||
|
||||
self._session_context = ClientSession(*streams)
|
||||
self._session = self._exit_stack.enter_context(self._session_context)
|
||||
session = cast(ClientSession, self._session)
|
||||
session.initialize()
|
||||
self._session.initialize()
|
||||
return
|
||||
|
||||
except MCPAuthError:
|
||||
|
||||
@ -110,9 +110,9 @@ class TokenBufferMemory:
|
||||
else:
|
||||
message_limit = 500
|
||||
|
||||
stmt = stmt.limit(message_limit)
|
||||
msg_limit_stmt = stmt.limit(message_limit)
|
||||
|
||||
messages = db.session.scalars(stmt).all()
|
||||
messages = db.session.scalars(msg_limit_stmt).all()
|
||||
|
||||
# instead of all messages from the conversation, we only need to extract messages
|
||||
# that belong to the thread of last message
|
||||
@ -167,11 +167,11 @@ class TokenBufferMemory:
|
||||
else:
|
||||
prompt_messages.append(AssistantPromptMessage(content=message.answer))
|
||||
|
||||
if not prompt_messages:
|
||||
return []
|
||||
if not prompt_messages:
|
||||
return []
|
||||
|
||||
# prune the chat message if it exceeds the max token limit
|
||||
curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
|
||||
# prune the chat message if it exceeds the max token limit
|
||||
curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
if curr_message_tokens > max_token_limit:
|
||||
while curr_message_tokens > max_token_limit and len(prompt_messages) > 1:
|
||||
|
||||
@ -158,8 +158,6 @@ class ModelInstance:
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, LargeLanguageModel):
|
||||
raise Exception("Model type instance is not LargeLanguageModel")
|
||||
|
||||
self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
|
||||
return cast(
|
||||
Union[LLMResult, Generator],
|
||||
self._round_robin_invoke(
|
||||
@ -188,8 +186,6 @@ class ModelInstance:
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, LargeLanguageModel):
|
||||
raise Exception("Model type instance is not LargeLanguageModel")
|
||||
|
||||
self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
|
||||
return cast(
|
||||
int,
|
||||
self._round_robin_invoke(
|
||||
@ -214,8 +210,6 @@ class ModelInstance:
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, TextEmbeddingModel):
|
||||
raise Exception("Model type instance is not TextEmbeddingModel")
|
||||
|
||||
self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
|
||||
return cast(
|
||||
TextEmbeddingResult,
|
||||
self._round_robin_invoke(
|
||||
@ -237,8 +231,6 @@ class ModelInstance:
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, TextEmbeddingModel):
|
||||
raise Exception("Model type instance is not TextEmbeddingModel")
|
||||
|
||||
self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
|
||||
return cast(
|
||||
list[int],
|
||||
self._round_robin_invoke(
|
||||
@ -269,8 +261,6 @@ class ModelInstance:
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, RerankModel):
|
||||
raise Exception("Model type instance is not RerankModel")
|
||||
|
||||
self.model_type_instance = cast(RerankModel, self.model_type_instance)
|
||||
return cast(
|
||||
RerankResult,
|
||||
self._round_robin_invoke(
|
||||
@ -295,8 +285,6 @@ class ModelInstance:
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, ModerationModel):
|
||||
raise Exception("Model type instance is not ModerationModel")
|
||||
|
||||
self.model_type_instance = cast(ModerationModel, self.model_type_instance)
|
||||
return cast(
|
||||
bool,
|
||||
self._round_robin_invoke(
|
||||
@ -318,8 +306,6 @@ class ModelInstance:
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, Speech2TextModel):
|
||||
raise Exception("Model type instance is not Speech2TextModel")
|
||||
|
||||
self.model_type_instance = cast(Speech2TextModel, self.model_type_instance)
|
||||
return cast(
|
||||
str,
|
||||
self._round_robin_invoke(
|
||||
@ -343,8 +329,6 @@ class ModelInstance:
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, TTSModel):
|
||||
raise Exception("Model type instance is not TTSModel")
|
||||
|
||||
self.model_type_instance = cast(TTSModel, self.model_type_instance)
|
||||
return cast(
|
||||
Iterable[bytes],
|
||||
self._round_robin_invoke(
|
||||
@ -404,8 +388,6 @@ class ModelInstance:
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, TTSModel):
|
||||
raise Exception("Model type instance is not TTSModel")
|
||||
|
||||
self.model_type_instance = cast(TTSModel, self.model_type_instance)
|
||||
return self.model_type_instance.get_tts_model_voices(
|
||||
model=self.model, credentials=self.credentials, language=language
|
||||
)
|
||||
|
||||
@ -24,7 +24,8 @@ from core.model_runtime.errors.invoke import (
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
from core.plugin.entities.plugin_daemon import PluginDaemonInnerError, PluginModelProviderEntity
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
|
||||
class AIModel(BaseModel):
|
||||
@ -52,8 +53,6 @@ class AIModel(BaseModel):
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
from core.plugin.entities.plugin_daemon import PluginDaemonInnerError
|
||||
|
||||
return {
|
||||
InvokeConnectionError: [InvokeConnectionError],
|
||||
InvokeServerUnavailableError: [InvokeServerUnavailableError],
|
||||
@ -141,8 +140,6 @@ class AIModel(BaseModel):
|
||||
:param credentials: model credentials
|
||||
:return: model schema
|
||||
"""
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
cache_key = f"{self.tenant_id}:{self.plugin_id}:{self.provider_name}:{self.model_type.value}:{model}"
|
||||
# sort credentials
|
||||
|
||||
@ -22,6 +22,7 @@ from core.model_runtime.entities.model_entities import (
|
||||
PriceType,
|
||||
)
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -141,8 +142,6 @@ class LargeLanguageModel(AIModel):
|
||||
result: Union[LLMResult, Generator[LLMResultChunk, None, None]]
|
||||
|
||||
try:
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
result = plugin_model_manager.invoke_llm(
|
||||
tenant_id=self.tenant_id,
|
||||
@ -341,8 +340,6 @@ class LargeLanguageModel(AIModel):
|
||||
:return:
|
||||
"""
|
||||
if dify_config.PLUGIN_BASED_TOKEN_COUNTING_ENABLED:
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
return plugin_model_manager.get_llm_num_tokens(
|
||||
tenant_id=self.tenant_id,
|
||||
|
||||
@ -5,6 +5,7 @@ from pydantic import ConfigDict
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
|
||||
class ModerationModel(AIModel):
|
||||
@ -30,8 +31,6 @@ class ModerationModel(AIModel):
|
||||
self.started_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
return plugin_model_manager.invoke_moderation(
|
||||
tenant_id=self.tenant_id,
|
||||
|
||||
@ -3,6 +3,7 @@ from typing import Optional
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.entities.rerank_entities import RerankResult
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
|
||||
class RerankModel(AIModel):
|
||||
@ -35,8 +36,6 @@ class RerankModel(AIModel):
|
||||
:return: rerank result
|
||||
"""
|
||||
try:
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
return plugin_model_manager.invoke_rerank(
|
||||
tenant_id=self.tenant_id,
|
||||
|
||||
@ -4,6 +4,7 @@ from pydantic import ConfigDict
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
|
||||
class Speech2TextModel(AIModel):
|
||||
@ -27,8 +28,6 @@ class Speech2TextModel(AIModel):
|
||||
:return: text for given audio file
|
||||
"""
|
||||
try:
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
return plugin_model_manager.invoke_speech_to_text(
|
||||
tenant_id=self.tenant_id,
|
||||
|
||||
@ -6,6 +6,7 @@ from core.entities.embedding_type import EmbeddingInputType
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
|
||||
class TextEmbeddingModel(AIModel):
|
||||
@ -36,8 +37,6 @@ class TextEmbeddingModel(AIModel):
|
||||
:param input_type: input type
|
||||
:return: embeddings result
|
||||
"""
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
try:
|
||||
plugin_model_manager = PluginModelClient()
|
||||
return plugin_model_manager.invoke_text_embedding(
|
||||
@ -62,8 +61,6 @@ class TextEmbeddingModel(AIModel):
|
||||
:param texts: texts to embed
|
||||
:return:
|
||||
"""
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
return plugin_model_manager.get_text_embedding_num_tokens(
|
||||
tenant_id=self.tenant_id,
|
||||
|
||||
@ -6,6 +6,7 @@ from pydantic import ConfigDict
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -41,8 +42,6 @@ class TTSModel(AIModel):
|
||||
:return: translated audio file
|
||||
"""
|
||||
try:
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
return plugin_model_manager.invoke_tts(
|
||||
tenant_id=self.tenant_id,
|
||||
@ -66,8 +65,6 @@ class TTSModel(AIModel):
|
||||
:param credentials: The credentials required to access the TTS model.
|
||||
:return: A list of voices supported by the TTS model.
|
||||
"""
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
return plugin_model_manager.get_tts_model_voices(
|
||||
tenant_id=self.tenant_id,
|
||||
|
||||
@ -20,8 +20,10 @@ from core.model_runtime.model_providers.__base.text_embedding_model import TextE
|
||||
from core.model_runtime.model_providers.__base.tts_model import TTSModel
|
||||
from core.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator
|
||||
from core.model_runtime.schema_validators.provider_credential_schema_validator import ProviderCredentialSchemaValidator
|
||||
from core.plugin.entities.plugin import ModelProviderID
|
||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
from models.provider_ids import ModelProviderID
|
||||
from core.plugin.impl.asset import PluginAssetManager
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -35,8 +37,6 @@ class ModelProviderFactory:
|
||||
provider_position_map: dict[str, int]
|
||||
|
||||
def __init__(self, tenant_id: str) -> None:
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
self.provider_position_map = {}
|
||||
|
||||
self.tenant_id = tenant_id
|
||||
@ -71,7 +71,7 @@ class ModelProviderFactory:
|
||||
|
||||
return [extension.plugin_model_provider_entity.declaration for extension in sorted_extensions.values()]
|
||||
|
||||
def get_plugin_model_providers(self) -> Sequence["PluginModelProviderEntity"]:
|
||||
def get_plugin_model_providers(self) -> Sequence[PluginModelProviderEntity]:
|
||||
"""
|
||||
Get all plugin model providers
|
||||
:return: list of plugin model providers
|
||||
@ -109,7 +109,7 @@ class ModelProviderFactory:
|
||||
plugin_model_provider_entity = self.get_plugin_model_provider(provider=provider)
|
||||
return plugin_model_provider_entity.declaration
|
||||
|
||||
def get_plugin_model_provider(self, provider: str) -> "PluginModelProviderEntity":
|
||||
def get_plugin_model_provider(self, provider: str) -> PluginModelProviderEntity:
|
||||
"""
|
||||
Get plugin model provider
|
||||
:param provider: provider name
|
||||
@ -366,8 +366,6 @@ class ModelProviderFactory:
|
||||
mime_type = image_mime_types.get(extension, "image/png")
|
||||
|
||||
# get icon bytes from plugin asset manager
|
||||
from core.plugin.impl.asset import PluginAssetManager
|
||||
|
||||
plugin_asset_manager = PluginAssetManager()
|
||||
return plugin_asset_manager.fetch_asset(tenant_id=self.tenant_id, id=file_name), mime_type
|
||||
|
||||
@ -377,6 +375,5 @@ class ModelProviderFactory:
|
||||
:param provider: provider name
|
||||
:return: plugin id and provider name
|
||||
"""
|
||||
|
||||
provider_id = ModelProviderID(provider)
|
||||
return provider_id.plugin_id, provider_id.provider_name
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.extension.api_based_extension_requestor import APIBasedExtensionPoint, APIBasedExtensionRequestor
|
||||
from core.helper.encrypter import decrypt_token
|
||||
@ -87,10 +88,9 @@ class ApiModeration(Moderation):
|
||||
|
||||
@staticmethod
|
||||
def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> Optional[APIBasedExtension]:
|
||||
extension = (
|
||||
db.session.query(APIBasedExtension)
|
||||
.where(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
|
||||
.first()
|
||||
stmt = select(APIBasedExtension).where(
|
||||
APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id
|
||||
)
|
||||
extension = db.session.scalar(stmt)
|
||||
|
||||
return extension
|
||||
|
||||
@ -20,7 +20,6 @@ class ModerationFactory:
|
||||
:param config: the form config data
|
||||
:return:
|
||||
"""
|
||||
code_based_extension.validate_form_schema(ExtensionModule.MODERATION, name, config)
|
||||
extension_class = code_based_extension.extension_class(ExtensionModule.MODERATION, name)
|
||||
# FIXME: mypy error, try to fix it instead of using type: ignore
|
||||
extension_class.validate_config(tenant_id, config) # type: ignore
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user