mirror of
https://github.com/langgenius/dify.git
synced 2026-01-28 15:56:00 +08:00
Compare commits
33 Commits
refactor/a
...
fix/trigge
| Author | SHA1 | Date | |
|---|---|---|---|
| 354d0e2038 | |||
| 0c495c5d75 | |||
| 186f89a9c7 | |||
| e48419937b | |||
| 5eaf0c733a | |||
| f561656a89 | |||
| f01f555146 | |||
| 47d0e400ae | |||
| 8724ba04aa | |||
| 6fd001c660 | |||
| e8e386a6b9 | |||
| eba5eac3fa | |||
| 19008dce13 | |||
| 92011d0a31 | |||
| a51ced0a4f | |||
| dad8e408b0 | |||
| d941201a3e | |||
| dd988d42c2 | |||
| a43d2ec4f0 | |||
| 7c12e923b6 | |||
| b9f1d65d4f | |||
| b4e2af96e2 | |||
| 9d38af6d99 | |||
| 0772d49257 | |||
| 67eb8c052d | |||
| 5c4028d557 | |||
| 55e6bca11c | |||
| 67657c2f48 | |||
| e8f9d64651 | |||
| 1f8c730259 | |||
| 8d45755303 | |||
| 6342d196e8 | |||
| 5dc5709d58 |
@ -0,0 +1,27 @@
|
||||
# Notes: `large_language_model.py`
|
||||
|
||||
## Purpose
|
||||
|
||||
Provides the base `LargeLanguageModel` implementation used by the model runtime to invoke plugin-backed LLMs and to
|
||||
bridge plugin daemon streaming semantics back into API-layer entities (`LLMResult`, `LLMResultChunk`).
|
||||
|
||||
## Key behaviors / invariants
|
||||
|
||||
- `invoke(..., stream=False)` still calls the plugin in streaming mode and then synthesizes a single `LLMResult` from
|
||||
the first yielded `LLMResultChunk`.
|
||||
- Plugin invocation is wrapped by `_invoke_llm_via_plugin(...)`, and `stream=False` normalization is handled by
|
||||
`_normalize_non_stream_plugin_result(...)` / `_build_llm_result_from_first_chunk(...)`.
|
||||
- Tool call deltas are merged incrementally via `_increase_tool_call(...)` to support multiple provider chunking
|
||||
patterns (IDs anchored to first chunk, every chunk, or missing entirely).
|
||||
- A tool-call delta with an empty `id` requires at least one existing tool call; otherwise we raise `ValueError` to
|
||||
surface invalid delta sequences explicitly.
|
||||
- Callback invocation is centralized in `_run_callbacks(...)` to ensure consistent error handling/logging.
|
||||
- For compatibility with dify issue `#17799`, `prompt_messages` may be removed by the plugin daemon in chunks and must
|
||||
be re-attached in this layer before callbacks/consumers use them.
|
||||
- Callback hooks (`on_before_invoke`, `on_new_chunk`, `on_after_invoke`, `on_invoke_error`) must not break invocation
|
||||
unless `callback.raise_error` is true.
|
||||
|
||||
## Test focus
|
||||
|
||||
- `api/tests/unit_tests/core/model_runtime/__base/test_increase_tool_call.py` validates tool-call delta merging and
|
||||
patches `_gen_tool_call_id` for deterministic IDs.
|
||||
171
api/README.md
171
api/README.md
@ -1,6 +1,6 @@
|
||||
# Dify Backend API
|
||||
|
||||
## Usage
|
||||
## Setup and Run
|
||||
|
||||
> [!IMPORTANT]
|
||||
>
|
||||
@ -8,48 +8,77 @@
|
||||
> [`uv`](https://docs.astral.sh/uv/) as the package manager
|
||||
> for Dify API backend service.
|
||||
|
||||
1. Start the docker-compose stack
|
||||
`uv` and `pnpm` are required to run the setup and development commands below.
|
||||
|
||||
The backend require some middleware, including PostgreSQL, Redis, and Weaviate, which can be started together using `docker-compose`.
|
||||
### Using scripts (recommended)
|
||||
|
||||
The scripts resolve paths relative to their location, so you can run them from anywhere.
|
||||
|
||||
1. Run setup (copies env files and installs dependencies).
|
||||
|
||||
```bash
|
||||
cd ../docker
|
||||
cp middleware.env.example middleware.env
|
||||
# change the profile to mysql if you are not using postgres,change the profile to other vector database if you are not using weaviate
|
||||
docker compose -f docker-compose.middleware.yaml --profile postgresql --profile weaviate -p dify up -d
|
||||
cd ../api
|
||||
./dev/setup
|
||||
```
|
||||
|
||||
1. Copy `.env.example` to `.env`
|
||||
1. Review `api/.env`, `web/.env.local`, and `docker/middleware.env` values (see the `SECRET_KEY` note below).
|
||||
|
||||
```cli
|
||||
cp .env.example .env
|
||||
1. Start middleware (PostgreSQL/Redis/Weaviate).
|
||||
|
||||
```bash
|
||||
./dev/start-docker-compose
|
||||
```
|
||||
|
||||
> [!IMPORTANT]
|
||||
>
|
||||
> When the frontend and backend run on different subdomains, set COOKIE_DOMAIN to the site’s top-level domain (e.g., `example.com`). The frontend and backend must be under the same top-level domain in order to share authentication cookies.
|
||||
1. Start backend (runs migrations first).
|
||||
|
||||
1. Generate a `SECRET_KEY` in the `.env` file.
|
||||
|
||||
bash for Linux
|
||||
|
||||
```bash for Linux
|
||||
sed -i "/^SECRET_KEY=/c\SECRET_KEY=$(openssl rand -base64 42)" .env
|
||||
```bash
|
||||
./dev/start-api
|
||||
```
|
||||
|
||||
bash for Mac
|
||||
1. Start Dify [web](../web) service.
|
||||
|
||||
```bash for Mac
|
||||
secret_key=$(openssl rand -base64 42)
|
||||
sed -i '' "/^SECRET_KEY=/c\\
|
||||
SECRET_KEY=${secret_key}" .env
|
||||
```bash
|
||||
./dev/start-web
|
||||
```
|
||||
|
||||
1. Create environment.
|
||||
1. Set up your application by visiting `http://localhost:3000`.
|
||||
|
||||
Dify API service uses [UV](https://docs.astral.sh/uv/) to manage dependencies.
|
||||
First, you need to add the uv package manager, if you don't have it already.
|
||||
1. Optional: start the worker service (async tasks, runs from `api`).
|
||||
|
||||
```bash
|
||||
./dev/start-worker
|
||||
```
|
||||
|
||||
1. Optional: start Celery Beat (scheduled tasks).
|
||||
|
||||
```bash
|
||||
./dev/start-beat
|
||||
```
|
||||
|
||||
### Manual commands
|
||||
|
||||
<details>
|
||||
<summary>Show manual setup and run steps</summary>
|
||||
|
||||
These commands assume you start from the repository root.
|
||||
|
||||
1. Start the docker-compose stack.
|
||||
|
||||
The backend requires middleware, including PostgreSQL, Redis, and Weaviate, which can be started together using `docker-compose`.
|
||||
|
||||
```bash
|
||||
cp docker/middleware.env.example docker/middleware.env
|
||||
# Use mysql or another vector database profile if you are not using postgres/weaviate.
|
||||
docker compose -f docker/docker-compose.middleware.yaml --profile postgresql --profile weaviate -p dify up -d
|
||||
```
|
||||
|
||||
1. Copy env files.
|
||||
|
||||
```bash
|
||||
cp api/.env.example api/.env
|
||||
cp web/.env.example web/.env.local
|
||||
```
|
||||
|
||||
1. Install UV if needed.
|
||||
|
||||
```bash
|
||||
pip install uv
|
||||
@ -57,60 +86,96 @@
|
||||
brew install uv
|
||||
```
|
||||
|
||||
1. Install dependencies
|
||||
1. Install API dependencies.
|
||||
|
||||
```bash
|
||||
uv sync --dev
|
||||
cd api
|
||||
uv sync --group dev
|
||||
```
|
||||
|
||||
1. Run migrate
|
||||
|
||||
Before the first launch, migrate the database to the latest version.
|
||||
1. Install web dependencies.
|
||||
|
||||
```bash
|
||||
cd web
|
||||
pnpm install
|
||||
cd ..
|
||||
```
|
||||
|
||||
1. Start backend (runs migrations first, in a new terminal).
|
||||
|
||||
```bash
|
||||
cd api
|
||||
uv run flask db upgrade
|
||||
```
|
||||
|
||||
1. Start backend
|
||||
|
||||
```bash
|
||||
uv run flask run --host 0.0.0.0 --port=5001 --debug
|
||||
```
|
||||
|
||||
1. Start Dify [web](../web) service.
|
||||
1. Start Dify [web](../web) service (in a new terminal).
|
||||
|
||||
1. Setup your application by visiting `http://localhost:3000`.
|
||||
```bash
|
||||
cd web
|
||||
pnpm dev:inspect
|
||||
```
|
||||
|
||||
1. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service.
|
||||
1. Set up your application by visiting `http://localhost:3000`.
|
||||
|
||||
```bash
|
||||
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention
|
||||
```
|
||||
1. Optional: start the worker service (async tasks, in a new terminal).
|
||||
|
||||
Additionally, if you want to debug the celery scheduled tasks, you can run the following command in another terminal to start the beat service:
|
||||
```bash
|
||||
cd api
|
||||
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention
|
||||
```
|
||||
|
||||
```bash
|
||||
uv run celery -A app.celery beat
|
||||
```
|
||||
1. Optional: start Celery Beat (scheduled tasks, in a new terminal).
|
||||
|
||||
```bash
|
||||
cd api
|
||||
uv run celery -A app.celery beat
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
### Environment notes
|
||||
|
||||
> [!IMPORTANT]
|
||||
>
|
||||
> When the frontend and backend run on different subdomains, set COOKIE_DOMAIN to the site’s top-level domain (e.g., `example.com`). The frontend and backend must be under the same top-level domain in order to share authentication cookies.
|
||||
|
||||
- Generate a `SECRET_KEY` in the `.env` file.
|
||||
|
||||
bash for Linux
|
||||
|
||||
```bash
|
||||
sed -i "/^SECRET_KEY=/c\\SECRET_KEY=$(openssl rand -base64 42)" .env
|
||||
```
|
||||
|
||||
bash for Mac
|
||||
|
||||
```bash
|
||||
secret_key=$(openssl rand -base64 42)
|
||||
sed -i '' "/^SECRET_KEY=/c\\
|
||||
SECRET_KEY=${secret_key}" .env
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
1. Install dependencies for both the backend and the test environment
|
||||
|
||||
```bash
|
||||
uv sync --dev
|
||||
cd api
|
||||
uv sync --group dev
|
||||
```
|
||||
|
||||
1. Run the tests locally with mocked system environment variables in `tool.pytest_env` section in `pyproject.toml`, more can check [Claude.md](../CLAUDE.md)
|
||||
|
||||
```bash
|
||||
cd api
|
||||
uv run pytest # Run all tests
|
||||
uv run pytest tests/unit_tests/ # Unit tests only
|
||||
uv run pytest tests/integration_tests/ # Integration tests
|
||||
|
||||
# Code quality
|
||||
../dev/reformat # Run all formatters and linters
|
||||
uv run ruff check --fix ./ # Fix linting issues
|
||||
uv run ruff format ./ # Format code
|
||||
uv run basedpyright . # Type checking
|
||||
./dev/reformat # Run all formatters and linters
|
||||
uv run ruff check --fix ./ # Fix linting issues
|
||||
uv run ruff format ./ # Format code
|
||||
uv run basedpyright . # Type checking
|
||||
```
|
||||
|
||||
@ -81,6 +81,7 @@ def initialize_extensions(app: DifyApp):
|
||||
ext_commands,
|
||||
ext_compress,
|
||||
ext_database,
|
||||
ext_fastopenapi,
|
||||
ext_forward_refs,
|
||||
ext_hosting_provider,
|
||||
ext_import_modules,
|
||||
@ -128,6 +129,7 @@ def initialize_extensions(app: DifyApp):
|
||||
ext_proxy_fix,
|
||||
ext_blueprints,
|
||||
ext_commands,
|
||||
ext_fastopenapi,
|
||||
ext_otel,
|
||||
ext_request_logging,
|
||||
ext_session_factory,
|
||||
|
||||
@ -82,13 +82,13 @@ class ProviderNotSupportSpeechToTextError(BaseHTTPException):
|
||||
class DraftWorkflowNotExist(BaseHTTPException):
|
||||
error_code = "draft_workflow_not_exist"
|
||||
description = "Draft workflow need to be initialized."
|
||||
code = 400
|
||||
code = 404
|
||||
|
||||
|
||||
class DraftWorkflowNotSync(BaseHTTPException):
|
||||
error_code = "draft_workflow_not_sync"
|
||||
description = "Workflow graph might have been modified, please refresh and resubmit."
|
||||
code = 400
|
||||
code = 409
|
||||
|
||||
|
||||
class TracingConfigNotExist(BaseHTTPException):
|
||||
|
||||
@ -470,7 +470,7 @@ class AdvancedChatDraftRunLoopNodeApi(Resource):
|
||||
Run draft workflow loop node
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = LoopNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
|
||||
args = LoopNodeRunPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate_single_loop(
|
||||
@ -508,7 +508,7 @@ class WorkflowDraftRunLoopNodeApi(Resource):
|
||||
Run draft workflow loop node
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = LoopNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
|
||||
args = LoopNodeRunPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate_single_loop(
|
||||
@ -999,6 +999,7 @@ class DraftWorkflowTriggerRunApi(Resource):
|
||||
if not event:
|
||||
return jsonable_encoder({"status": "waiting", "retry_in": LISTENING_RETRY_IN})
|
||||
workflow_args = dict(event.workflow_args)
|
||||
|
||||
workflow_args[SKIP_PREPARE_USER_INPUTS_KEY] = True
|
||||
return helper.compact_generate_response(
|
||||
AppGenerateService.generate(
|
||||
@ -1147,6 +1148,7 @@ class DraftWorkflowTriggerRunAllApi(Resource):
|
||||
|
||||
try:
|
||||
workflow_args = dict(trigger_debug_event.workflow_args)
|
||||
|
||||
workflow_args[SKIP_PREPARE_USER_INPUTS_KEY] = True
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model,
|
||||
|
||||
@ -1,17 +1,17 @@
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from . import console_ns
|
||||
from controllers.fastopenapi import console_router
|
||||
|
||||
|
||||
@console_ns.route("/ping")
|
||||
class PingApi(Resource):
|
||||
@console_ns.doc("health_check")
|
||||
@console_ns.doc(description="Health check endpoint for connection testing")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
console_ns.model("PingResponse", {"result": fields.String(description="Health check result", example="pong")}),
|
||||
)
|
||||
def get(self):
|
||||
"""Health check endpoint for connection testing"""
|
||||
return {"result": "pong"}
|
||||
class PingResponse(BaseModel):
|
||||
result: str = Field(description="Health check result", examples=["pong"])
|
||||
|
||||
|
||||
@console_router.get(
|
||||
"/ping",
|
||||
response_model=PingResponse,
|
||||
tags=["console"],
|
||||
)
|
||||
def ping() -> PingResponse:
|
||||
"""Health check endpoint for connection testing."""
|
||||
return PingResponse(result="pong")
|
||||
|
||||
@ -1,20 +1,19 @@
|
||||
from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.fastopenapi import console_router
|
||||
from libs.helper import EmailStr, extract_remote_ip
|
||||
from libs.password import valid_password
|
||||
from models.model import DifySetup, db
|
||||
from services.account_service import RegisterService, TenantService
|
||||
|
||||
from . import console_ns
|
||||
from .error import AlreadySetupError, NotInitValidateError
|
||||
from .init_validate import get_init_validate_status
|
||||
from .wraps import only_edition_self_hosted
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class SetupRequestPayload(BaseModel):
|
||||
email: EmailStr = Field(..., description="Admin email address")
|
||||
@ -28,78 +27,66 @@ class SetupRequestPayload(BaseModel):
|
||||
return valid_password(value)
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
SetupRequestPayload.__name__,
|
||||
SetupRequestPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
class SetupStatusResponse(BaseModel):
|
||||
step: Literal["not_started", "finished"] = Field(description="Setup step status")
|
||||
setup_at: str | None = Field(default=None, description="Setup completion time (ISO format)")
|
||||
|
||||
|
||||
class SetupResponse(BaseModel):
|
||||
result: str = Field(description="Setup result", examples=["success"])
|
||||
|
||||
|
||||
@console_router.get(
|
||||
"/setup",
|
||||
response_model=SetupStatusResponse,
|
||||
tags=["console"],
|
||||
)
|
||||
def get_setup_status_api() -> SetupStatusResponse:
|
||||
"""Get system setup status."""
|
||||
if dify_config.EDITION == "SELF_HOSTED":
|
||||
setup_status = get_setup_status()
|
||||
if setup_status and not isinstance(setup_status, bool):
|
||||
return SetupStatusResponse(step="finished", setup_at=setup_status.setup_at.isoformat())
|
||||
if setup_status:
|
||||
return SetupStatusResponse(step="finished")
|
||||
return SetupStatusResponse(step="not_started")
|
||||
return SetupStatusResponse(step="finished")
|
||||
|
||||
|
||||
@console_ns.route("/setup")
|
||||
class SetupApi(Resource):
|
||||
@console_ns.doc("get_setup_status")
|
||||
@console_ns.doc(description="Get system setup status")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
console_ns.model(
|
||||
"SetupStatusResponse",
|
||||
{
|
||||
"step": fields.String(description="Setup step status", enum=["not_started", "finished"]),
|
||||
"setup_at": fields.String(description="Setup completion time (ISO format)", required=False),
|
||||
},
|
||||
),
|
||||
@console_router.post(
|
||||
"/setup",
|
||||
response_model=SetupResponse,
|
||||
tags=["console"],
|
||||
status_code=201,
|
||||
)
|
||||
@only_edition_self_hosted
|
||||
def setup_system(payload: SetupRequestPayload) -> SetupResponse:
|
||||
"""Initialize system setup with admin account."""
|
||||
if get_setup_status():
|
||||
raise AlreadySetupError()
|
||||
|
||||
tenant_count = TenantService.get_tenant_count()
|
||||
if tenant_count > 0:
|
||||
raise AlreadySetupError()
|
||||
|
||||
if not get_init_validate_status():
|
||||
raise NotInitValidateError()
|
||||
|
||||
normalized_email = payload.email.lower()
|
||||
|
||||
RegisterService.setup(
|
||||
email=normalized_email,
|
||||
name=payload.name,
|
||||
password=payload.password,
|
||||
ip_address=extract_remote_ip(request),
|
||||
language=payload.language,
|
||||
)
|
||||
def get(self):
|
||||
"""Get system setup status"""
|
||||
if dify_config.EDITION == "SELF_HOSTED":
|
||||
setup_status = get_setup_status()
|
||||
# Check if setup_status is a DifySetup object rather than a bool
|
||||
if setup_status and not isinstance(setup_status, bool):
|
||||
return {"step": "finished", "setup_at": setup_status.setup_at.isoformat()}
|
||||
elif setup_status:
|
||||
return {"step": "finished"}
|
||||
return {"step": "not_started"}
|
||||
return {"step": "finished"}
|
||||
|
||||
@console_ns.doc("setup_system")
|
||||
@console_ns.doc(description="Initialize system setup with admin account")
|
||||
@console_ns.expect(console_ns.models[SetupRequestPayload.__name__])
|
||||
@console_ns.response(
|
||||
201, "Success", console_ns.model("SetupResponse", {"result": fields.String(description="Setup result")})
|
||||
)
|
||||
@console_ns.response(400, "Already setup or validation failed")
|
||||
@only_edition_self_hosted
|
||||
def post(self):
|
||||
"""Initialize system setup with admin account"""
|
||||
# is set up
|
||||
if get_setup_status():
|
||||
raise AlreadySetupError()
|
||||
|
||||
# is tenant created
|
||||
tenant_count = TenantService.get_tenant_count()
|
||||
if tenant_count > 0:
|
||||
raise AlreadySetupError()
|
||||
|
||||
if not get_init_validate_status():
|
||||
raise NotInitValidateError()
|
||||
|
||||
args = SetupRequestPayload.model_validate(console_ns.payload)
|
||||
normalized_email = args.email.lower()
|
||||
|
||||
# setup
|
||||
RegisterService.setup(
|
||||
email=normalized_email,
|
||||
name=args.name,
|
||||
password=args.password,
|
||||
ip_address=extract_remote_ip(request),
|
||||
language=args.language,
|
||||
)
|
||||
|
||||
return {"result": "success"}, 201
|
||||
return SetupResponse(result="success")
|
||||
|
||||
|
||||
def get_setup_status():
|
||||
def get_setup_status() -> DifySetup | bool | None:
|
||||
if dify_config.EDITION == "SELF_HOSTED":
|
||||
return db.session.query(DifySetup).first()
|
||||
else:
|
||||
return True
|
||||
|
||||
return True
|
||||
|
||||
@ -1,15 +1,11 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
import httpx
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields
|
||||
from packaging import version
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
from . import console_ns
|
||||
from controllers.fastopenapi import console_router
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -18,69 +14,61 @@ class VersionQuery(BaseModel):
|
||||
current_version: str = Field(..., description="Current application version")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
VersionQuery.__name__,
|
||||
VersionQuery.model_json_schema(ref_template="#/definitions/{model}"),
|
||||
class VersionFeatures(BaseModel):
|
||||
can_replace_logo: bool = Field(description="Whether logo replacement is supported")
|
||||
model_load_balancing_enabled: bool = Field(description="Whether model load balancing is enabled")
|
||||
|
||||
|
||||
class VersionResponse(BaseModel):
|
||||
version: str = Field(description="Latest version number")
|
||||
release_date: str = Field(description="Release date of latest version")
|
||||
release_notes: str = Field(description="Release notes for latest version")
|
||||
can_auto_update: bool = Field(description="Whether auto-update is supported")
|
||||
features: VersionFeatures = Field(description="Feature flags and capabilities")
|
||||
|
||||
|
||||
@console_router.get(
|
||||
"/version",
|
||||
response_model=VersionResponse,
|
||||
tags=["console"],
|
||||
)
|
||||
def check_version_update(query: VersionQuery) -> VersionResponse:
|
||||
"""Check for application version updates."""
|
||||
check_update_url = dify_config.CHECK_UPDATE_URL
|
||||
|
||||
|
||||
@console_ns.route("/version")
|
||||
class VersionApi(Resource):
|
||||
@console_ns.doc("check_version_update")
|
||||
@console_ns.doc(description="Check for application version updates")
|
||||
@console_ns.expect(console_ns.models[VersionQuery.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
console_ns.model(
|
||||
"VersionResponse",
|
||||
{
|
||||
"version": fields.String(description="Latest version number"),
|
||||
"release_date": fields.String(description="Release date of latest version"),
|
||||
"release_notes": fields.String(description="Release notes for latest version"),
|
||||
"can_auto_update": fields.Boolean(description="Whether auto-update is supported"),
|
||||
"features": fields.Raw(description="Feature flags and capabilities"),
|
||||
},
|
||||
result = VersionResponse(
|
||||
version=dify_config.project.version,
|
||||
release_date="",
|
||||
release_notes="",
|
||||
can_auto_update=False,
|
||||
features=VersionFeatures(
|
||||
can_replace_logo=dify_config.CAN_REPLACE_LOGO,
|
||||
model_load_balancing_enabled=dify_config.MODEL_LB_ENABLED,
|
||||
),
|
||||
)
|
||||
def get(self):
|
||||
"""Check for application version updates"""
|
||||
args = VersionQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
check_update_url = dify_config.CHECK_UPDATE_URL
|
||||
|
||||
result = {
|
||||
"version": dify_config.project.version,
|
||||
"release_date": "",
|
||||
"release_notes": "",
|
||||
"can_auto_update": False,
|
||||
"features": {
|
||||
"can_replace_logo": dify_config.CAN_REPLACE_LOGO,
|
||||
"model_load_balancing_enabled": dify_config.MODEL_LB_ENABLED,
|
||||
},
|
||||
}
|
||||
|
||||
if not check_update_url:
|
||||
return result
|
||||
|
||||
try:
|
||||
response = httpx.get(
|
||||
check_update_url,
|
||||
params={"current_version": args.current_version},
|
||||
timeout=httpx.Timeout(timeout=10.0, connect=3.0),
|
||||
)
|
||||
except Exception as error:
|
||||
logger.warning("Check update version error: %s.", str(error))
|
||||
result["version"] = args.current_version
|
||||
return result
|
||||
|
||||
content = json.loads(response.content)
|
||||
if _has_new_version(latest_version=content["version"], current_version=f"{args.current_version}"):
|
||||
result["version"] = content["version"]
|
||||
result["release_date"] = content["releaseDate"]
|
||||
result["release_notes"] = content["releaseNotes"]
|
||||
result["can_auto_update"] = content["canAutoUpdate"]
|
||||
if not check_update_url:
|
||||
return result
|
||||
|
||||
try:
|
||||
response = httpx.get(
|
||||
check_update_url,
|
||||
params={"current_version": query.current_version},
|
||||
timeout=httpx.Timeout(timeout=10.0, connect=3.0),
|
||||
)
|
||||
content = response.json()
|
||||
except Exception as error:
|
||||
logger.warning("Check update version error: %s.", str(error))
|
||||
result.version = query.current_version
|
||||
return result
|
||||
latest_version = content.get("version", result.version)
|
||||
if _has_new_version(latest_version=latest_version, current_version=f"{query.current_version}"):
|
||||
result.version = latest_version
|
||||
result.release_date = content.get("releaseDate", "")
|
||||
result.release_notes = content.get("releaseNotes", "")
|
||||
result.can_auto_update = content.get("canAutoUpdate", False)
|
||||
return result
|
||||
|
||||
|
||||
def _has_new_version(*, latest_version: str, current_version: str) -> bool:
|
||||
try:
|
||||
|
||||
3
api/controllers/fastopenapi.py
Normal file
3
api/controllers/fastopenapi.py
Normal file
@ -0,0 +1,3 @@
|
||||
from fastopenapi.routers import FlaskRouter
|
||||
|
||||
console_router = FlaskRouter()
|
||||
@ -11,7 +11,9 @@ from controllers.service_api.wraps import DatasetApiResource, cloud_edition_bill
|
||||
from fields.dataset_fields import dataset_metadata_fields
|
||||
from services.dataset_service import DatasetService
|
||||
from services.entities.knowledge_entities.knowledge_entities import (
|
||||
DocumentMetadataOperation,
|
||||
MetadataArgs,
|
||||
MetadataDetail,
|
||||
MetadataOperationData,
|
||||
)
|
||||
from services.metadata_service import MetadataService
|
||||
@ -22,7 +24,13 @@ class MetadataUpdatePayload(BaseModel):
|
||||
|
||||
|
||||
register_schema_model(service_api_ns, MetadataUpdatePayload)
|
||||
register_schema_models(service_api_ns, MetadataArgs, MetadataOperationData)
|
||||
register_schema_models(
|
||||
service_api_ns,
|
||||
MetadataArgs,
|
||||
MetadataDetail,
|
||||
DocumentMetadataOperation,
|
||||
MetadataOperationData,
|
||||
)
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/metadata")
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
from flask import make_response, request
|
||||
from flask_restx import Resource, reqparse
|
||||
from flask_restx import Resource
|
||||
from jwt import InvalidTokenError
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console.auth.error import (
|
||||
AuthenticationFailedError,
|
||||
EmailCodeError,
|
||||
@ -18,7 +20,7 @@ from controllers.console.wraps import (
|
||||
)
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.wraps import decode_jwt_token
|
||||
from libs.helper import email
|
||||
from libs.helper import EmailStr
|
||||
from libs.passport import PassportService
|
||||
from libs.password import valid_password
|
||||
from libs.token import (
|
||||
@ -30,10 +32,35 @@ from services.app_service import AppService
|
||||
from services.webapp_auth_service import WebAppAuthService
|
||||
|
||||
|
||||
class LoginPayload(BaseModel):
|
||||
email: EmailStr
|
||||
password: str
|
||||
|
||||
@field_validator("password")
|
||||
@classmethod
|
||||
def validate_password(cls, value: str) -> str:
|
||||
return valid_password(value)
|
||||
|
||||
|
||||
class EmailCodeLoginSendPayload(BaseModel):
|
||||
email: EmailStr
|
||||
language: str | None = None
|
||||
|
||||
|
||||
class EmailCodeLoginVerifyPayload(BaseModel):
|
||||
email: EmailStr
|
||||
code: str
|
||||
token: str = Field(min_length=1)
|
||||
|
||||
|
||||
register_schema_models(web_ns, LoginPayload, EmailCodeLoginSendPayload, EmailCodeLoginVerifyPayload)
|
||||
|
||||
|
||||
@web_ns.route("/login")
|
||||
class LoginApi(Resource):
|
||||
"""Resource for web app email/password login."""
|
||||
|
||||
@web_ns.expect(web_ns.models[LoginPayload.__name__])
|
||||
@setup_required
|
||||
@only_edition_enterprise
|
||||
@web_ns.doc("web_app_login")
|
||||
@ -50,15 +77,10 @@ class LoginApi(Resource):
|
||||
@decrypt_password_field
|
||||
def post(self):
|
||||
"""Authenticate user and login."""
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=email, required=True, location="json")
|
||||
.add_argument("password", type=valid_password, required=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
payload = LoginPayload.model_validate(web_ns.payload or {})
|
||||
|
||||
try:
|
||||
account = WebAppAuthService.authenticate(args["email"], args["password"])
|
||||
account = WebAppAuthService.authenticate(payload.email, payload.password)
|
||||
except services.errors.account.AccountLoginError:
|
||||
raise AccountBannedError()
|
||||
except services.errors.account.AccountPasswordError:
|
||||
@ -145,6 +167,7 @@ class EmailCodeLoginSendEmailApi(Resource):
|
||||
@only_edition_enterprise
|
||||
@web_ns.doc("send_email_code_login")
|
||||
@web_ns.doc(description="Send email verification code for login")
|
||||
@web_ns.expect(web_ns.models[EmailCodeLoginSendPayload.__name__])
|
||||
@web_ns.doc(
|
||||
responses={
|
||||
200: "Email code sent successfully",
|
||||
@ -153,19 +176,14 @@ class EmailCodeLoginSendEmailApi(Resource):
|
||||
}
|
||||
)
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=email, required=True, location="json")
|
||||
.add_argument("language", type=str, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
payload = EmailCodeLoginSendPayload.model_validate(web_ns.payload or {})
|
||||
|
||||
if args["language"] is not None and args["language"] == "zh-Hans":
|
||||
if payload.language == "zh-Hans":
|
||||
language = "zh-Hans"
|
||||
else:
|
||||
language = "en-US"
|
||||
|
||||
account = WebAppAuthService.get_user_through_email(args["email"])
|
||||
account = WebAppAuthService.get_user_through_email(payload.email)
|
||||
if account is None:
|
||||
raise AuthenticationFailedError()
|
||||
else:
|
||||
@ -179,6 +197,7 @@ class EmailCodeLoginApi(Resource):
|
||||
@only_edition_enterprise
|
||||
@web_ns.doc("verify_email_code_login")
|
||||
@web_ns.doc(description="Verify email code and complete login")
|
||||
@web_ns.expect(web_ns.models[EmailCodeLoginVerifyPayload.__name__])
|
||||
@web_ns.doc(
|
||||
responses={
|
||||
200: "Email code verified and login successful",
|
||||
@ -189,17 +208,11 @@ class EmailCodeLoginApi(Resource):
|
||||
)
|
||||
@decrypt_code_field
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=str, required=True, location="json")
|
||||
.add_argument("code", type=str, required=True, location="json")
|
||||
.add_argument("token", type=str, required=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
payload = EmailCodeLoginVerifyPayload.model_validate(web_ns.payload or {})
|
||||
|
||||
user_email = args["email"].lower()
|
||||
user_email = payload.email.lower()
|
||||
|
||||
token_data = WebAppAuthService.get_email_code_login_data(args["token"])
|
||||
token_data = WebAppAuthService.get_email_code_login_data(payload.token)
|
||||
if token_data is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
@ -210,10 +223,10 @@ class EmailCodeLoginApi(Resource):
|
||||
if normalized_token_email != user_email:
|
||||
raise InvalidEmailError()
|
||||
|
||||
if token_data["code"] != args["code"]:
|
||||
if token_data["code"] != payload.code:
|
||||
raise EmailCodeError()
|
||||
|
||||
WebAppAuthService.revoke_email_code_login_token(args["token"])
|
||||
WebAppAuthService.revoke_email_code_login_token(payload.token)
|
||||
account = WebAppAuthService.get_user_through_email(token_email)
|
||||
if not account:
|
||||
raise AuthenticationFailedError()
|
||||
|
||||
@ -1,8 +1,10 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from flask_restx import reqparse
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.error import (
|
||||
CompletionRequestError,
|
||||
@ -27,19 +29,22 @@ from models.model import App, AppMode, EndUser
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
|
||||
class WorkflowRunPayload(BaseModel):
|
||||
inputs: dict[str, Any] = Field(description="Input variables for the workflow")
|
||||
files: list[dict[str, Any]] | None = Field(default=None, description="Files to be processed by the workflow")
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
register_schema_models(web_ns, WorkflowRunPayload)
|
||||
|
||||
|
||||
@web_ns.route("/workflows/run")
|
||||
class WorkflowRunApi(WebApiResource):
|
||||
@web_ns.doc("Run Workflow")
|
||||
@web_ns.doc(description="Execute a workflow with provided inputs and files.")
|
||||
@web_ns.doc(
|
||||
params={
|
||||
"inputs": {"description": "Input variables for the workflow", "type": "object", "required": True},
|
||||
"files": {"description": "Files to be processed by the workflow", "type": "array", "required": False},
|
||||
}
|
||||
)
|
||||
@web_ns.expect(web_ns.models[WorkflowRunPayload.__name__])
|
||||
@web_ns.doc(
|
||||
responses={
|
||||
200: "Success",
|
||||
@ -58,12 +63,8 @@ class WorkflowRunApi(WebApiResource):
|
||||
if app_mode != AppMode.WORKFLOW:
|
||||
raise NotWorkflowAppError()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("files", type=list, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
payload = WorkflowRunPayload.model_validate(web_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate(
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextvars
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Literal, Union, overload
|
||||
from typing import TYPE_CHECKING, Any, Literal, Union, overload
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
@ -13,6 +15,9 @@ from sqlalchemy.orm import Session, sessionmaker
|
||||
import contexts
|
||||
from configs import dify_config
|
||||
from constants import UUID_NIL
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from controllers.console.app.workflow import LoopNodeRunPayload
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
|
||||
from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner
|
||||
@ -304,7 +309,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user: Account | EndUser,
|
||||
args: Mapping,
|
||||
args: LoopNodeRunPayload,
|
||||
streaming: bool = True,
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
|
||||
"""
|
||||
@ -320,7 +325,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
if not node_id:
|
||||
raise ValueError("node_id is required")
|
||||
|
||||
if args.get("inputs") is None:
|
||||
if args.inputs is None:
|
||||
raise ValueError("inputs is required")
|
||||
|
||||
# convert to app config
|
||||
@ -338,7 +343,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
stream=streaming,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
extras={"auto_generate_conversation_name": False},
|
||||
single_loop_run=AdvancedChatAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]),
|
||||
single_loop_run=AdvancedChatAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args.inputs),
|
||||
)
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||
|
||||
@ -236,4 +236,7 @@ class AgentChatAppRunner(AppRunner):
|
||||
queue_manager=queue_manager,
|
||||
stream=application_generate_entity.stream,
|
||||
agent=True,
|
||||
message_id=message.id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
tenant_id=app_config.tenant_id,
|
||||
)
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
import base64
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from mimetypes import guess_extension
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
|
||||
from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity
|
||||
@ -11,10 +13,16 @@ from core.app.entities.app_invoke_entities import (
|
||||
InvokeFrom,
|
||||
ModelConfigWithCredentialsEntity,
|
||||
)
|
||||
from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChunkEvent, QueueMessageEndEvent
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueAgentMessageEvent,
|
||||
QueueLLMChunkEvent,
|
||||
QueueMessageEndEvent,
|
||||
QueueMessageFileEvent,
|
||||
)
|
||||
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
|
||||
from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature
|
||||
from core.external_data_tool.external_data_fetch import ExternalDataFetch
|
||||
from core.file.enums import FileTransferMethod, FileType
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
@ -22,6 +30,7 @@ from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
from core.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
@ -29,7 +38,10 @@ from core.moderation.input_moderation import InputModeration
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform
|
||||
from models.model import App, AppMode, Message, MessageAnnotation
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from extensions.ext_database import db
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import App, AppMode, Message, MessageAnnotation, MessageFile
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.file.models import File
|
||||
@ -203,6 +215,9 @@ class AppRunner:
|
||||
queue_manager: AppQueueManager,
|
||||
stream: bool,
|
||||
agent: bool = False,
|
||||
message_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
tenant_id: str | None = None,
|
||||
):
|
||||
"""
|
||||
Handle invoke result
|
||||
@ -210,21 +225,41 @@ class AppRunner:
|
||||
:param queue_manager: application queue manager
|
||||
:param stream: stream
|
||||
:param agent: agent
|
||||
:param message_id: message id for multimodal output
|
||||
:param user_id: user id for multimodal output
|
||||
:param tenant_id: tenant id for multimodal output
|
||||
:return:
|
||||
"""
|
||||
if not stream and isinstance(invoke_result, LLMResult):
|
||||
self._handle_invoke_result_direct(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent)
|
||||
self._handle_invoke_result_direct(
|
||||
invoke_result=invoke_result,
|
||||
queue_manager=queue_manager,
|
||||
)
|
||||
elif stream and isinstance(invoke_result, Generator):
|
||||
self._handle_invoke_result_stream(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent)
|
||||
self._handle_invoke_result_stream(
|
||||
invoke_result=invoke_result,
|
||||
queue_manager=queue_manager,
|
||||
agent=agent,
|
||||
message_id=message_id,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"unsupported invoke result type: {type(invoke_result)}")
|
||||
|
||||
def _handle_invoke_result_direct(self, invoke_result: LLMResult, queue_manager: AppQueueManager, agent: bool):
|
||||
def _handle_invoke_result_direct(
|
||||
self,
|
||||
invoke_result: LLMResult,
|
||||
queue_manager: AppQueueManager,
|
||||
):
|
||||
"""
|
||||
Handle invoke result direct
|
||||
:param invoke_result: invoke result
|
||||
:param queue_manager: application queue manager
|
||||
:param agent: agent
|
||||
:param message_id: message id for multimodal output
|
||||
:param user_id: user id for multimodal output
|
||||
:param tenant_id: tenant id for multimodal output
|
||||
:return:
|
||||
"""
|
||||
queue_manager.publish(
|
||||
@ -235,13 +270,22 @@ class AppRunner:
|
||||
)
|
||||
|
||||
def _handle_invoke_result_stream(
|
||||
self, invoke_result: Generator[LLMResultChunk, None, None], queue_manager: AppQueueManager, agent: bool
|
||||
self,
|
||||
invoke_result: Generator[LLMResultChunk, None, None],
|
||||
queue_manager: AppQueueManager,
|
||||
agent: bool,
|
||||
message_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
tenant_id: str | None = None,
|
||||
):
|
||||
"""
|
||||
Handle invoke result
|
||||
:param invoke_result: invoke result
|
||||
:param queue_manager: application queue manager
|
||||
:param agent: agent
|
||||
:param message_id: message id for multimodal output
|
||||
:param user_id: user id for multimodal output
|
||||
:param tenant_id: tenant id for multimodal output
|
||||
:return:
|
||||
"""
|
||||
model: str = ""
|
||||
@ -259,12 +303,26 @@ class AppRunner:
|
||||
text += message.content
|
||||
elif isinstance(message.content, list):
|
||||
for content in message.content:
|
||||
if not isinstance(content, str):
|
||||
# TODO(QuantumGhost): Add multimodal output support for easy ui.
|
||||
_logger.warning("received multimodal output, type=%s", type(content))
|
||||
if isinstance(content, str):
|
||||
text += content
|
||||
elif isinstance(content, TextPromptMessageContent):
|
||||
text += content.data
|
||||
elif isinstance(content, ImagePromptMessageContent):
|
||||
if message_id and user_id and tenant_id:
|
||||
try:
|
||||
self._handle_multimodal_image_content(
|
||||
content=content,
|
||||
message_id=message_id,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
queue_manager=queue_manager,
|
||||
)
|
||||
except Exception:
|
||||
_logger.exception("Failed to handle multimodal image output")
|
||||
else:
|
||||
_logger.warning("Received multimodal output but missing required parameters")
|
||||
else:
|
||||
text += content # failback to str
|
||||
text += content.data if hasattr(content, "data") else str(content)
|
||||
|
||||
if not model:
|
||||
model = result.model
|
||||
@ -289,6 +347,101 @@ class AppRunner:
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
def _handle_multimodal_image_content(
|
||||
self,
|
||||
content: ImagePromptMessageContent,
|
||||
message_id: str,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
queue_manager: AppQueueManager,
|
||||
):
|
||||
"""
|
||||
Handle multimodal image content from LLM response.
|
||||
Save the image and create a MessageFile record.
|
||||
|
||||
:param content: ImagePromptMessageContent instance
|
||||
:param message_id: message id
|
||||
:param user_id: user id
|
||||
:param tenant_id: tenant id
|
||||
:param queue_manager: queue manager
|
||||
:return:
|
||||
"""
|
||||
_logger.info("Handling multimodal image content for message %s", message_id)
|
||||
|
||||
image_url = content.url
|
||||
base64_data = content.base64_data
|
||||
|
||||
_logger.info("Image URL: %s, Base64 data present: %s", image_url, base64_data)
|
||||
|
||||
if not image_url and not base64_data:
|
||||
_logger.warning("Image content has neither URL nor base64 data")
|
||||
return
|
||||
|
||||
tool_file_manager = ToolFileManager()
|
||||
|
||||
# Save the image file
|
||||
try:
|
||||
if image_url:
|
||||
# Download image from URL
|
||||
_logger.info("Downloading image from URL: %s", image_url)
|
||||
tool_file = tool_file_manager.create_file_by_url(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
file_url=image_url,
|
||||
conversation_id=None,
|
||||
)
|
||||
_logger.info("Image saved successfully, tool_file_id: %s", tool_file.id)
|
||||
elif base64_data:
|
||||
if base64_data.startswith("data:"):
|
||||
base64_data = base64_data.split(",", 1)[1]
|
||||
|
||||
image_binary = base64.b64decode(base64_data)
|
||||
mimetype = content.mime_type or "image/png"
|
||||
extension = guess_extension(mimetype) or ".png"
|
||||
|
||||
tool_file = tool_file_manager.create_file_by_raw(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=None,
|
||||
file_binary=image_binary,
|
||||
mimetype=mimetype,
|
||||
filename=f"generated_image{extension}",
|
||||
)
|
||||
_logger.info("Image saved successfully, tool_file_id: %s", tool_file.id)
|
||||
else:
|
||||
return
|
||||
except Exception:
|
||||
_logger.exception("Failed to save image file")
|
||||
return
|
||||
|
||||
# Create MessageFile record
|
||||
message_file = MessageFile(
|
||||
message_id=message_id,
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
belongs_to="assistant",
|
||||
url=f"/files/tools/{tool_file.id}",
|
||||
upload_file_id=tool_file.id,
|
||||
created_by_role=(
|
||||
CreatorUserRole.ACCOUNT
|
||||
if queue_manager.invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE}
|
||||
else CreatorUserRole.END_USER
|
||||
),
|
||||
created_by=user_id,
|
||||
)
|
||||
|
||||
db.session.add(message_file)
|
||||
db.session.commit()
|
||||
db.session.refresh(message_file)
|
||||
|
||||
# Publish QueueMessageFileEvent
|
||||
queue_manager.publish(
|
||||
QueueMessageFileEvent(message_file_id=message_file.id),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
_logger.info("QueueMessageFileEvent published for message_file_id: %s", message_file.id)
|
||||
|
||||
def moderation_for_inputs(
|
||||
self,
|
||||
*,
|
||||
|
||||
@ -226,5 +226,10 @@ class ChatAppRunner(AppRunner):
|
||||
|
||||
# handle invoke result
|
||||
self._handle_invoke_result(
|
||||
invoke_result=invoke_result, queue_manager=queue_manager, stream=application_generate_entity.stream
|
||||
invoke_result=invoke_result,
|
||||
queue_manager=queue_manager,
|
||||
stream=application_generate_entity.stream,
|
||||
message_id=message.id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
tenant_id=app_config.tenant_id,
|
||||
)
|
||||
|
||||
@ -184,5 +184,10 @@ class CompletionAppRunner(AppRunner):
|
||||
|
||||
# handle invoke result
|
||||
self._handle_invoke_result(
|
||||
invoke_result=invoke_result, queue_manager=queue_manager, stream=application_generate_entity.stream
|
||||
invoke_result=invoke_result,
|
||||
queue_manager=queue_manager,
|
||||
stream=application_generate_entity.stream,
|
||||
message_id=message.id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
tenant_id=app_config.tenant_id,
|
||||
)
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextvars
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, Literal, Union, overload
|
||||
from typing import TYPE_CHECKING, Any, Literal, Union, overload
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
@ -40,6 +42,9 @@ from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTrigger
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from controllers.console.app.workflow import LoopNodeRunPayload
|
||||
|
||||
SKIP_PREPARE_USER_INPUTS_KEY = "_skip_prepare_user_inputs"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -381,7 +386,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
args: LoopNodeRunPayload,
|
||||
streaming: bool = True,
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
|
||||
"""
|
||||
@ -397,7 +402,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
if not node_id:
|
||||
raise ValueError("node_id is required")
|
||||
|
||||
if args.get("inputs") is None:
|
||||
if args.inputs is None:
|
||||
raise ValueError("inputs is required")
|
||||
|
||||
# convert to app config
|
||||
@ -413,7 +418,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
stream=streaming,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
extras={"auto_generate_conversation_name": False},
|
||||
single_loop_run=WorkflowAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]),
|
||||
single_loop_run=WorkflowAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args.inputs or {}),
|
||||
workflow_execution_id=str(uuid.uuid4()),
|
||||
)
|
||||
contexts.plugin_tool_providers.set({})
|
||||
|
||||
@ -166,18 +166,22 @@ class WorkflowBasedAppRunner:
|
||||
|
||||
# Determine which type of single node execution and get graph/variable_pool
|
||||
if single_iteration_run:
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_for_single_node_run(
|
||||
workflow=workflow,
|
||||
node_id=single_iteration_run.node_id,
|
||||
user_inputs=dict(single_iteration_run.inputs),
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
node_type_filter_key="iteration_id",
|
||||
node_type_label="iteration",
|
||||
)
|
||||
elif single_loop_run:
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_for_single_node_run(
|
||||
workflow=workflow,
|
||||
node_id=single_loop_run.node_id,
|
||||
user_inputs=dict(single_loop_run.inputs),
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
node_type_filter_key="loop_id",
|
||||
node_type_label="loop",
|
||||
)
|
||||
else:
|
||||
raise ValueError("Neither single_iteration_run nor single_loop_run is specified")
|
||||
@ -314,44 +318,6 @@ class WorkflowBasedAppRunner:
|
||||
|
||||
return graph, variable_pool
|
||||
|
||||
def _get_graph_and_variable_pool_of_single_iteration(
|
||||
self,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user_inputs: dict[str, Any],
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
) -> tuple[Graph, VariablePool]:
|
||||
"""
|
||||
Get variable pool of single iteration
|
||||
"""
|
||||
return self._get_graph_and_variable_pool_for_single_node_run(
|
||||
workflow=workflow,
|
||||
node_id=node_id,
|
||||
user_inputs=user_inputs,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
node_type_filter_key="iteration_id",
|
||||
node_type_label="iteration",
|
||||
)
|
||||
|
||||
def _get_graph_and_variable_pool_of_single_loop(
|
||||
self,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user_inputs: dict[str, Any],
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
) -> tuple[Graph, VariablePool]:
|
||||
"""
|
||||
Get variable pool of single loop
|
||||
"""
|
||||
return self._get_graph_and_variable_pool_for_single_node_run(
|
||||
workflow=workflow,
|
||||
node_id=node_id,
|
||||
user_inputs=user_inputs,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
node_type_filter_key="loop_id",
|
||||
node_type_label="loop",
|
||||
)
|
||||
|
||||
def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent):
|
||||
"""
|
||||
Handle event
|
||||
|
||||
@ -39,6 +39,7 @@ from core.app.entities.task_entities import (
|
||||
MessageAudioEndStreamResponse,
|
||||
MessageAudioStreamResponse,
|
||||
MessageEndStreamResponse,
|
||||
StreamEvent,
|
||||
StreamResponse,
|
||||
)
|
||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
@ -70,6 +71,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
|
||||
_task_state: EasyUITaskState
|
||||
_application_generate_entity: Union[ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity]
|
||||
_precomputed_event_type: StreamEvent | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -342,11 +344,15 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
self._task_state.llm_result.message.content = current_content
|
||||
|
||||
if isinstance(event, QueueLLMChunkEvent):
|
||||
event_type = self._message_cycle_manager.get_message_event_type(message_id=self._message_id)
|
||||
# Determine the event type once, on first LLM chunk, and reuse for subsequent chunks
|
||||
if not hasattr(self, "_precomputed_event_type") or self._precomputed_event_type is None:
|
||||
self._precomputed_event_type = self._message_cycle_manager.get_message_event_type(
|
||||
message_id=self._message_id
|
||||
)
|
||||
yield self._message_cycle_manager.message_to_stream_response(
|
||||
answer=cast(str, delta_text),
|
||||
message_id=self._message_id,
|
||||
event_type=event_type,
|
||||
event_type=self._precomputed_event_type,
|
||||
)
|
||||
else:
|
||||
yield self._agent_message_to_stream_response(
|
||||
|
||||
@ -5,7 +5,7 @@ from threading import Thread
|
||||
from typing import Union
|
||||
|
||||
from flask import Flask, current_app
|
||||
from sqlalchemy import exists, select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
@ -30,6 +30,7 @@ from core.app.entities.task_entities import (
|
||||
StreamEvent,
|
||||
WorkflowTaskState,
|
||||
)
|
||||
from core.db.session_factory import session_factory
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
from core.tools.signature import sign_tool_file
|
||||
from extensions.ext_database import db
|
||||
@ -57,13 +58,15 @@ class MessageCycleManager:
|
||||
self._message_has_file: set[str] = set()
|
||||
|
||||
def get_message_event_type(self, message_id: str) -> StreamEvent:
|
||||
# Fast path: cached determination from prior QueueMessageFileEvent
|
||||
if message_id in self._message_has_file:
|
||||
return StreamEvent.MESSAGE_FILE
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
has_file = session.query(exists().where(MessageFile.message_id == message_id)).scalar()
|
||||
# Use SQLAlchemy 2.x style session.scalar(select(...))
|
||||
with session_factory.create_session() as session:
|
||||
message_file = session.scalar(select(MessageFile).where(MessageFile.message_id == message_id))
|
||||
|
||||
if has_file:
|
||||
if message_file:
|
||||
self._message_has_file.add(message_id)
|
||||
return StreamEvent.MESSAGE_FILE
|
||||
|
||||
@ -199,6 +202,8 @@ class MessageCycleManager:
|
||||
message_file = session.scalar(select(MessageFile).where(MessageFile.id == event.message_file_id))
|
||||
|
||||
if message_file and message_file.url is not None:
|
||||
self._message_has_file.add(message_file.message_id)
|
||||
|
||||
# get tool file id
|
||||
tool_file_id = message_file.url.split("/")[-1]
|
||||
# trim extension
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator, Sequence
|
||||
from collections.abc import Callable, Generator, Iterator, Sequence
|
||||
from typing import Union
|
||||
|
||||
from pydantic import ConfigDict
|
||||
@ -30,6 +30,142 @@ def _gen_tool_call_id() -> str:
|
||||
return f"chatcmpl-tool-{str(uuid.uuid4().hex)}"
|
||||
|
||||
|
||||
def _run_callbacks(callbacks: Sequence[Callback] | None, *, event: str, invoke: Callable[[Callback], None]) -> None:
|
||||
if not callbacks:
|
||||
return
|
||||
|
||||
for callback in callbacks:
|
||||
try:
|
||||
invoke(callback)
|
||||
except Exception as e:
|
||||
if callback.raise_error:
|
||||
raise
|
||||
logger.warning("Callback %s %s failed with error %s", callback.__class__.__name__, event, e)
|
||||
|
||||
|
||||
def _get_or_create_tool_call(
|
||||
existing_tools_calls: list[AssistantPromptMessage.ToolCall],
|
||||
tool_call_id: str,
|
||||
) -> AssistantPromptMessage.ToolCall:
|
||||
"""
|
||||
Get or create a tool call by ID.
|
||||
|
||||
If `tool_call_id` is empty, returns the most recently created tool call.
|
||||
"""
|
||||
if not tool_call_id:
|
||||
if not existing_tools_calls:
|
||||
raise ValueError("tool_call_id is empty but no existing tool call is available to apply the delta")
|
||||
return existing_tools_calls[-1]
|
||||
|
||||
tool_call = next((tool_call for tool_call in existing_tools_calls if tool_call.id == tool_call_id), None)
|
||||
if tool_call is None:
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id=tool_call_id,
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments=""),
|
||||
)
|
||||
existing_tools_calls.append(tool_call)
|
||||
|
||||
return tool_call
|
||||
|
||||
|
||||
def _merge_tool_call_delta(
|
||||
tool_call: AssistantPromptMessage.ToolCall,
|
||||
delta: AssistantPromptMessage.ToolCall,
|
||||
) -> None:
|
||||
if delta.id:
|
||||
tool_call.id = delta.id
|
||||
if delta.type:
|
||||
tool_call.type = delta.type
|
||||
if delta.function.name:
|
||||
tool_call.function.name = delta.function.name
|
||||
if delta.function.arguments:
|
||||
tool_call.function.arguments += delta.function.arguments
|
||||
|
||||
|
||||
def _build_llm_result_from_first_chunk(
|
||||
model: str,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
chunks: Iterator[LLMResultChunk],
|
||||
) -> LLMResult:
|
||||
"""
|
||||
Build a single `LLMResult` from the first returned chunk.
|
||||
|
||||
This is used for `stream=False` because the plugin side may still implement the response via a chunked stream.
|
||||
"""
|
||||
content = ""
|
||||
content_list: list[PromptMessageContentUnionTypes] = []
|
||||
usage = LLMUsage.empty_usage()
|
||||
system_fingerprint: str | None = None
|
||||
tools_calls: list[AssistantPromptMessage.ToolCall] = []
|
||||
|
||||
first_chunk = next(chunks, None)
|
||||
if first_chunk is not None:
|
||||
if isinstance(first_chunk.delta.message.content, str):
|
||||
content += first_chunk.delta.message.content
|
||||
elif isinstance(first_chunk.delta.message.content, list):
|
||||
content_list.extend(first_chunk.delta.message.content)
|
||||
|
||||
if first_chunk.delta.message.tool_calls:
|
||||
_increase_tool_call(first_chunk.delta.message.tool_calls, tools_calls)
|
||||
|
||||
usage = first_chunk.delta.usage or LLMUsage.empty_usage()
|
||||
system_fingerprint = first_chunk.system_fingerprint
|
||||
|
||||
return LLMResult(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(
|
||||
content=content or content_list,
|
||||
tool_calls=tools_calls,
|
||||
),
|
||||
usage=usage,
|
||||
system_fingerprint=system_fingerprint,
|
||||
)
|
||||
|
||||
|
||||
def _invoke_llm_via_plugin(
|
||||
*,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
model_parameters: dict,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
tools: list[PromptMessageTool] | None,
|
||||
stop: Sequence[str] | None,
|
||||
stream: bool,
|
||||
) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]:
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
return plugin_model_manager.invoke_llm(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
model_parameters=model_parameters,
|
||||
prompt_messages=list(prompt_messages),
|
||||
tools=tools,
|
||||
stop=list(stop) if stop else None,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
|
||||
def _normalize_non_stream_plugin_result(
|
||||
model: str,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
result: Union[LLMResult, Iterator[LLMResultChunk]],
|
||||
) -> LLMResult:
|
||||
if isinstance(result, LLMResult):
|
||||
return result
|
||||
return _build_llm_result_from_first_chunk(model=model, prompt_messages=prompt_messages, chunks=result)
|
||||
|
||||
|
||||
def _increase_tool_call(
|
||||
new_tool_calls: list[AssistantPromptMessage.ToolCall], existing_tools_calls: list[AssistantPromptMessage.ToolCall]
|
||||
):
|
||||
@ -40,42 +176,13 @@ def _increase_tool_call(
|
||||
:param existing_tools_calls: List of existing tool calls to be modified IN-PLACE.
|
||||
"""
|
||||
|
||||
def get_tool_call(tool_call_id: str):
|
||||
"""
|
||||
Get or create a tool call by ID
|
||||
|
||||
:param tool_call_id: tool call ID
|
||||
:return: existing or new tool call
|
||||
"""
|
||||
if not tool_call_id:
|
||||
return existing_tools_calls[-1]
|
||||
|
||||
_tool_call = next((_tool_call for _tool_call in existing_tools_calls if _tool_call.id == tool_call_id), None)
|
||||
if _tool_call is None:
|
||||
_tool_call = AssistantPromptMessage.ToolCall(
|
||||
id=tool_call_id,
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments=""),
|
||||
)
|
||||
existing_tools_calls.append(_tool_call)
|
||||
|
||||
return _tool_call
|
||||
|
||||
for new_tool_call in new_tool_calls:
|
||||
# generate ID for tool calls with function name but no ID to track them
|
||||
if new_tool_call.function.name and not new_tool_call.id:
|
||||
new_tool_call.id = _gen_tool_call_id()
|
||||
# get tool call
|
||||
tool_call = get_tool_call(new_tool_call.id)
|
||||
# update tool call
|
||||
if new_tool_call.id:
|
||||
tool_call.id = new_tool_call.id
|
||||
if new_tool_call.type:
|
||||
tool_call.type = new_tool_call.type
|
||||
if new_tool_call.function.name:
|
||||
tool_call.function.name = new_tool_call.function.name
|
||||
if new_tool_call.function.arguments:
|
||||
tool_call.function.arguments += new_tool_call.function.arguments
|
||||
|
||||
tool_call = _get_or_create_tool_call(existing_tools_calls, new_tool_call.id)
|
||||
_merge_tool_call_delta(tool_call, new_tool_call)
|
||||
|
||||
|
||||
class LargeLanguageModel(AIModel):
|
||||
@ -141,10 +248,7 @@ class LargeLanguageModel(AIModel):
|
||||
result: Union[LLMResult, Generator[LLMResultChunk, None, None]]
|
||||
|
||||
try:
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
result = plugin_model_manager.invoke_llm(
|
||||
result = _invoke_llm_via_plugin(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user or "unknown",
|
||||
plugin_id=self.plugin_id,
|
||||
@ -154,38 +258,13 @@ class LargeLanguageModel(AIModel):
|
||||
model_parameters=model_parameters,
|
||||
prompt_messages=prompt_messages,
|
||||
tools=tools,
|
||||
stop=list(stop) if stop else None,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
if not stream:
|
||||
content = ""
|
||||
content_list = []
|
||||
usage = LLMUsage.empty_usage()
|
||||
system_fingerprint = None
|
||||
tools_calls: list[AssistantPromptMessage.ToolCall] = []
|
||||
|
||||
for chunk in result:
|
||||
if isinstance(chunk.delta.message.content, str):
|
||||
content += chunk.delta.message.content
|
||||
elif isinstance(chunk.delta.message.content, list):
|
||||
content_list.extend(chunk.delta.message.content)
|
||||
if chunk.delta.message.tool_calls:
|
||||
_increase_tool_call(chunk.delta.message.tool_calls, tools_calls)
|
||||
|
||||
usage = chunk.delta.usage or LLMUsage.empty_usage()
|
||||
system_fingerprint = chunk.system_fingerprint
|
||||
break
|
||||
|
||||
result = LLMResult(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(
|
||||
content=content or content_list,
|
||||
tool_calls=tools_calls,
|
||||
),
|
||||
usage=usage,
|
||||
system_fingerprint=system_fingerprint,
|
||||
result = _normalize_non_stream_plugin_result(
|
||||
model=model, prompt_messages=prompt_messages, result=result
|
||||
)
|
||||
except Exception as e:
|
||||
self._trigger_invoke_error_callbacks(
|
||||
@ -425,27 +504,21 @@ class LargeLanguageModel(AIModel):
|
||||
:param user: unique user id
|
||||
:param callbacks: callbacks
|
||||
"""
|
||||
if callbacks:
|
||||
for callback in callbacks:
|
||||
try:
|
||||
callback.on_before_invoke(
|
||||
llm_instance=self,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
)
|
||||
except Exception as e:
|
||||
if callback.raise_error:
|
||||
raise e
|
||||
else:
|
||||
logger.warning(
|
||||
"Callback %s on_before_invoke failed with error %s", callback.__class__.__name__, e
|
||||
)
|
||||
_run_callbacks(
|
||||
callbacks,
|
||||
event="on_before_invoke",
|
||||
invoke=lambda callback: callback.on_before_invoke(
|
||||
llm_instance=self,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
),
|
||||
)
|
||||
|
||||
def _trigger_new_chunk_callbacks(
|
||||
self,
|
||||
@ -473,26 +546,22 @@ class LargeLanguageModel(AIModel):
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
"""
|
||||
if callbacks:
|
||||
for callback in callbacks:
|
||||
try:
|
||||
callback.on_new_chunk(
|
||||
llm_instance=self,
|
||||
chunk=chunk,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
)
|
||||
except Exception as e:
|
||||
if callback.raise_error:
|
||||
raise e
|
||||
else:
|
||||
logger.warning("Callback %s on_new_chunk failed with error %s", callback.__class__.__name__, e)
|
||||
_run_callbacks(
|
||||
callbacks,
|
||||
event="on_new_chunk",
|
||||
invoke=lambda callback: callback.on_new_chunk(
|
||||
llm_instance=self,
|
||||
chunk=chunk,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
),
|
||||
)
|
||||
|
||||
def _trigger_after_invoke_callbacks(
|
||||
self,
|
||||
@ -521,28 +590,22 @@ class LargeLanguageModel(AIModel):
|
||||
:param user: unique user id
|
||||
:param callbacks: callbacks
|
||||
"""
|
||||
if callbacks:
|
||||
for callback in callbacks:
|
||||
try:
|
||||
callback.on_after_invoke(
|
||||
llm_instance=self,
|
||||
result=result,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
)
|
||||
except Exception as e:
|
||||
if callback.raise_error:
|
||||
raise e
|
||||
else:
|
||||
logger.warning(
|
||||
"Callback %s on_after_invoke failed with error %s", callback.__class__.__name__, e
|
||||
)
|
||||
_run_callbacks(
|
||||
callbacks,
|
||||
event="on_after_invoke",
|
||||
invoke=lambda callback: callback.on_after_invoke(
|
||||
llm_instance=self,
|
||||
result=result,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
),
|
||||
)
|
||||
|
||||
def _trigger_invoke_error_callbacks(
|
||||
self,
|
||||
@ -571,25 +634,19 @@ class LargeLanguageModel(AIModel):
|
||||
:param user: unique user id
|
||||
:param callbacks: callbacks
|
||||
"""
|
||||
if callbacks:
|
||||
for callback in callbacks:
|
||||
try:
|
||||
callback.on_invoke_error(
|
||||
llm_instance=self,
|
||||
ex=ex,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
)
|
||||
except Exception as e:
|
||||
if callback.raise_error:
|
||||
raise e
|
||||
else:
|
||||
logger.warning(
|
||||
"Callback %s on_invoke_error failed with error %s", callback.__class__.__name__, e
|
||||
)
|
||||
_run_callbacks(
|
||||
callbacks,
|
||||
event="on_invoke_error",
|
||||
invoke=lambda callback: callback.on_invoke_error(
|
||||
llm_instance=self,
|
||||
ex=ex,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
),
|
||||
)
|
||||
|
||||
@ -154,7 +154,7 @@ class IrisConnectionPool:
|
||||
# Add to cache to skip future checks
|
||||
self._schemas_initialized.add(schema)
|
||||
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
logger.exception("Failed to ensure schema %s exists", schema)
|
||||
raise
|
||||
@ -177,6 +177,9 @@ class IrisConnectionPool:
|
||||
class IrisVector(BaseVector):
|
||||
"""IRIS vector database implementation using native VECTOR type and HNSW indexing."""
|
||||
|
||||
# Fallback score for full-text search when Rank function unavailable or TEXT_INDEX disabled
|
||||
_FULL_TEXT_FALLBACK_SCORE = 0.5
|
||||
|
||||
def __init__(self, collection_name: str, config: IrisVectorConfig) -> None:
|
||||
super().__init__(collection_name)
|
||||
self.config = config
|
||||
@ -272,41 +275,131 @@ class IrisVector(BaseVector):
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
"""Search documents by full-text using iFind index or fallback to LIKE search."""
|
||||
"""Search documents by full-text using iFind index with BM25 relevance scoring.
|
||||
|
||||
When IRIS_TEXT_INDEX is enabled, this method uses the auto-generated Rank
|
||||
function from %iFind.Index.Basic to calculate BM25 relevance scores. The Rank
|
||||
function is automatically created with naming: {schema}.{table_name}_{index}Rank
|
||||
|
||||
Args:
|
||||
query: Search query string
|
||||
**kwargs: Optional parameters including top_k, document_ids_filter
|
||||
|
||||
Returns:
|
||||
List of Document objects with relevance scores in metadata["score"]
|
||||
"""
|
||||
top_k = kwargs.get("top_k", 5)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
|
||||
with self._get_cursor() as cursor:
|
||||
if self.config.IRIS_TEXT_INDEX:
|
||||
# Use iFind full-text search with index
|
||||
# Use iFind full-text search with auto-generated Rank function
|
||||
text_index_name = f"idx_{self.table_name}_text"
|
||||
# IRIS removes underscores from function names
|
||||
table_no_underscore = self.table_name.replace("_", "")
|
||||
index_no_underscore = text_index_name.replace("_", "")
|
||||
rank_function = f"{self.schema}.{table_no_underscore}_{index_no_underscore}Rank"
|
||||
|
||||
# Build WHERE clause with document ID filter if provided
|
||||
where_clause = f"WHERE %ID %FIND search_index({text_index_name}, ?)"
|
||||
# First param for Rank function, second for FIND
|
||||
params = [query, query]
|
||||
|
||||
if document_ids_filter:
|
||||
# Add document ID filter
|
||||
placeholders = ",".join("?" * len(document_ids_filter))
|
||||
where_clause += f" AND JSON_VALUE(meta, '$.document_id') IN ({placeholders})"
|
||||
params.extend(document_ids_filter)
|
||||
|
||||
sql = f"""
|
||||
SELECT TOP {top_k} id, text, meta
|
||||
SELECT TOP {top_k}
|
||||
id,
|
||||
text,
|
||||
meta,
|
||||
{rank_function}(%ID, ?) AS score
|
||||
FROM {self.schema}.{self.table_name}
|
||||
WHERE %ID %FIND search_index({text_index_name}, ?)
|
||||
{where_clause}
|
||||
ORDER BY score DESC
|
||||
"""
|
||||
cursor.execute(sql, (query,))
|
||||
|
||||
logger.debug(
|
||||
"iFind search: query='%s', index='%s', rank='%s'",
|
||||
query,
|
||||
text_index_name,
|
||||
rank_function,
|
||||
)
|
||||
|
||||
try:
|
||||
cursor.execute(sql, params)
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
# Fallback to query without Rank function if it fails
|
||||
logger.warning(
|
||||
"Rank function '%s' failed, using fixed score",
|
||||
rank_function,
|
||||
exc_info=True,
|
||||
)
|
||||
sql_fallback = f"""
|
||||
SELECT TOP {top_k} id, text, meta, {self._FULL_TEXT_FALLBACK_SCORE} AS score
|
||||
FROM {self.schema}.{self.table_name}
|
||||
{where_clause}
|
||||
"""
|
||||
# Skip first param (for Rank function)
|
||||
cursor.execute(sql_fallback, params[1:])
|
||||
else:
|
||||
# Fallback to LIKE search (inefficient for large datasets)
|
||||
# Escape special characters for LIKE clause to prevent SQL injection
|
||||
from libs.helper import escape_like_pattern
|
||||
# Fallback to LIKE search (IRIS_TEXT_INDEX disabled)
|
||||
from libs.helper import ( # pylint: disable=import-outside-toplevel
|
||||
escape_like_pattern,
|
||||
)
|
||||
|
||||
escaped_query = escape_like_pattern(query)
|
||||
query_pattern = f"%{escaped_query}%"
|
||||
|
||||
# Build WHERE clause with document ID filter if provided
|
||||
where_clause = "WHERE text LIKE ? ESCAPE '\\\\'"
|
||||
params = [query_pattern]
|
||||
|
||||
if document_ids_filter:
|
||||
placeholders = ",".join("?" * len(document_ids_filter))
|
||||
where_clause += f" AND JSON_VALUE(meta, '$.document_id') IN ({placeholders})"
|
||||
params.extend(document_ids_filter)
|
||||
|
||||
sql = f"""
|
||||
SELECT TOP {top_k} id, text, meta
|
||||
SELECT TOP {top_k} id, text, meta, {self._FULL_TEXT_FALLBACK_SCORE} AS score
|
||||
FROM {self.schema}.{self.table_name}
|
||||
WHERE text LIKE ? ESCAPE '\\'
|
||||
{where_clause}
|
||||
ORDER BY LENGTH(text) ASC
|
||||
"""
|
||||
cursor.execute(sql, (query_pattern,))
|
||||
|
||||
logger.debug(
|
||||
"LIKE fallback (TEXT_INDEX disabled): query='%s'",
|
||||
query_pattern,
|
||||
)
|
||||
cursor.execute(sql, params)
|
||||
|
||||
docs = []
|
||||
for row in cursor.fetchall():
|
||||
if len(row) >= 3:
|
||||
metadata = json.loads(row[2]) if row[2] else {}
|
||||
docs.append(Document(page_content=row[1], metadata=metadata))
|
||||
# Expecting 4 columns: id, text, meta, score
|
||||
if len(row) >= 4:
|
||||
text_content = row[1]
|
||||
meta_str = row[2]
|
||||
score_value = row[3]
|
||||
|
||||
metadata = json.loads(meta_str) if meta_str else {}
|
||||
# Add score to metadata for hybrid search compatibility
|
||||
score = float(score_value) if score_value is not None else 0.0
|
||||
metadata["score"] = score
|
||||
|
||||
docs.append(Document(page_content=text_content, metadata=metadata))
|
||||
|
||||
logger.info(
|
||||
"Full-text search completed: query='%s', results=%d/%d",
|
||||
query,
|
||||
len(docs),
|
||||
top_k,
|
||||
)
|
||||
|
||||
if not docs:
|
||||
logger.info("Full-text search for '%s' returned no results", query)
|
||||
logger.warning("Full-text search for '%s' returned no results", query)
|
||||
|
||||
return docs
|
||||
|
||||
@ -370,7 +463,11 @@ class IrisVector(BaseVector):
|
||||
AS %iFind.Index.Basic
|
||||
(LANGUAGE = '{language}', LOWER = 1, INDEXOPTION = 0)
|
||||
"""
|
||||
logger.info("Creating text index: %s with language: %s", text_index_name, language)
|
||||
logger.info(
|
||||
"Creating text index: %s with language: %s",
|
||||
text_index_name,
|
||||
language,
|
||||
)
|
||||
logger.info("SQL for text index: %s", sql_text_index)
|
||||
cursor.execute(sql_text_index)
|
||||
logger.info("Text index created successfully: %s", text_index_name)
|
||||
|
||||
@ -130,7 +130,7 @@ class ToolInvokeMessage(BaseModel):
|
||||
text: str
|
||||
|
||||
class JsonMessage(BaseModel):
|
||||
json_object: dict
|
||||
json_object: dict | list
|
||||
suppress_output: bool = Field(default=False, description="Whether to suppress JSON output in result string")
|
||||
|
||||
class BlobMessage(BaseModel):
|
||||
@ -144,7 +144,14 @@ class ToolInvokeMessage(BaseModel):
|
||||
end: bool = Field(..., description="Whether the chunk is the last chunk")
|
||||
|
||||
class FileMessage(BaseModel):
|
||||
pass
|
||||
file_marker: str = Field(default="file_marker")
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_file_message(cls, values):
|
||||
if isinstance(values, dict) and "file_marker" not in values:
|
||||
raise ValueError("Invalid FileMessage: missing file_marker")
|
||||
return values
|
||||
|
||||
class VariableMessage(BaseModel):
|
||||
variable_name: str = Field(..., description="The name of the variable")
|
||||
@ -234,10 +241,22 @@ class ToolInvokeMessage(BaseModel):
|
||||
|
||||
@field_validator("message", mode="before")
|
||||
@classmethod
|
||||
def decode_blob_message(cls, v):
|
||||
def decode_blob_message(cls, v, info: ValidationInfo):
|
||||
# 处理 blob 解码
|
||||
if isinstance(v, dict) and "blob" in v:
|
||||
with contextlib.suppress(Exception):
|
||||
v["blob"] = base64.b64decode(v["blob"])
|
||||
|
||||
# Force correct message type based on type field
|
||||
# Only wrap dict types to avoid wrapping already parsed Pydantic model objects
|
||||
if info.data and isinstance(info.data, dict) and isinstance(v, dict):
|
||||
msg_type = info.data.get("type")
|
||||
if msg_type == cls.MessageType.JSON:
|
||||
if "json_object" not in v:
|
||||
v = {"json_object": v}
|
||||
elif msg_type == cls.MessageType.FILE:
|
||||
v = {"file_marker": "file_marker"}
|
||||
|
||||
return v
|
||||
|
||||
@field_serializer("message")
|
||||
|
||||
@ -494,7 +494,7 @@ class AgentNode(Node[AgentNodeData]):
|
||||
|
||||
text = ""
|
||||
files: list[File] = []
|
||||
json_list: list[dict] = []
|
||||
json_list: list[dict | list] = []
|
||||
|
||||
agent_logs: list[AgentLogEvent] = []
|
||||
agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {}
|
||||
@ -568,13 +568,18 @@ class AgentNode(Node[AgentNodeData]):
|
||||
elif message.type == ToolInvokeMessage.MessageType.JSON:
|
||||
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
|
||||
if node_type == NodeType.AGENT:
|
||||
msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {})
|
||||
llm_usage = LLMUsage.from_metadata(cast(LLMUsageMetadata, msg_metadata))
|
||||
agent_execution_metadata = {
|
||||
WorkflowNodeExecutionMetadataKey(key): value
|
||||
for key, value in msg_metadata.items()
|
||||
if key in WorkflowNodeExecutionMetadataKey.__members__.values()
|
||||
}
|
||||
if isinstance(message.message.json_object, dict):
|
||||
msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {})
|
||||
llm_usage = LLMUsage.from_metadata(cast(LLMUsageMetadata, msg_metadata))
|
||||
agent_execution_metadata = {
|
||||
WorkflowNodeExecutionMetadataKey(key): value
|
||||
for key, value in msg_metadata.items()
|
||||
if key in WorkflowNodeExecutionMetadataKey.__members__.values()
|
||||
}
|
||||
else:
|
||||
msg_metadata = {}
|
||||
llm_usage = LLMUsage.empty_usage()
|
||||
agent_execution_metadata = {}
|
||||
if message.message.json_object:
|
||||
json_list.append(message.message.json_object)
|
||||
elif message.type == ToolInvokeMessage.MessageType.LINK:
|
||||
@ -683,7 +688,7 @@ class AgentNode(Node[AgentNodeData]):
|
||||
yield agent_log
|
||||
|
||||
# Add agent_logs to outputs['json'] to ensure frontend can access thinking process
|
||||
json_output: list[dict[str, Any]] = []
|
||||
json_output: list[dict[str, Any] | list[Any]] = []
|
||||
|
||||
# Step 1: append each agent log as its own dict.
|
||||
if agent_logs:
|
||||
|
||||
@ -301,7 +301,7 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||
|
||||
text = ""
|
||||
files: list[File] = []
|
||||
json: list[dict] = []
|
||||
json: list[dict | list] = []
|
||||
|
||||
variables: dict[str, Any] = {}
|
||||
|
||||
|
||||
@ -244,7 +244,7 @@ class ToolNode(Node[ToolNodeData]):
|
||||
|
||||
text = ""
|
||||
files: list[File] = []
|
||||
json: list[dict] = []
|
||||
json: list[dict | list] = []
|
||||
|
||||
variables: dict[str, Any] = {}
|
||||
|
||||
@ -400,7 +400,7 @@ class ToolNode(Node[ToolNodeData]):
|
||||
message.message.metadata = dict_metadata
|
||||
|
||||
# Add agent_logs to outputs['json'] to ensure frontend can access thinking process
|
||||
json_output: list[dict[str, Any]] = []
|
||||
json_output: list[dict[str, Any] | list[Any]] = []
|
||||
|
||||
# Step 2: normalize JSON into {"data": [...]}.change json to list[dict]
|
||||
if json:
|
||||
|
||||
21
api/enums/hosted_provider.py
Normal file
21
api/enums/hosted_provider.py
Normal file
@ -0,0 +1,21 @@
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class HostedTrialProvider(StrEnum):
|
||||
"""
|
||||
Enum representing hosted model provider names for trial access.
|
||||
"""
|
||||
|
||||
OPENAI = "langgenius/openai/openai"
|
||||
ANTHROPIC = "langgenius/anthropic/anthropic"
|
||||
GEMINI = "langgenius/gemini/google"
|
||||
X = "langgenius/x/x"
|
||||
DEEPSEEK = "langgenius/deepseek/deepseek"
|
||||
TONGYI = "langgenius/tongyi/tongyi"
|
||||
|
||||
@property
|
||||
def config_key(self) -> str:
|
||||
"""Return the config key used in dify_config (e.g., HOSTED_{config_key}_PAID_ENABLED)."""
|
||||
if self == HostedTrialProvider.X:
|
||||
return "XAI"
|
||||
return self.name
|
||||
45
api/extensions/ext_fastopenapi.py
Normal file
45
api/extensions/ext_fastopenapi.py
Normal file
@ -0,0 +1,45 @@
|
||||
from fastopenapi.routers import FlaskRouter
|
||||
from flask_cors import CORS
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.fastopenapi import console_router
|
||||
from dify_app import DifyApp
|
||||
from extensions.ext_blueprints import AUTHENTICATED_HEADERS, EXPOSED_HEADERS
|
||||
|
||||
DOCS_PREFIX = "/fastopenapi"
|
||||
|
||||
|
||||
def init_app(app: DifyApp) -> None:
|
||||
docs_enabled = dify_config.SWAGGER_UI_ENABLED
|
||||
docs_url = f"{DOCS_PREFIX}/docs" if docs_enabled else None
|
||||
redoc_url = f"{DOCS_PREFIX}/redoc" if docs_enabled else None
|
||||
openapi_url = f"{DOCS_PREFIX}/openapi.json" if docs_enabled else None
|
||||
|
||||
router = FlaskRouter(
|
||||
app=app,
|
||||
docs_url=docs_url,
|
||||
redoc_url=redoc_url,
|
||||
openapi_url=openapi_url,
|
||||
openapi_version="3.0.0",
|
||||
title="Dify API (FastOpenAPI PoC)",
|
||||
version="1.0",
|
||||
description="FastOpenAPI proof of concept for Dify API",
|
||||
)
|
||||
|
||||
# Ensure route decorators are evaluated.
|
||||
import controllers.console.ping as ping_module
|
||||
from controllers.console import setup
|
||||
|
||||
_ = ping_module
|
||||
_ = setup
|
||||
|
||||
router.include_router(console_router, prefix="/console/api")
|
||||
CORS(
|
||||
app,
|
||||
resources={r"/console/api/*": {"origins": dify_config.CONSOLE_CORS_ALLOW_ORIGINS}},
|
||||
supports_credentials=True,
|
||||
allow_headers=list(AUTHENTICATED_HEADERS),
|
||||
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
||||
expose_headers=list(EXPOSED_HEADERS),
|
||||
)
|
||||
app.extensions["fastopenapi"] = router
|
||||
@ -315,40 +315,48 @@ class App(Base):
|
||||
return None
|
||||
|
||||
|
||||
class AppModelConfig(Base):
|
||||
class AppModelConfig(TypeBase):
|
||||
__tablename__ = "app_model_configs"
|
||||
__table_args__ = (sa.PrimaryKeyConstraint("id", name="app_model_config_pkey"), sa.Index("app_app_id_idx", "app_id"))
|
||||
|
||||
id = mapped_column(StringUUID, default=lambda: str(uuid4()))
|
||||
app_id = mapped_column(StringUUID, nullable=False)
|
||||
provider = mapped_column(String(255), nullable=True)
|
||||
model_id = mapped_column(String(255), nullable=True)
|
||||
configs = mapped_column(sa.JSON, nullable=True)
|
||||
created_by = mapped_column(StringUUID, nullable=True)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_by = mapped_column(StringUUID, nullable=True)
|
||||
updated_at = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
provider: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
|
||||
model_id: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
|
||||
configs: Mapped[Any | None] = mapped_column(sa.JSON, nullable=True, default=None)
|
||||
created_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
opening_statement = mapped_column(LongText)
|
||||
suggested_questions = mapped_column(LongText)
|
||||
suggested_questions_after_answer = mapped_column(LongText)
|
||||
speech_to_text = mapped_column(LongText)
|
||||
text_to_speech = mapped_column(LongText)
|
||||
more_like_this = mapped_column(LongText)
|
||||
model = mapped_column(LongText)
|
||||
user_input_form = mapped_column(LongText)
|
||||
dataset_query_variable = mapped_column(String(255))
|
||||
pre_prompt = mapped_column(LongText)
|
||||
agent_mode = mapped_column(LongText)
|
||||
sensitive_word_avoidance = mapped_column(LongText)
|
||||
retriever_resource = mapped_column(LongText)
|
||||
prompt_type = mapped_column(String(255), nullable=False, server_default=sa.text("'simple'"))
|
||||
chat_prompt_config = mapped_column(LongText)
|
||||
completion_prompt_config = mapped_column(LongText)
|
||||
dataset_configs = mapped_column(LongText)
|
||||
external_data_tools = mapped_column(LongText)
|
||||
file_upload = mapped_column(LongText)
|
||||
updated_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime,
|
||||
nullable=False,
|
||||
server_default=func.current_timestamp(),
|
||||
onupdate=func.current_timestamp(),
|
||||
init=False,
|
||||
)
|
||||
opening_statement: Mapped[str | None] = mapped_column(LongText, default=None)
|
||||
suggested_questions: Mapped[str | None] = mapped_column(LongText, default=None)
|
||||
suggested_questions_after_answer: Mapped[str | None] = mapped_column(LongText, default=None)
|
||||
speech_to_text: Mapped[str | None] = mapped_column(LongText, default=None)
|
||||
text_to_speech: Mapped[str | None] = mapped_column(LongText, default=None)
|
||||
more_like_this: Mapped[str | None] = mapped_column(LongText, default=None)
|
||||
model: Mapped[str | None] = mapped_column(LongText, default=None)
|
||||
user_input_form: Mapped[str | None] = mapped_column(LongText, default=None)
|
||||
dataset_query_variable: Mapped[str | None] = mapped_column(String(255), default=None)
|
||||
pre_prompt: Mapped[str | None] = mapped_column(LongText, default=None)
|
||||
agent_mode: Mapped[str | None] = mapped_column(LongText, default=None)
|
||||
sensitive_word_avoidance: Mapped[str | None] = mapped_column(LongText, default=None)
|
||||
retriever_resource: Mapped[str | None] = mapped_column(LongText, default=None)
|
||||
prompt_type: Mapped[str] = mapped_column(
|
||||
String(255), nullable=False, server_default=sa.text("'simple'"), default="simple"
|
||||
)
|
||||
chat_prompt_config: Mapped[str | None] = mapped_column(LongText, default=None)
|
||||
completion_prompt_config: Mapped[str | None] = mapped_column(LongText, default=None)
|
||||
dataset_configs: Mapped[str | None] = mapped_column(LongText, default=None)
|
||||
external_data_tools: Mapped[str | None] = mapped_column(LongText, default=None)
|
||||
file_upload: Mapped[str | None] = mapped_column(LongText, default=None)
|
||||
|
||||
@property
|
||||
def app(self) -> App | None:
|
||||
@ -810,8 +818,8 @@ class Conversation(Base):
|
||||
override_model_configs = json.loads(self.override_model_configs)
|
||||
|
||||
if "model" in override_model_configs:
|
||||
app_model_config = AppModelConfig()
|
||||
app_model_config = app_model_config.from_model_config_dict(override_model_configs)
|
||||
# where is app_id?
|
||||
app_model_config = AppModelConfig(app_id=self.app_id).from_model_config_dict(override_model_configs)
|
||||
model_config = app_model_config.to_dict()
|
||||
else:
|
||||
model_config["configs"] = override_model_configs
|
||||
|
||||
@ -226,8 +226,7 @@ class Workflow(Base): # bug
|
||||
#
|
||||
# Currently, the following functions / methods would mutate the returned dict:
|
||||
#
|
||||
# - `_get_graph_and_variable_pool_of_single_iteration`.
|
||||
# - `_get_graph_and_variable_pool_of_single_loop`.
|
||||
# - `_get_graph_and_variable_pool_for_single_node_run`.
|
||||
return json.loads(self.graph) if self.graph else {}
|
||||
|
||||
def get_node_config_by_id(self, node_id: str) -> Mapping[str, Any]:
|
||||
|
||||
@ -31,7 +31,7 @@ dependencies = [
|
||||
"gunicorn~=23.0.0",
|
||||
"httpx[socks]~=0.27.0",
|
||||
"jieba==0.42.1",
|
||||
"json-repair>=0.41.1",
|
||||
"json-repair>=0.55.1",
|
||||
"jsonschema>=4.25.1",
|
||||
"langfuse~=2.51.3",
|
||||
"langsmith~=0.1.77",
|
||||
@ -93,6 +93,7 @@ dependencies = [
|
||||
"weaviate-client==4.17.0",
|
||||
"apscheduler>=3.11.0",
|
||||
"weave>=0.52.16",
|
||||
"fastopenapi[flask]>=0.7.0",
|
||||
]
|
||||
# Before adding new dependency, consider place it in
|
||||
# alphabet order (a-z) and suitable group.
|
||||
|
||||
@ -8,6 +8,7 @@
|
||||
],
|
||||
"typeCheckingMode": "strict",
|
||||
"allowedUntypedLibraries": [
|
||||
"fastopenapi",
|
||||
"flask_restx",
|
||||
"flask_login",
|
||||
"opentelemetry.instrumentation.celery",
|
||||
|
||||
@ -521,12 +521,10 @@ class AppDslService:
|
||||
raise ValueError("Missing model_config for chat/agent-chat/completion app")
|
||||
# Initialize or update model config
|
||||
if not app.app_model_config:
|
||||
app_model_config = AppModelConfig().from_model_config_dict(model_config)
|
||||
app_model_config = AppModelConfig(
|
||||
app_id=app.id, created_by=account.id, updated_by=account.id
|
||||
).from_model_config_dict(model_config)
|
||||
app_model_config.id = str(uuid4())
|
||||
app_model_config.app_id = app.id
|
||||
app_model_config.created_by = account.id
|
||||
app_model_config.updated_by = account.id
|
||||
|
||||
app.app_model_config_id = app_model_config.id
|
||||
|
||||
self._session.add(app_model_config)
|
||||
@ -783,15 +781,16 @@ class AppDslService:
|
||||
return dependencies
|
||||
|
||||
@classmethod
|
||||
def get_leaked_dependencies(cls, tenant_id: str, dsl_dependencies: list[dict]) -> list[PluginDependency]:
|
||||
def get_leaked_dependencies(
|
||||
cls, tenant_id: str, dsl_dependencies: list[PluginDependency]
|
||||
) -> list[PluginDependency]:
|
||||
"""
|
||||
Returns the leaked dependencies in current workspace
|
||||
"""
|
||||
dependencies = [PluginDependency.model_validate(dep) for dep in dsl_dependencies]
|
||||
if not dependencies:
|
||||
if not dsl_dependencies:
|
||||
return []
|
||||
|
||||
return DependenciesAnalysisService.get_leaked_dependencies(tenant_id=tenant_id, dependencies=dependencies)
|
||||
return DependenciesAnalysisService.get_leaked_dependencies(tenant_id=tenant_id, dependencies=dsl_dependencies)
|
||||
|
||||
@staticmethod
|
||||
def _generate_aes_key(tenant_id: str) -> bytes:
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Union
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
||||
@ -18,6 +20,9 @@ from services.errors.app import QuotaExceededError, WorkflowIdFormatError, Workf
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from controllers.console.app.workflow import LoopNodeRunPayload
|
||||
|
||||
|
||||
class AppGenerateService:
|
||||
@classmethod
|
||||
@ -165,7 +170,9 @@ class AppGenerateService:
|
||||
raise ValueError(f"Invalid app mode {app_model.mode}")
|
||||
|
||||
@classmethod
|
||||
def generate_single_loop(cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True):
|
||||
def generate_single_loop(
|
||||
cls, app_model: App, user: Account, node_id: str, args: LoopNodeRunPayload, streaming: bool = True
|
||||
):
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT:
|
||||
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
|
||||
return AdvancedChatAppGenerator.convert_to_event_stream(
|
||||
|
||||
@ -150,10 +150,9 @@ class AppService:
|
||||
db.session.flush()
|
||||
|
||||
if default_model_config:
|
||||
app_model_config = AppModelConfig(**default_model_config)
|
||||
app_model_config.app_id = app.id
|
||||
app_model_config.created_by = account.id
|
||||
app_model_config.updated_by = account.id
|
||||
app_model_config = AppModelConfig(
|
||||
**default_model_config, app_id=app.id, created_by=account.id, updated_by=account.id
|
||||
)
|
||||
db.session.add(app_model_config)
|
||||
db.session.flush()
|
||||
|
||||
|
||||
@ -4,6 +4,7 @@ from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from configs import dify_config
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from enums.hosted_provider import HostedTrialProvider
|
||||
from services.billing_service import BillingService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
|
||||
@ -170,6 +171,7 @@ class SystemFeatureModel(BaseModel):
|
||||
plugin_installation_permission: PluginInstallationPermissionModel = PluginInstallationPermissionModel()
|
||||
enable_change_email: bool = True
|
||||
plugin_manager: PluginManagerModel = PluginManagerModel()
|
||||
trial_models: list[str] = []
|
||||
enable_trial_app: bool = False
|
||||
enable_explore_banner: bool = False
|
||||
|
||||
@ -227,9 +229,21 @@ class FeatureService:
|
||||
system_features.is_allow_register = dify_config.ALLOW_REGISTER
|
||||
system_features.is_allow_create_workspace = dify_config.ALLOW_CREATE_WORKSPACE
|
||||
system_features.is_email_setup = dify_config.MAIL_TYPE is not None and dify_config.MAIL_TYPE != ""
|
||||
system_features.trial_models = cls._fulfill_trial_models_from_env()
|
||||
system_features.enable_trial_app = dify_config.ENABLE_TRIAL_APP
|
||||
system_features.enable_explore_banner = dify_config.ENABLE_EXPLORE_BANNER
|
||||
|
||||
@classmethod
|
||||
def _fulfill_trial_models_from_env(cls) -> list[str]:
|
||||
return [
|
||||
provider.value
|
||||
for provider in HostedTrialProvider
|
||||
if (
|
||||
getattr(dify_config, f"HOSTED_{provider.config_key}_PAID_ENABLED", False)
|
||||
and getattr(dify_config, f"HOSTED_{provider.config_key}_TRIAL_ENABLED", False)
|
||||
)
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _fulfill_params_from_env(cls, features: FeatureModel):
|
||||
features.can_replace_logo = dify_config.CAN_REPLACE_LOGO
|
||||
|
||||
@ -261,10 +261,9 @@ class MessageService:
|
||||
else:
|
||||
conversation_override_model_configs = json.loads(conversation.override_model_configs)
|
||||
app_model_config = AppModelConfig(
|
||||
id=conversation.app_model_config_id,
|
||||
app_id=app_model.id,
|
||||
)
|
||||
|
||||
app_model_config.id = conversation.app_model_config_id
|
||||
app_model_config = app_model_config.from_model_config_dict(conversation_override_model_configs)
|
||||
if not app_model_config:
|
||||
raise ValueError("did not find app model config")
|
||||
|
||||
@ -870,15 +870,16 @@ class RagPipelineDslService:
|
||||
return dependencies
|
||||
|
||||
@classmethod
|
||||
def get_leaked_dependencies(cls, tenant_id: str, dsl_dependencies: list[dict]) -> list[PluginDependency]:
|
||||
def get_leaked_dependencies(
|
||||
cls, tenant_id: str, dsl_dependencies: list[PluginDependency]
|
||||
) -> list[PluginDependency]:
|
||||
"""
|
||||
Returns the leaked dependencies in current workspace
|
||||
"""
|
||||
dependencies = [PluginDependency.model_validate(dep) for dep in dsl_dependencies]
|
||||
if not dependencies:
|
||||
if not dsl_dependencies:
|
||||
return []
|
||||
|
||||
return DependenciesAnalysisService.get_leaked_dependencies(tenant_id=tenant_id, dependencies=dependencies)
|
||||
return DependenciesAnalysisService.get_leaked_dependencies(tenant_id=tenant_id, dependencies=dsl_dependencies)
|
||||
|
||||
def _generate_aes_key(self, tenant_id: str) -> bytes:
|
||||
"""Generate AES key based on tenant_id"""
|
||||
|
||||
@ -44,7 +44,7 @@ class RagPipelineTransformService:
|
||||
doc_form = dataset.doc_form
|
||||
if not doc_form:
|
||||
return self._transform_to_empty_pipeline(dataset)
|
||||
retrieval_model = dataset.retrieval_model
|
||||
retrieval_model = RetrievalSetting.model_validate(dataset.retrieval_model) if dataset.retrieval_model else None
|
||||
pipeline_yaml = self._get_transform_yaml(doc_form, datasource_type, indexing_technique)
|
||||
# deal dependencies
|
||||
self._deal_dependencies(pipeline_yaml, dataset.tenant_id)
|
||||
@ -154,7 +154,12 @@ class RagPipelineTransformService:
|
||||
return node
|
||||
|
||||
def _deal_knowledge_index(
|
||||
self, dataset: Dataset, doc_form: str, indexing_technique: str | None, retrieval_model: dict, node: dict
|
||||
self,
|
||||
dataset: Dataset,
|
||||
doc_form: str,
|
||||
indexing_technique: str | None,
|
||||
retrieval_model: RetrievalSetting | None,
|
||||
node: dict,
|
||||
):
|
||||
knowledge_configuration_dict = node.get("data", {})
|
||||
knowledge_configuration = KnowledgeConfiguration.model_validate(knowledge_configuration_dict)
|
||||
@ -163,10 +168,9 @@ class RagPipelineTransformService:
|
||||
knowledge_configuration.embedding_model = dataset.embedding_model
|
||||
knowledge_configuration.embedding_model_provider = dataset.embedding_model_provider
|
||||
if retrieval_model:
|
||||
retrieval_setting = RetrievalSetting.model_validate(retrieval_model)
|
||||
if indexing_technique == "economy":
|
||||
retrieval_setting.search_method = RetrievalMethod.KEYWORD_SEARCH
|
||||
knowledge_configuration.retrieval_model = retrieval_setting
|
||||
retrieval_model.search_method = RetrievalMethod.KEYWORD_SEARCH
|
||||
knowledge_configuration.retrieval_model = retrieval_model
|
||||
else:
|
||||
dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump()
|
||||
|
||||
|
||||
@ -17,7 +17,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
RETRY_TIMES_OF_ONE_PLUGIN_IN_ONE_TENANT = 3
|
||||
CACHE_REDIS_KEY_PREFIX = "plugin_autoupgrade_check_task:cached_plugin_manifests:"
|
||||
CACHE_REDIS_TTL = 60 * 15 # 15 minutes
|
||||
CACHE_REDIS_TTL = 60 * 60 # 1 hour
|
||||
|
||||
|
||||
def _get_redis_cache_key(plugin_id: str) -> str:
|
||||
|
||||
@ -172,7 +172,6 @@ class TestAgentService:
|
||||
|
||||
# Create app model config
|
||||
app_model_config = AppModelConfig(
|
||||
id=fake.uuid4(),
|
||||
app_id=app.id,
|
||||
provider="openai",
|
||||
model_id="gpt-3.5-turbo",
|
||||
@ -180,6 +179,7 @@ class TestAgentService:
|
||||
model="gpt-3.5-turbo",
|
||||
agent_mode=json.dumps({"enabled": True, "strategy": "react", "tools": []}),
|
||||
)
|
||||
app_model_config.id = fake.uuid4()
|
||||
db.session.add(app_model_config)
|
||||
db.session.commit()
|
||||
|
||||
@ -413,7 +413,6 @@ class TestAgentService:
|
||||
|
||||
# Create app model config
|
||||
app_model_config = AppModelConfig(
|
||||
id=fake.uuid4(),
|
||||
app_id=app.id,
|
||||
provider="openai",
|
||||
model_id="gpt-3.5-turbo",
|
||||
@ -421,6 +420,7 @@ class TestAgentService:
|
||||
model="gpt-3.5-turbo",
|
||||
agent_mode=json.dumps({"enabled": True, "strategy": "react", "tools": []}),
|
||||
)
|
||||
app_model_config.id = fake.uuid4()
|
||||
db.session.add(app_model_config)
|
||||
db.session.commit()
|
||||
|
||||
@ -485,7 +485,6 @@ class TestAgentService:
|
||||
|
||||
# Create app model config
|
||||
app_model_config = AppModelConfig(
|
||||
id=fake.uuid4(),
|
||||
app_id=app.id,
|
||||
provider="openai",
|
||||
model_id="gpt-3.5-turbo",
|
||||
@ -493,6 +492,7 @@ class TestAgentService:
|
||||
model="gpt-3.5-turbo",
|
||||
agent_mode=json.dumps({"enabled": True, "strategy": "react", "tools": []}),
|
||||
)
|
||||
app_model_config.id = fake.uuid4()
|
||||
db.session.add(app_model_config)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
@ -226,26 +226,27 @@ class TestAppDslService:
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create model config for the app
|
||||
model_config = AppModelConfig()
|
||||
model_config.id = fake.uuid4()
|
||||
model_config.app_id = app.id
|
||||
model_config.provider = "openai"
|
||||
model_config.model_id = "gpt-3.5-turbo"
|
||||
model_config.model = json.dumps(
|
||||
{
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo",
|
||||
"mode": "chat",
|
||||
"completion_params": {
|
||||
"max_tokens": 1000,
|
||||
"temperature": 0.7,
|
||||
},
|
||||
}
|
||||
model_config = AppModelConfig(
|
||||
app_id=app.id,
|
||||
provider="openai",
|
||||
model_id="gpt-3.5-turbo",
|
||||
model=json.dumps(
|
||||
{
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo",
|
||||
"mode": "chat",
|
||||
"completion_params": {
|
||||
"max_tokens": 1000,
|
||||
"temperature": 0.7,
|
||||
},
|
||||
}
|
||||
),
|
||||
pre_prompt="You are a helpful assistant.",
|
||||
prompt_type="simple",
|
||||
created_by=account.id,
|
||||
updated_by=account.id,
|
||||
)
|
||||
model_config.pre_prompt = "You are a helpful assistant."
|
||||
model_config.prompt_type = "simple"
|
||||
model_config.created_by = account.id
|
||||
model_config.updated_by = account.id
|
||||
model_config.id = fake.uuid4()
|
||||
|
||||
# Set the app_model_config_id to link the config
|
||||
app.app_model_config_id = model_config.id
|
||||
|
||||
@ -925,24 +925,24 @@ class TestWorkflowService:
|
||||
# Create app model config (required for conversion)
|
||||
from models.model import AppModelConfig
|
||||
|
||||
app_model_config = AppModelConfig()
|
||||
app_model_config.id = fake.uuid4()
|
||||
app_model_config.app_id = app.id
|
||||
app_model_config.tenant_id = app.tenant_id
|
||||
app_model_config.provider = "openai"
|
||||
app_model_config.model_id = "gpt-3.5-turbo"
|
||||
# Set the model field directly - this is what model_dict property returns
|
||||
app_model_config.model = json.dumps(
|
||||
{
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo",
|
||||
"completion_params": {"max_tokens": 1000, "temperature": 0.7},
|
||||
}
|
||||
app_model_config = AppModelConfig(
|
||||
app_id=app.id,
|
||||
provider="openai",
|
||||
model_id="gpt-3.5-turbo",
|
||||
# Set the model field directly - this is what model_dict property returns
|
||||
model=json.dumps(
|
||||
{
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo",
|
||||
"completion_params": {"max_tokens": 1000, "temperature": 0.7},
|
||||
}
|
||||
),
|
||||
# Set pre_prompt for PromptTemplateConfigManager
|
||||
pre_prompt="You are a helpful assistant.",
|
||||
created_by=account.id,
|
||||
updated_by=account.id,
|
||||
)
|
||||
# Set pre_prompt for PromptTemplateConfigManager
|
||||
app_model_config.pre_prompt = "You are a helpful assistant."
|
||||
app_model_config.created_by = account.id
|
||||
app_model_config.updated_by = account.id
|
||||
app_model_config.id = fake.uuid4()
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
@ -987,24 +987,24 @@ class TestWorkflowService:
|
||||
# Create app model config (required for conversion)
|
||||
from models.model import AppModelConfig
|
||||
|
||||
app_model_config = AppModelConfig()
|
||||
app_model_config.id = fake.uuid4()
|
||||
app_model_config.app_id = app.id
|
||||
app_model_config.tenant_id = app.tenant_id
|
||||
app_model_config.provider = "openai"
|
||||
app_model_config.model_id = "gpt-3.5-turbo"
|
||||
# Set the model field directly - this is what model_dict property returns
|
||||
app_model_config.model = json.dumps(
|
||||
{
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo",
|
||||
"completion_params": {"max_tokens": 1000, "temperature": 0.7},
|
||||
}
|
||||
app_model_config = AppModelConfig(
|
||||
app_id=app.id,
|
||||
provider="openai",
|
||||
model_id="gpt-3.5-turbo",
|
||||
# Set the model field directly - this is what model_dict property returns
|
||||
model=json.dumps(
|
||||
{
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo",
|
||||
"completion_params": {"max_tokens": 1000, "temperature": 0.7},
|
||||
}
|
||||
),
|
||||
# Set pre_prompt for PromptTemplateConfigManager
|
||||
pre_prompt="Complete the following text:",
|
||||
created_by=account.id,
|
||||
updated_by=account.id,
|
||||
)
|
||||
# Set pre_prompt for PromptTemplateConfigManager
|
||||
app_model_config.pre_prompt = "Complete the following text:"
|
||||
app_model_config.created_by = account.id
|
||||
app_model_config.updated_by = account.id
|
||||
app_model_config.id = fake.uuid4()
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
|
||||
@ -0,0 +1,27 @@
|
||||
import builtins
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask.views import MethodView
|
||||
|
||||
from extensions import ext_fastopenapi
|
||||
|
||||
if not hasattr(builtins, "MethodView"):
|
||||
builtins.MethodView = MethodView # type: ignore[attr-defined]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app() -> Flask:
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
|
||||
def test_console_ping_fastopenapi_returns_pong(app: Flask):
|
||||
ext_fastopenapi.init_app(app)
|
||||
|
||||
client = app.test_client()
|
||||
response = client.get("/console/api/ping")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_json() == {"result": "pong"}
|
||||
@ -0,0 +1,56 @@
|
||||
import builtins
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask.views import MethodView
|
||||
|
||||
from extensions import ext_fastopenapi
|
||||
|
||||
if not hasattr(builtins, "MethodView"):
|
||||
builtins.MethodView = MethodView # type: ignore[attr-defined]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app() -> Flask:
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
|
||||
def test_console_setup_fastopenapi_get_not_started(app: Flask):
|
||||
ext_fastopenapi.init_app(app)
|
||||
|
||||
with (
|
||||
patch("controllers.console.setup.dify_config.EDITION", "SELF_HOSTED"),
|
||||
patch("controllers.console.setup.get_setup_status", return_value=None),
|
||||
):
|
||||
client = app.test_client()
|
||||
response = client.get("/console/api/setup")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_json() == {"step": "not_started", "setup_at": None}
|
||||
|
||||
|
||||
def test_console_setup_fastopenapi_post_success(app: Flask):
|
||||
ext_fastopenapi.init_app(app)
|
||||
|
||||
payload = {
|
||||
"email": "admin@example.com",
|
||||
"name": "Admin",
|
||||
"password": "Passw0rd1",
|
||||
"language": "en-US",
|
||||
}
|
||||
|
||||
with (
|
||||
patch("controllers.console.wraps.dify_config.EDITION", "SELF_HOSTED"),
|
||||
patch("controllers.console.setup.get_setup_status", return_value=None),
|
||||
patch("controllers.console.setup.TenantService.get_tenant_count", return_value=0),
|
||||
patch("controllers.console.setup.get_init_validate_status", return_value=True),
|
||||
patch("controllers.console.setup.RegisterService.setup"),
|
||||
):
|
||||
client = app.test_client()
|
||||
response = client.post("/console/api/setup", json=payload)
|
||||
|
||||
assert response.status_code == 201
|
||||
assert response.get_json() == {"result": "success"}
|
||||
@ -0,0 +1,35 @@
|
||||
import builtins
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask.views import MethodView
|
||||
|
||||
from configs import dify_config
|
||||
from extensions import ext_fastopenapi
|
||||
|
||||
if not hasattr(builtins, "MethodView"):
|
||||
builtins.MethodView = MethodView # type: ignore[attr-defined]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app() -> Flask:
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
|
||||
def test_console_version_fastopenapi_returns_current_version(app: Flask):
|
||||
ext_fastopenapi.init_app(app)
|
||||
|
||||
with patch("controllers.console.version.dify_config.CHECK_UPDATE_URL", None):
|
||||
client = app.test_client()
|
||||
response = client.get("/console/api/version", query_string={"current_version": "0.0.0"})
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.get_json()
|
||||
assert data["version"] == dify_config.project.version
|
||||
assert data["release_date"] == ""
|
||||
assert data["release_notes"] == ""
|
||||
assert data["can_auto_update"] is False
|
||||
assert "features" in data
|
||||
@ -1,39 +0,0 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
from controllers.console.setup import SetupApi
|
||||
|
||||
|
||||
class TestSetupApi:
|
||||
def test_post_lowercases_email_before_register(self):
|
||||
"""Ensure setup registration normalizes email casing."""
|
||||
payload = {
|
||||
"email": "Admin@Example.com",
|
||||
"name": "Admin User",
|
||||
"password": "ValidPass123!",
|
||||
"language": "en-US",
|
||||
}
|
||||
setup_api = SetupApi(api=None)
|
||||
|
||||
mock_console_ns = SimpleNamespace(payload=payload)
|
||||
|
||||
with (
|
||||
patch("controllers.console.setup.console_ns", mock_console_ns),
|
||||
patch("controllers.console.setup.get_setup_status", return_value=False),
|
||||
patch("controllers.console.setup.TenantService.get_tenant_count", return_value=0),
|
||||
patch("controllers.console.setup.get_init_validate_status", return_value=True),
|
||||
patch("controllers.console.setup.extract_remote_ip", return_value="127.0.0.1"),
|
||||
patch("controllers.console.setup.request", object()),
|
||||
patch("controllers.console.setup.RegisterService.setup") as mock_register,
|
||||
):
|
||||
response, status = setup_api.post()
|
||||
|
||||
assert response == {"result": "success"}
|
||||
assert status == 201
|
||||
mock_register.assert_called_once_with(
|
||||
email="admin@example.com",
|
||||
name=payload["name"],
|
||||
password=payload["password"],
|
||||
ip_address="127.0.0.1",
|
||||
language=payload["language"],
|
||||
)
|
||||
@ -0,0 +1,454 @@
|
||||
"""Test multimodal image output handling in BaseAppRunner."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.apps.base_app_queue_manager import PublishFrom
|
||||
from core.app.apps.base_app_runner import AppRunner
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import QueueMessageFileEvent
|
||||
from core.file.enums import FileTransferMethod, FileType
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||
from models.enums import CreatorUserRole
|
||||
|
||||
|
||||
class TestBaseAppRunnerMultimodal:
|
||||
"""Test that BaseAppRunner correctly handles multimodal image content."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user_id(self):
|
||||
"""Mock user ID."""
|
||||
return str(uuid4())
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tenant_id(self):
|
||||
"""Mock tenant ID."""
|
||||
return str(uuid4())
|
||||
|
||||
@pytest.fixture
|
||||
def mock_message_id(self):
|
||||
"""Mock message ID."""
|
||||
return str(uuid4())
|
||||
|
||||
@pytest.fixture
|
||||
def mock_queue_manager(self):
|
||||
"""Create a mock queue manager."""
|
||||
manager = MagicMock()
|
||||
manager.invoke_from = InvokeFrom.SERVICE_API
|
||||
return manager
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tool_file(self):
|
||||
"""Create a mock tool file."""
|
||||
tool_file = MagicMock()
|
||||
tool_file.id = str(uuid4())
|
||||
return tool_file
|
||||
|
||||
@pytest.fixture
|
||||
def mock_message_file(self):
|
||||
"""Create a mock message file."""
|
||||
message_file = MagicMock()
|
||||
message_file.id = str(uuid4())
|
||||
return message_file
|
||||
|
||||
def test_handle_multimodal_image_content_with_url(
|
||||
self,
|
||||
mock_user_id,
|
||||
mock_tenant_id,
|
||||
mock_message_id,
|
||||
mock_queue_manager,
|
||||
mock_tool_file,
|
||||
mock_message_file,
|
||||
):
|
||||
"""Test handling image from URL."""
|
||||
# Arrange
|
||||
image_url = "http://example.com/image.png"
|
||||
content = ImagePromptMessageContent(
|
||||
url=image_url,
|
||||
format="png",
|
||||
mime_type="image/png",
|
||||
)
|
||||
|
||||
with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class:
|
||||
# Setup mock tool file manager
|
||||
mock_mgr = MagicMock()
|
||||
mock_mgr.create_file_by_url.return_value = mock_tool_file
|
||||
mock_mgr_class.return_value = mock_mgr
|
||||
|
||||
with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class:
|
||||
# Setup mock message file
|
||||
mock_msg_file_class.return_value = mock_message_file
|
||||
|
||||
with patch("core.app.apps.base_app_runner.db.session") as mock_session:
|
||||
mock_session.add = MagicMock()
|
||||
mock_session.commit = MagicMock()
|
||||
mock_session.refresh = MagicMock()
|
||||
|
||||
# Act
|
||||
# Create a mock runner with the method bound
|
||||
runner = MagicMock()
|
||||
|
||||
method = AppRunner._handle_multimodal_image_content
|
||||
runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs)
|
||||
|
||||
runner._handle_multimodal_image_content(
|
||||
content=content,
|
||||
message_id=mock_message_id,
|
||||
user_id=mock_user_id,
|
||||
tenant_id=mock_tenant_id,
|
||||
queue_manager=mock_queue_manager,
|
||||
)
|
||||
|
||||
# Assert
|
||||
# Verify tool file was created from URL
|
||||
mock_mgr.create_file_by_url.assert_called_once_with(
|
||||
user_id=mock_user_id,
|
||||
tenant_id=mock_tenant_id,
|
||||
file_url=image_url,
|
||||
conversation_id=None,
|
||||
)
|
||||
|
||||
# Verify message file was created with correct parameters
|
||||
mock_msg_file_class.assert_called_once()
|
||||
call_kwargs = mock_msg_file_class.call_args[1]
|
||||
assert call_kwargs["message_id"] == mock_message_id
|
||||
assert call_kwargs["type"] == FileType.IMAGE
|
||||
assert call_kwargs["transfer_method"] == FileTransferMethod.TOOL_FILE
|
||||
assert call_kwargs["belongs_to"] == "assistant"
|
||||
assert call_kwargs["created_by"] == mock_user_id
|
||||
|
||||
# Verify database operations
|
||||
mock_session.add.assert_called_once_with(mock_message_file)
|
||||
mock_session.commit.assert_called_once()
|
||||
mock_session.refresh.assert_called_once_with(mock_message_file)
|
||||
|
||||
# Verify event was published
|
||||
mock_queue_manager.publish.assert_called_once()
|
||||
publish_call = mock_queue_manager.publish.call_args
|
||||
assert isinstance(publish_call[0][0], QueueMessageFileEvent)
|
||||
assert publish_call[0][0].message_file_id == mock_message_file.id
|
||||
# publish_from might be passed as positional or keyword argument
|
||||
assert (
|
||||
publish_call[0][1] == PublishFrom.APPLICATION_MANAGER
|
||||
or publish_call.kwargs.get("publish_from") == PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def test_handle_multimodal_image_content_with_base64(
|
||||
self,
|
||||
mock_user_id,
|
||||
mock_tenant_id,
|
||||
mock_message_id,
|
||||
mock_queue_manager,
|
||||
mock_tool_file,
|
||||
mock_message_file,
|
||||
):
|
||||
"""Test handling image from base64 data."""
|
||||
# Arrange
|
||||
import base64
|
||||
|
||||
# Create a small test image (1x1 PNG)
|
||||
test_image_data = base64.b64encode(
|
||||
b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde"
|
||||
).decode()
|
||||
content = ImagePromptMessageContent(
|
||||
base64_data=test_image_data,
|
||||
format="png",
|
||||
mime_type="image/png",
|
||||
)
|
||||
|
||||
with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class:
|
||||
# Setup mock tool file manager
|
||||
mock_mgr = MagicMock()
|
||||
mock_mgr.create_file_by_raw.return_value = mock_tool_file
|
||||
mock_mgr_class.return_value = mock_mgr
|
||||
|
||||
with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class:
|
||||
# Setup mock message file
|
||||
mock_msg_file_class.return_value = mock_message_file
|
||||
|
||||
with patch("core.app.apps.base_app_runner.db.session") as mock_session:
|
||||
mock_session.add = MagicMock()
|
||||
mock_session.commit = MagicMock()
|
||||
mock_session.refresh = MagicMock()
|
||||
|
||||
# Act
|
||||
# Create a mock runner with the method bound
|
||||
runner = MagicMock()
|
||||
method = AppRunner._handle_multimodal_image_content
|
||||
runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs)
|
||||
|
||||
runner._handle_multimodal_image_content(
|
||||
content=content,
|
||||
message_id=mock_message_id,
|
||||
user_id=mock_user_id,
|
||||
tenant_id=mock_tenant_id,
|
||||
queue_manager=mock_queue_manager,
|
||||
)
|
||||
|
||||
# Assert
|
||||
# Verify tool file was created from base64
|
||||
mock_mgr.create_file_by_raw.assert_called_once()
|
||||
call_kwargs = mock_mgr.create_file_by_raw.call_args[1]
|
||||
assert call_kwargs["user_id"] == mock_user_id
|
||||
assert call_kwargs["tenant_id"] == mock_tenant_id
|
||||
assert call_kwargs["conversation_id"] is None
|
||||
assert "file_binary" in call_kwargs
|
||||
assert call_kwargs["mimetype"] == "image/png"
|
||||
assert call_kwargs["filename"].startswith("generated_image")
|
||||
assert call_kwargs["filename"].endswith(".png")
|
||||
|
||||
# Verify message file was created
|
||||
mock_msg_file_class.assert_called_once()
|
||||
|
||||
# Verify database operations
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
mock_session.refresh.assert_called_once()
|
||||
|
||||
# Verify event was published
|
||||
mock_queue_manager.publish.assert_called_once()
|
||||
|
||||
def test_handle_multimodal_image_content_with_base64_data_uri(
|
||||
self,
|
||||
mock_user_id,
|
||||
mock_tenant_id,
|
||||
mock_message_id,
|
||||
mock_queue_manager,
|
||||
mock_tool_file,
|
||||
mock_message_file,
|
||||
):
|
||||
"""Test handling image from base64 data with URI prefix."""
|
||||
# Arrange
|
||||
# Data URI format: data:image/png;base64,<base64_data>
|
||||
test_image_data = (
|
||||
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="
|
||||
)
|
||||
content = ImagePromptMessageContent(
|
||||
base64_data=f"data:image/png;base64,{test_image_data}",
|
||||
format="png",
|
||||
mime_type="image/png",
|
||||
)
|
||||
|
||||
with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class:
|
||||
# Setup mock tool file manager
|
||||
mock_mgr = MagicMock()
|
||||
mock_mgr.create_file_by_raw.return_value = mock_tool_file
|
||||
mock_mgr_class.return_value = mock_mgr
|
||||
|
||||
with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class:
|
||||
# Setup mock message file
|
||||
mock_msg_file_class.return_value = mock_message_file
|
||||
|
||||
with patch("core.app.apps.base_app_runner.db.session") as mock_session:
|
||||
mock_session.add = MagicMock()
|
||||
mock_session.commit = MagicMock()
|
||||
mock_session.refresh = MagicMock()
|
||||
|
||||
# Act
|
||||
# Create a mock runner with the method bound
|
||||
runner = MagicMock()
|
||||
method = AppRunner._handle_multimodal_image_content
|
||||
runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs)
|
||||
|
||||
runner._handle_multimodal_image_content(
|
||||
content=content,
|
||||
message_id=mock_message_id,
|
||||
user_id=mock_user_id,
|
||||
tenant_id=mock_tenant_id,
|
||||
queue_manager=mock_queue_manager,
|
||||
)
|
||||
|
||||
# Assert - verify that base64 data was extracted correctly (without prefix)
|
||||
mock_mgr.create_file_by_raw.assert_called_once()
|
||||
call_kwargs = mock_mgr.create_file_by_raw.call_args[1]
|
||||
# The base64 data should be decoded, so we check the binary was passed
|
||||
assert "file_binary" in call_kwargs
|
||||
|
||||
def test_handle_multimodal_image_content_without_url_or_base64(
|
||||
self,
|
||||
mock_user_id,
|
||||
mock_tenant_id,
|
||||
mock_message_id,
|
||||
mock_queue_manager,
|
||||
):
|
||||
"""Test handling image content without URL or base64 data."""
|
||||
# Arrange
|
||||
content = ImagePromptMessageContent(
|
||||
url="",
|
||||
base64_data="",
|
||||
format="png",
|
||||
mime_type="image/png",
|
||||
)
|
||||
|
||||
with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class:
|
||||
with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class:
|
||||
with patch("core.app.apps.base_app_runner.db.session") as mock_session:
|
||||
# Act
|
||||
# Create a mock runner with the method bound
|
||||
runner = MagicMock()
|
||||
method = AppRunner._handle_multimodal_image_content
|
||||
runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs)
|
||||
|
||||
runner._handle_multimodal_image_content(
|
||||
content=content,
|
||||
message_id=mock_message_id,
|
||||
user_id=mock_user_id,
|
||||
tenant_id=mock_tenant_id,
|
||||
queue_manager=mock_queue_manager,
|
||||
)
|
||||
|
||||
# Assert - should not create any files or publish events
|
||||
mock_mgr_class.assert_not_called()
|
||||
mock_msg_file_class.assert_not_called()
|
||||
mock_session.add.assert_not_called()
|
||||
mock_queue_manager.publish.assert_not_called()
|
||||
|
||||
def test_handle_multimodal_image_content_with_error(
|
||||
self,
|
||||
mock_user_id,
|
||||
mock_tenant_id,
|
||||
mock_message_id,
|
||||
mock_queue_manager,
|
||||
):
|
||||
"""Test handling image content when an error occurs."""
|
||||
# Arrange
|
||||
image_url = "http://example.com/image.png"
|
||||
content = ImagePromptMessageContent(
|
||||
url=image_url,
|
||||
format="png",
|
||||
mime_type="image/png",
|
||||
)
|
||||
|
||||
with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class:
|
||||
# Setup mock to raise exception
|
||||
mock_mgr = MagicMock()
|
||||
mock_mgr.create_file_by_url.side_effect = Exception("Network error")
|
||||
mock_mgr_class.return_value = mock_mgr
|
||||
|
||||
with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class:
|
||||
with patch("core.app.apps.base_app_runner.db.session") as mock_session:
|
||||
# Act
|
||||
# Create a mock runner with the method bound
|
||||
runner = MagicMock()
|
||||
method = AppRunner._handle_multimodal_image_content
|
||||
runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs)
|
||||
|
||||
# Should not raise exception, just log it
|
||||
runner._handle_multimodal_image_content(
|
||||
content=content,
|
||||
message_id=mock_message_id,
|
||||
user_id=mock_user_id,
|
||||
tenant_id=mock_tenant_id,
|
||||
queue_manager=mock_queue_manager,
|
||||
)
|
||||
|
||||
# Assert - should not create message file or publish event on error
|
||||
mock_msg_file_class.assert_not_called()
|
||||
mock_session.add.assert_not_called()
|
||||
mock_queue_manager.publish.assert_not_called()
|
||||
|
||||
def test_handle_multimodal_image_content_debugger_mode(
|
||||
self,
|
||||
mock_user_id,
|
||||
mock_tenant_id,
|
||||
mock_message_id,
|
||||
mock_queue_manager,
|
||||
mock_tool_file,
|
||||
mock_message_file,
|
||||
):
|
||||
"""Test that debugger mode sets correct created_by_role."""
|
||||
# Arrange
|
||||
image_url = "http://example.com/image.png"
|
||||
content = ImagePromptMessageContent(
|
||||
url=image_url,
|
||||
format="png",
|
||||
mime_type="image/png",
|
||||
)
|
||||
mock_queue_manager.invoke_from = InvokeFrom.DEBUGGER
|
||||
|
||||
with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class:
|
||||
# Setup mock tool file manager
|
||||
mock_mgr = MagicMock()
|
||||
mock_mgr.create_file_by_url.return_value = mock_tool_file
|
||||
mock_mgr_class.return_value = mock_mgr
|
||||
|
||||
with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class:
|
||||
# Setup mock message file
|
||||
mock_msg_file_class.return_value = mock_message_file
|
||||
|
||||
with patch("core.app.apps.base_app_runner.db.session") as mock_session:
|
||||
mock_session.add = MagicMock()
|
||||
mock_session.commit = MagicMock()
|
||||
mock_session.refresh = MagicMock()
|
||||
|
||||
# Act
|
||||
# Create a mock runner with the method bound
|
||||
runner = MagicMock()
|
||||
method = AppRunner._handle_multimodal_image_content
|
||||
runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs)
|
||||
|
||||
runner._handle_multimodal_image_content(
|
||||
content=content,
|
||||
message_id=mock_message_id,
|
||||
user_id=mock_user_id,
|
||||
tenant_id=mock_tenant_id,
|
||||
queue_manager=mock_queue_manager,
|
||||
)
|
||||
|
||||
# Assert - verify created_by_role is ACCOUNT for debugger mode
|
||||
call_kwargs = mock_msg_file_class.call_args[1]
|
||||
assert call_kwargs["created_by_role"] == CreatorUserRole.ACCOUNT
|
||||
|
||||
def test_handle_multimodal_image_content_service_api_mode(
|
||||
self,
|
||||
mock_user_id,
|
||||
mock_tenant_id,
|
||||
mock_message_id,
|
||||
mock_queue_manager,
|
||||
mock_tool_file,
|
||||
mock_message_file,
|
||||
):
|
||||
"""Test that service API mode sets correct created_by_role."""
|
||||
# Arrange
|
||||
image_url = "http://example.com/image.png"
|
||||
content = ImagePromptMessageContent(
|
||||
url=image_url,
|
||||
format="png",
|
||||
mime_type="image/png",
|
||||
)
|
||||
mock_queue_manager.invoke_from = InvokeFrom.SERVICE_API
|
||||
|
||||
with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class:
|
||||
# Setup mock tool file manager
|
||||
mock_mgr = MagicMock()
|
||||
mock_mgr.create_file_by_url.return_value = mock_tool_file
|
||||
mock_mgr_class.return_value = mock_mgr
|
||||
|
||||
with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class:
|
||||
# Setup mock message file
|
||||
mock_msg_file_class.return_value = mock_message_file
|
||||
|
||||
with patch("core.app.apps.base_app_runner.db.session") as mock_session:
|
||||
mock_session.add = MagicMock()
|
||||
mock_session.commit = MagicMock()
|
||||
mock_session.refresh = MagicMock()
|
||||
|
||||
# Act
|
||||
# Create a mock runner with the method bound
|
||||
runner = MagicMock()
|
||||
method = AppRunner._handle_multimodal_image_content
|
||||
runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs)
|
||||
|
||||
runner._handle_multimodal_image_content(
|
||||
content=content,
|
||||
message_id=mock_message_id,
|
||||
user_id=mock_user_id,
|
||||
tenant_id=mock_tenant_id,
|
||||
queue_manager=mock_queue_manager,
|
||||
)
|
||||
|
||||
# Assert - verify created_by_role is END_USER for service API
|
||||
call_kwargs = mock_msg_file_class.call_args[1]
|
||||
assert call_kwargs["created_by_role"] == CreatorUserRole.END_USER
|
||||
@ -1,7 +1,6 @@
|
||||
"""Unit tests for the message cycle manager optimization."""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import ANY, Mock, patch
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from flask import current_app
|
||||
@ -28,17 +27,14 @@ class TestMessageCycleManagerOptimization:
|
||||
|
||||
def test_get_message_event_type_with_message_file(self, message_cycle_manager):
|
||||
"""Test get_message_event_type returns MESSAGE_FILE when message has files."""
|
||||
with (
|
||||
patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class,
|
||||
patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())),
|
||||
):
|
||||
with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory:
|
||||
# Setup mock session and message file
|
||||
mock_session = Mock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session_factory.create_session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_message_file = Mock()
|
||||
# Current implementation uses session.query(...).scalar()
|
||||
mock_session.query.return_value.scalar.return_value = mock_message_file
|
||||
# Current implementation uses session.scalar(select(...))
|
||||
mock_session.scalar.return_value = mock_message_file
|
||||
|
||||
# Execute
|
||||
with current_app.app_context():
|
||||
@ -46,19 +42,16 @@ class TestMessageCycleManagerOptimization:
|
||||
|
||||
# Assert
|
||||
assert result == StreamEvent.MESSAGE_FILE
|
||||
mock_session.query.return_value.scalar.assert_called_once()
|
||||
mock_session.scalar.assert_called_once()
|
||||
|
||||
def test_get_message_event_type_without_message_file(self, message_cycle_manager):
|
||||
"""Test get_message_event_type returns MESSAGE when message has no files."""
|
||||
with (
|
||||
patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class,
|
||||
patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())),
|
||||
):
|
||||
with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory:
|
||||
# Setup mock session and no message file
|
||||
mock_session = Mock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
# Current implementation uses session.query(...).scalar()
|
||||
mock_session.query.return_value.scalar.return_value = None
|
||||
mock_session_factory.create_session.return_value.__enter__.return_value = mock_session
|
||||
# Current implementation uses session.scalar(select(...))
|
||||
mock_session.scalar.return_value = None
|
||||
|
||||
# Execute
|
||||
with current_app.app_context():
|
||||
@ -66,21 +59,18 @@ class TestMessageCycleManagerOptimization:
|
||||
|
||||
# Assert
|
||||
assert result == StreamEvent.MESSAGE
|
||||
mock_session.query.return_value.scalar.assert_called_once()
|
||||
mock_session.scalar.assert_called_once()
|
||||
|
||||
def test_message_to_stream_response_with_precomputed_event_type(self, message_cycle_manager):
|
||||
"""MessageCycleManager.message_to_stream_response expects a valid event_type; callers should precompute it."""
|
||||
with (
|
||||
patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class,
|
||||
patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())),
|
||||
):
|
||||
with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory:
|
||||
# Setup mock session and message file
|
||||
mock_session = Mock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session_factory.create_session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_message_file = Mock()
|
||||
# Current implementation uses session.query(...).scalar()
|
||||
mock_session.query.return_value.scalar.return_value = mock_message_file
|
||||
# Current implementation uses session.scalar(select(...))
|
||||
mock_session.scalar.return_value = mock_message_file
|
||||
|
||||
# Execute: compute event type once, then pass to message_to_stream_response
|
||||
with current_app.app_context():
|
||||
@ -94,11 +84,11 @@ class TestMessageCycleManagerOptimization:
|
||||
assert result.answer == "Hello world"
|
||||
assert result.id == "test-message-id"
|
||||
assert result.event == StreamEvent.MESSAGE_FILE
|
||||
mock_session.query.return_value.scalar.assert_called_once()
|
||||
mock_session.scalar.assert_called_once()
|
||||
|
||||
def test_message_to_stream_response_with_event_type_skips_query(self, message_cycle_manager):
|
||||
"""Test that message_to_stream_response skips database query when event_type is provided."""
|
||||
with patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class:
|
||||
with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory:
|
||||
# Execute with event_type provided
|
||||
result = message_cycle_manager.message_to_stream_response(
|
||||
answer="Hello world", message_id="test-message-id", event_type=StreamEvent.MESSAGE
|
||||
@ -109,8 +99,8 @@ class TestMessageCycleManagerOptimization:
|
||||
assert result.answer == "Hello world"
|
||||
assert result.id == "test-message-id"
|
||||
assert result.event == StreamEvent.MESSAGE
|
||||
# Should not query database when event_type is provided
|
||||
mock_session_class.assert_not_called()
|
||||
# Should not open a session when event_type is provided
|
||||
mock_session_factory.create_session.assert_not_called()
|
||||
|
||||
def test_message_to_stream_response_with_from_variable_selector(self, message_cycle_manager):
|
||||
"""Test message_to_stream_response with from_variable_selector parameter."""
|
||||
@ -130,24 +120,21 @@ class TestMessageCycleManagerOptimization:
|
||||
def test_optimization_usage_example(self, message_cycle_manager):
|
||||
"""Test the optimization pattern that should be used by callers."""
|
||||
# Step 1: Get event type once (this queries database)
|
||||
with (
|
||||
patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class,
|
||||
patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())),
|
||||
):
|
||||
with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory:
|
||||
mock_session = Mock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
# Current implementation uses session.query(...).scalar()
|
||||
mock_session.query.return_value.scalar.return_value = None # No files
|
||||
mock_session_factory.create_session.return_value.__enter__.return_value = mock_session
|
||||
# Current implementation uses session.scalar(select(...))
|
||||
mock_session.scalar.return_value = None # No files
|
||||
with current_app.app_context():
|
||||
event_type = message_cycle_manager.get_message_event_type("test-message-id")
|
||||
|
||||
# Should query database once
|
||||
mock_session_class.assert_called_once_with(ANY, expire_on_commit=False)
|
||||
# Should open session once
|
||||
mock_session_factory.create_session.assert_called_once()
|
||||
assert event_type == StreamEvent.MESSAGE
|
||||
|
||||
# Step 2: Use event_type for multiple calls (no additional queries)
|
||||
with patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class:
|
||||
mock_session_class.return_value.__enter__.return_value = Mock()
|
||||
with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory:
|
||||
mock_session_factory.create_session.return_value.__enter__.return_value = Mock()
|
||||
|
||||
chunk1_response = message_cycle_manager.message_to_stream_response(
|
||||
answer="Chunk 1", message_id="test-message-id", event_type=event_type
|
||||
@ -157,8 +144,8 @@ class TestMessageCycleManagerOptimization:
|
||||
answer="Chunk 2", message_id="test-message-id", event_type=event_type
|
||||
)
|
||||
|
||||
# Should not query database again
|
||||
mock_session_class.assert_not_called()
|
||||
# Should not open session again when event_type provided
|
||||
mock_session_factory.create_session.assert_not_called()
|
||||
|
||||
assert chunk1_response.event == StreamEvent.MESSAGE
|
||||
assert chunk2_response.event == StreamEvent.MESSAGE
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage
|
||||
from core.model_runtime.model_providers.__base.large_language_model import _increase_tool_call
|
||||
|
||||
@ -97,3 +99,14 @@ def test__increase_tool_call():
|
||||
mock_id_generator.side_effect = [_exp_case.id for _exp_case in EXPECTED_CASE_4]
|
||||
with patch("core.model_runtime.model_providers.__base.large_language_model._gen_tool_call_id", mock_id_generator):
|
||||
_run_case(INPUTS_CASE_4, EXPECTED_CASE_4)
|
||||
|
||||
|
||||
def test__increase_tool_call__no_id_no_name_first_delta_should_raise():
|
||||
inputs = [
|
||||
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')),
|
||||
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments='"value"}')),
|
||||
]
|
||||
actual: list[ToolCall] = []
|
||||
with patch("core.model_runtime.model_providers.__base.large_language_model._gen_tool_call_id", MagicMock()):
|
||||
with pytest.raises(ValueError):
|
||||
_increase_tool_call(inputs, actual)
|
||||
|
||||
@ -0,0 +1,103 @@
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.model_providers.__base.large_language_model import _normalize_non_stream_plugin_result
|
||||
|
||||
|
||||
def _make_chunk(
|
||||
*,
|
||||
model: str = "test-model",
|
||||
content: str | list[TextPromptMessageContent] | None,
|
||||
tool_calls: list[AssistantPromptMessage.ToolCall] | None = None,
|
||||
usage: LLMUsage | None = None,
|
||||
system_fingerprint: str | None = None,
|
||||
) -> LLMResultChunk:
|
||||
message = AssistantPromptMessage(content=content, tool_calls=tool_calls or [])
|
||||
delta = LLMResultChunkDelta(index=0, message=message, usage=usage)
|
||||
return LLMResultChunk(model=model, delta=delta, system_fingerprint=system_fingerprint)
|
||||
|
||||
|
||||
def test__normalize_non_stream_plugin_result__from_first_chunk_str_content_and_tool_calls():
|
||||
prompt_messages = [UserPromptMessage(content="hi")]
|
||||
|
||||
tool_calls = [
|
||||
AssistantPromptMessage.ToolCall(
|
||||
id="1",
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="func_foo", arguments=""),
|
||||
),
|
||||
AssistantPromptMessage.ToolCall(
|
||||
id="",
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments='{"arg1": '),
|
||||
),
|
||||
AssistantPromptMessage.ToolCall(
|
||||
id="",
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments='"value"}'),
|
||||
),
|
||||
]
|
||||
|
||||
usage = LLMUsage.empty_usage().model_copy(update={"prompt_tokens": 1, "total_tokens": 1})
|
||||
chunk = _make_chunk(content="hello", tool_calls=tool_calls, usage=usage, system_fingerprint="fp-1")
|
||||
|
||||
result = _normalize_non_stream_plugin_result(
|
||||
model="test-model", prompt_messages=prompt_messages, result=iter([chunk])
|
||||
)
|
||||
|
||||
assert result.model == "test-model"
|
||||
assert result.prompt_messages == prompt_messages
|
||||
assert result.message.content == "hello"
|
||||
assert result.usage.prompt_tokens == 1
|
||||
assert result.system_fingerprint == "fp-1"
|
||||
assert result.message.tool_calls == [
|
||||
AssistantPromptMessage.ToolCall(
|
||||
id="1",
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}'),
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def test__normalize_non_stream_plugin_result__from_first_chunk_list_content():
|
||||
prompt_messages = [UserPromptMessage(content="hi")]
|
||||
|
||||
content_list = [TextPromptMessageContent(data="a"), TextPromptMessageContent(data="b")]
|
||||
chunk = _make_chunk(content=content_list, usage=LLMUsage.empty_usage())
|
||||
|
||||
result = _normalize_non_stream_plugin_result(
|
||||
model="test-model", prompt_messages=prompt_messages, result=iter([chunk])
|
||||
)
|
||||
|
||||
assert result.message.content == content_list
|
||||
|
||||
|
||||
def test__normalize_non_stream_plugin_result__passthrough_llm_result():
|
||||
prompt_messages = [UserPromptMessage(content="hi")]
|
||||
llm_result = LLMResult(
|
||||
model="test-model",
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(content="ok"),
|
||||
usage=LLMUsage.empty_usage(),
|
||||
)
|
||||
|
||||
assert (
|
||||
_normalize_non_stream_plugin_result(model="test-model", prompt_messages=prompt_messages, result=llm_result)
|
||||
== llm_result
|
||||
)
|
||||
|
||||
|
||||
def test__normalize_non_stream_plugin_result__empty_iterator_defaults():
|
||||
prompt_messages = [UserPromptMessage(content="hi")]
|
||||
|
||||
result = _normalize_non_stream_plugin_result(model="test-model", prompt_messages=prompt_messages, result=iter([]))
|
||||
|
||||
assert result.model == "test-model"
|
||||
assert result.prompt_messages == prompt_messages
|
||||
assert result.message.content == []
|
||||
assert result.message.tool_calls == []
|
||||
assert result.usage == LLMUsage.empty_usage()
|
||||
assert result.system_fingerprint is None
|
||||
@ -475,3 +475,130 @@ def test_valid_api_key_works():
|
||||
headers = executor._assembling_headers()
|
||||
assert "Authorization" in headers
|
||||
assert headers["Authorization"] == "Bearer valid-api-key-123"
|
||||
|
||||
|
||||
def test_executor_with_json_body_and_unquoted_uuid_variable():
|
||||
"""Test that unquoted UUID variables are correctly handled in JSON body.
|
||||
|
||||
This test verifies the fix for issue #31436 where json_repair would truncate
|
||||
certain UUID patterns (like 57eeeeb1-...) when they appeared as unquoted values.
|
||||
"""
|
||||
# UUID that triggers the json_repair truncation bug
|
||||
test_uuid = "57eeeeb1-450b-482c-81b9-4be77e95dee2"
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={},
|
||||
)
|
||||
variable_pool.add(["pre_node_id", "uuid"], test_uuid)
|
||||
|
||||
node_data = HttpRequestNodeData(
|
||||
title="Test JSON Body with Unquoted UUID Variable",
|
||||
method="post",
|
||||
url="https://api.example.com/data",
|
||||
authorization=HttpRequestNodeAuthorization(type="no-auth"),
|
||||
headers="Content-Type: application/json",
|
||||
params="",
|
||||
body=HttpRequestNodeBody(
|
||||
type="json",
|
||||
data=[
|
||||
BodyData(
|
||||
key="",
|
||||
type="text",
|
||||
# UUID variable without quotes - this is the problematic case
|
||||
value='{"rowId": {{#pre_node_id.uuid#}}}',
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
executor = Executor(
|
||||
node_data=node_data,
|
||||
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30),
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
# The UUID should be preserved in full, not truncated
|
||||
assert executor.json == {"rowId": test_uuid}
|
||||
assert len(executor.json["rowId"]) == len(test_uuid)
|
||||
|
||||
|
||||
def test_executor_with_json_body_and_unquoted_uuid_with_newlines():
|
||||
"""Test that unquoted UUID variables with newlines in JSON are handled correctly.
|
||||
|
||||
This is a specific case from issue #31436 where the JSON body contains newlines.
|
||||
"""
|
||||
test_uuid = "57eeeeb1-450b-482c-81b9-4be77e95dee2"
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={},
|
||||
)
|
||||
variable_pool.add(["pre_node_id", "uuid"], test_uuid)
|
||||
|
||||
node_data = HttpRequestNodeData(
|
||||
title="Test JSON Body with Unquoted UUID and Newlines",
|
||||
method="post",
|
||||
url="https://api.example.com/data",
|
||||
authorization=HttpRequestNodeAuthorization(type="no-auth"),
|
||||
headers="Content-Type: application/json",
|
||||
params="",
|
||||
body=HttpRequestNodeBody(
|
||||
type="json",
|
||||
data=[
|
||||
BodyData(
|
||||
key="",
|
||||
type="text",
|
||||
# JSON with newlines and unquoted UUID variable
|
||||
value='{\n"rowId": {{#pre_node_id.uuid#}}\n}',
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
executor = Executor(
|
||||
node_data=node_data,
|
||||
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30),
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
# The UUID should be preserved in full
|
||||
assert executor.json == {"rowId": test_uuid}
|
||||
|
||||
|
||||
def test_executor_with_json_body_preserves_numbers_and_strings():
|
||||
"""Test that numbers are preserved and string values are properly quoted."""
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={},
|
||||
)
|
||||
variable_pool.add(["node", "count"], 42)
|
||||
variable_pool.add(["node", "id"], "abc-123")
|
||||
|
||||
node_data = HttpRequestNodeData(
|
||||
title="Test JSON Body with mixed types",
|
||||
method="post",
|
||||
url="https://api.example.com/data",
|
||||
authorization=HttpRequestNodeAuthorization(type="no-auth"),
|
||||
headers="",
|
||||
params="",
|
||||
body=HttpRequestNodeBody(
|
||||
type="json",
|
||||
data=[
|
||||
BodyData(
|
||||
key="",
|
||||
type="text",
|
||||
value='{"count": {{#node.count#}}, "id": {{#node.id#}}}',
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
executor = Executor(
|
||||
node_data=node_data,
|
||||
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30),
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
assert executor.json["count"] == 42
|
||||
assert executor.json["id"] == "abc-123"
|
||||
|
||||
4671
api/uv.lock
generated
4671
api/uv.lock
generated
File diff suppressed because it is too large
Load Diff
28
dev/setup
Executable file
28
dev/setup
Executable file
@ -0,0 +1,28 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
SCRIPT_DIR="$(dirname "$(realpath "$0")")"
|
||||
ROOT="$(dirname "$SCRIPT_DIR")"
|
||||
|
||||
API_ENV_EXAMPLE="$ROOT/api/.env.example"
|
||||
API_ENV="$ROOT/api/.env"
|
||||
WEB_ENV_EXAMPLE="$ROOT/web/.env.example"
|
||||
WEB_ENV="$ROOT/web/.env.local"
|
||||
MIDDLEWARE_ENV_EXAMPLE="$ROOT/docker/middleware.env.example"
|
||||
MIDDLEWARE_ENV="$ROOT/docker/middleware.env"
|
||||
|
||||
# 1) Copy api/.env.example -> api/.env
|
||||
cp "$API_ENV_EXAMPLE" "$API_ENV"
|
||||
|
||||
# 2) Copy web/.env.example -> web/.env.local
|
||||
cp "$WEB_ENV_EXAMPLE" "$WEB_ENV"
|
||||
|
||||
# 3) Copy docker/middleware.env.example -> docker/middleware.env
|
||||
cp "$MIDDLEWARE_ENV_EXAMPLE" "$MIDDLEWARE_ENV"
|
||||
|
||||
# 4) Install deps
|
||||
cd "$ROOT/api"
|
||||
uv sync --group dev
|
||||
|
||||
cd "$ROOT/web"
|
||||
pnpm install
|
||||
@ -3,8 +3,9 @@
|
||||
set -x
|
||||
|
||||
SCRIPT_DIR="$(dirname "$(realpath "$0")")"
|
||||
cd "$SCRIPT_DIR/.."
|
||||
cd "$SCRIPT_DIR/../api"
|
||||
|
||||
uv run flask db upgrade
|
||||
|
||||
uv --directory api run \
|
||||
uv run \
|
||||
flask run --host 0.0.0.0 --port=5001 --debug
|
||||
|
||||
8
dev/start-docker-compose
Executable file
8
dev/start-docker-compose
Executable file
@ -0,0 +1,8 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
SCRIPT_DIR="$(dirname "$(realpath "$0")")"
|
||||
ROOT="$(dirname "$SCRIPT_DIR")"
|
||||
|
||||
cd "$ROOT/docker"
|
||||
docker compose -f docker-compose.middleware.yaml --profile postgresql --profile weaviate -p dify up -d
|
||||
@ -83,7 +83,7 @@ while [[ $# -gt 0 ]]; do
|
||||
done
|
||||
|
||||
SCRIPT_DIR="$(dirname "$(realpath "$0")")"
|
||||
cd "$SCRIPT_DIR/.."
|
||||
cd "$SCRIPT_DIR/../api"
|
||||
|
||||
if [[ -n "${ENV_FILE}" ]]; then
|
||||
if [[ ! -f "${ENV_FILE}" ]]; then
|
||||
@ -123,6 +123,6 @@ echo " Concurrency: ${CONCURRENCY}"
|
||||
echo " Pool: ${POOL}"
|
||||
echo " Log Level: ${LOGLEVEL}"
|
||||
|
||||
uv --directory api run \
|
||||
uv run \
|
||||
celery -A app.celery worker \
|
||||
-P ${POOL} -c ${CONCURRENCY} --loglevel ${LOGLEVEL} -Q ${QUEUES}
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
set -e
|
||||
set -o pipefail
|
||||
|
||||
SCRIPT_DIR="$(dirname "$0")"
|
||||
SCRIPT_DIR="$(dirname "$(realpath "$0")")"
|
||||
REPO_ROOT="$(dirname "${SCRIPT_DIR}")"
|
||||
|
||||
# rely on `poetry` in path
|
||||
|
||||
@ -1,27 +1,15 @@
|
||||
import type { StorybookConfig } from '@storybook/nextjs'
|
||||
import path from 'node:path'
|
||||
import { fileURLToPath } from 'node:url'
|
||||
|
||||
const storybookDir = path.dirname(fileURLToPath(import.meta.url))
|
||||
import type { StorybookConfig } from '@storybook/nextjs-vite'
|
||||
|
||||
const config: StorybookConfig = {
|
||||
stories: ['../app/components/**/*.stories.@(js|jsx|mjs|ts|tsx)'],
|
||||
addons: [
|
||||
'@storybook/addon-onboarding',
|
||||
// Not working with Storybook Vite framework
|
||||
// '@storybook/addon-onboarding',
|
||||
'@storybook/addon-links',
|
||||
'@storybook/addon-docs',
|
||||
'@chromatic-com/storybook',
|
||||
],
|
||||
framework: {
|
||||
name: '@storybook/nextjs',
|
||||
options: {
|
||||
builder: {
|
||||
useSWC: true,
|
||||
lazyCompilation: false,
|
||||
},
|
||||
nextConfigPath: undefined,
|
||||
},
|
||||
},
|
||||
framework: '@storybook/nextjs-vite',
|
||||
staticDirs: ['../public'],
|
||||
core: {
|
||||
disableWhatsNewNotifications: true,
|
||||
@ -29,17 +17,5 @@ const config: StorybookConfig = {
|
||||
docs: {
|
||||
defaultName: 'Documentation',
|
||||
},
|
||||
webpackFinal: async (config) => {
|
||||
// Add alias to mock problematic modules with circular dependencies
|
||||
config.resolve = config.resolve || {}
|
||||
config.resolve.alias = {
|
||||
...config.resolve.alias,
|
||||
// Mock the plugin index files to avoid circular dependencies
|
||||
[path.resolve(storybookDir, '../app/components/base/prompt-editor/plugins/context-block/index.tsx')]: path.resolve(storybookDir, '__mocks__/context-block.tsx'),
|
||||
[path.resolve(storybookDir, '../app/components/base/prompt-editor/plugins/history-block/index.tsx')]: path.resolve(storybookDir, '__mocks__/history-block.tsx'),
|
||||
[path.resolve(storybookDir, '../app/components/base/prompt-editor/plugins/query-block/index.tsx')]: path.resolve(storybookDir, '__mocks__/query-block.tsx'),
|
||||
}
|
||||
return config
|
||||
},
|
||||
}
|
||||
export default config
|
||||
|
||||
@ -1,3 +1,6 @@
|
||||
/**
|
||||
* @vitest-environment jsdom
|
||||
*/
|
||||
import type { ReactNode } from 'react'
|
||||
import type { ModalContextState } from '@/context/modal-context'
|
||||
import type { ProviderContextState } from '@/context/provider-context'
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs'
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs-vite'
|
||||
import { RiAddLine, RiDeleteBinLine, RiEditLine, RiMore2Fill, RiSaveLine, RiShareLine } from '@remixicon/react'
|
||||
import ActionButton, { ActionButtonState } from '.'
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs'
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs-vite'
|
||||
import type { IChatItem } from '@/app/components/base/chat/chat/type'
|
||||
import type { AgentLogDetailResponse } from '@/models/log'
|
||||
import { useEffect, useRef } from 'react'
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs'
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs-vite'
|
||||
import type { ReactNode } from 'react'
|
||||
import AnswerIcon from '.'
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs'
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs-vite'
|
||||
import type { AppIconSelection } from '.'
|
||||
import { useState } from 'react'
|
||||
import AppIconPicker from '.'
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs'
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs-vite'
|
||||
import type { ComponentProps } from 'react'
|
||||
import AppIcon from '.'
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs'
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs-vite'
|
||||
import type { ComponentProps } from 'react'
|
||||
import { useEffect } from 'react'
|
||||
import AudioBtn from '.'
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs'
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs-vite'
|
||||
import AudioGallery from '.'
|
||||
|
||||
const AUDIO_SOURCES = [
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs'
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs-vite'
|
||||
import { useState } from 'react'
|
||||
import AutoHeightTextarea from '.'
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs'
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs-vite'
|
||||
import Avatar from '.'
|
||||
|
||||
const meta = {
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs'
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs-vite'
|
||||
import Badge from '../badge'
|
||||
|
||||
const meta = {
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs'
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs-vite'
|
||||
import { useState } from 'react'
|
||||
import BlockInput from '.'
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs'
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs-vite'
|
||||
import AddButton from './add-button'
|
||||
|
||||
const meta = {
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs'
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs-vite'
|
||||
|
||||
import { RocketLaunchIcon } from '@heroicons/react/20/solid'
|
||||
import { Button } from '.'
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs'
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs-vite'
|
||||
import SyncButton from './sync-button'
|
||||
|
||||
const meta = {
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs'
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs-vite'
|
||||
import type { ChatItem } from '../../types'
|
||||
import { WorkflowRunningStatus } from '@/app/components/workflow/types'
|
||||
import Answer from '.'
|
||||
|
||||
178
web/app/components/base/chat/chat/hooks.multimodal.spec.ts
Normal file
178
web/app/components/base/chat/chat/hooks.multimodal.spec.ts
Normal file
@ -0,0 +1,178 @@
|
||||
/**
|
||||
* Tests for multimodal image file handling in chat hooks.
|
||||
* Tests the file object conversion logic without full hook integration.
|
||||
*/
|
||||
|
||||
describe('Multimodal File Handling', () => {
|
||||
describe('File type to MIME type mapping', () => {
|
||||
it('should map image to image/png', () => {
|
||||
const fileType: string = 'image'
|
||||
const expectedMime = 'image/png'
|
||||
const mimeType = fileType === 'image' ? 'image/png' : 'application/octet-stream'
|
||||
expect(mimeType).toBe(expectedMime)
|
||||
})
|
||||
|
||||
it('should map video to video/mp4', () => {
|
||||
const fileType: string = 'video'
|
||||
const expectedMime = 'video/mp4'
|
||||
const mimeType = fileType === 'video' ? 'video/mp4' : 'application/octet-stream'
|
||||
expect(mimeType).toBe(expectedMime)
|
||||
})
|
||||
|
||||
it('should map audio to audio/mpeg', () => {
|
||||
const fileType: string = 'audio'
|
||||
const expectedMime = 'audio/mpeg'
|
||||
const mimeType = fileType === 'audio' ? 'audio/mpeg' : 'application/octet-stream'
|
||||
expect(mimeType).toBe(expectedMime)
|
||||
})
|
||||
|
||||
it('should map unknown to application/octet-stream', () => {
|
||||
const fileType: string = 'unknown'
|
||||
const expectedMime = 'application/octet-stream'
|
||||
const mimeType = ['image', 'video', 'audio'].includes(fileType) ? 'image/png' : 'application/octet-stream'
|
||||
expect(mimeType).toBe(expectedMime)
|
||||
})
|
||||
})
|
||||
|
||||
describe('TransferMethod selection', () => {
|
||||
it('should select remote_url for images', () => {
|
||||
const fileType: string = 'image'
|
||||
const transferMethod = fileType === 'image' ? 'remote_url' : 'local_file'
|
||||
expect(transferMethod).toBe('remote_url')
|
||||
})
|
||||
|
||||
it('should select local_file for non-images', () => {
|
||||
const fileType: string = 'video'
|
||||
const transferMethod = fileType === 'image' ? 'remote_url' : 'local_file'
|
||||
expect(transferMethod).toBe('local_file')
|
||||
})
|
||||
})
|
||||
|
||||
describe('File extension mapping', () => {
|
||||
it('should use .png extension for images', () => {
|
||||
const fileType: string = 'image'
|
||||
const expectedExtension = '.png'
|
||||
const extension = fileType === 'image' ? 'png' : 'bin'
|
||||
expect(extension).toBe(expectedExtension.replace('.', ''))
|
||||
})
|
||||
|
||||
it('should use .mp4 extension for videos', () => {
|
||||
const fileType: string = 'video'
|
||||
const expectedExtension = '.mp4'
|
||||
const extension = fileType === 'video' ? 'mp4' : 'bin'
|
||||
expect(extension).toBe(expectedExtension.replace('.', ''))
|
||||
})
|
||||
|
||||
it('should use .mp3 extension for audio', () => {
|
||||
const fileType: string = 'audio'
|
||||
const expectedExtension = '.mp3'
|
||||
const extension = fileType === 'audio' ? 'mp3' : 'bin'
|
||||
expect(extension).toBe(expectedExtension.replace('.', ''))
|
||||
})
|
||||
})
|
||||
|
||||
describe('File name generation', () => {
|
||||
it('should generate correct file name for images', () => {
|
||||
const fileType: string = 'image'
|
||||
const expectedName = 'generated_image.png'
|
||||
const fileName = `generated_${fileType}.${fileType === 'image' ? 'png' : 'bin'}`
|
||||
expect(fileName).toBe(expectedName)
|
||||
})
|
||||
|
||||
it('should generate correct file name for videos', () => {
|
||||
const fileType: string = 'video'
|
||||
const expectedName = 'generated_video.mp4'
|
||||
const fileName = `generated_${fileType}.${fileType === 'video' ? 'mp4' : 'bin'}`
|
||||
expect(fileName).toBe(expectedName)
|
||||
})
|
||||
|
||||
it('should generate correct file name for audio', () => {
|
||||
const fileType: string = 'audio'
|
||||
const expectedName = 'generated_audio.mp3'
|
||||
const fileName = `generated_${fileType}.${fileType === 'audio' ? 'mp3' : 'bin'}`
|
||||
expect(fileName).toBe(expectedName)
|
||||
})
|
||||
})
|
||||
|
||||
describe('SupportFileType mapping', () => {
|
||||
it('should map image type to image supportFileType', () => {
|
||||
const fileType: string = 'image'
|
||||
const supportFileType = fileType === 'image' ? 'image' : fileType === 'video' ? 'video' : fileType === 'audio' ? 'audio' : 'document'
|
||||
expect(supportFileType).toBe('image')
|
||||
})
|
||||
|
||||
it('should map video type to video supportFileType', () => {
|
||||
const fileType: string = 'video'
|
||||
const supportFileType = fileType === 'image' ? 'image' : fileType === 'video' ? 'video' : fileType === 'audio' ? 'audio' : 'document'
|
||||
expect(supportFileType).toBe('video')
|
||||
})
|
||||
|
||||
it('should map audio type to audio supportFileType', () => {
|
||||
const fileType: string = 'audio'
|
||||
const supportFileType = fileType === 'image' ? 'image' : fileType === 'video' ? 'video' : fileType === 'audio' ? 'audio' : 'document'
|
||||
expect(supportFileType).toBe('audio')
|
||||
})
|
||||
|
||||
it('should map unknown type to document supportFileType', () => {
|
||||
const fileType: string = 'unknown'
|
||||
const supportFileType = fileType === 'image' ? 'image' : fileType === 'video' ? 'video' : fileType === 'audio' ? 'audio' : 'document'
|
||||
expect(supportFileType).toBe('document')
|
||||
})
|
||||
})
|
||||
|
||||
describe('File conversion logic', () => {
|
||||
it('should detect existing transferMethod', () => {
|
||||
const fileWithTransferMethod = {
|
||||
id: 'file-123',
|
||||
transferMethod: 'remote_url' as const,
|
||||
type: 'image/png',
|
||||
name: 'test.png',
|
||||
size: 1024,
|
||||
supportFileType: 'image',
|
||||
progress: 100,
|
||||
}
|
||||
const hasTransferMethod = 'transferMethod' in fileWithTransferMethod
|
||||
expect(hasTransferMethod).toBe(true)
|
||||
})
|
||||
|
||||
it('should detect missing transferMethod', () => {
|
||||
const fileWithoutTransferMethod = {
|
||||
id: 'file-456',
|
||||
type: 'image',
|
||||
url: 'http://example.com/image.png',
|
||||
belongs_to: 'assistant',
|
||||
}
|
||||
const hasTransferMethod = 'transferMethod' in fileWithoutTransferMethod
|
||||
expect(hasTransferMethod).toBe(false)
|
||||
})
|
||||
|
||||
it('should create file with size 0 for generated files', () => {
|
||||
const expectedSize = 0
|
||||
expect(expectedSize).toBe(0)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Agent vs Non-Agent mode logic', () => {
|
||||
it('should check for agent_thoughts to determine mode', () => {
|
||||
const agentResponse: { agent_thoughts?: Array<Record<string, unknown>> } = {
|
||||
agent_thoughts: [{}],
|
||||
}
|
||||
const isAgentMode = agentResponse.agent_thoughts && agentResponse.agent_thoughts.length > 0
|
||||
expect(isAgentMode).toBe(true)
|
||||
})
|
||||
|
||||
it('should detect non-agent mode when agent_thoughts is empty', () => {
|
||||
const nonAgentResponse: { agent_thoughts?: Array<Record<string, unknown>> } = {
|
||||
agent_thoughts: [],
|
||||
}
|
||||
const isAgentMode = nonAgentResponse.agent_thoughts && nonAgentResponse.agent_thoughts.length > 0
|
||||
expect(isAgentMode).toBe(false)
|
||||
})
|
||||
|
||||
it('should detect non-agent mode when agent_thoughts is undefined', () => {
|
||||
const nonAgentResponse: { agent_thoughts?: Array<Record<string, unknown>> } = {}
|
||||
const isAgentMode = nonAgentResponse.agent_thoughts && nonAgentResponse.agent_thoughts.length > 0
|
||||
expect(isAgentMode).toBeFalsy()
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -419,9 +419,40 @@ export const useChat = (
|
||||
}
|
||||
},
|
||||
onFile(file) {
|
||||
// Convert simple file type to MIME type for non-agent mode
|
||||
// Backend sends: { id, type: "image", belongs_to, url }
|
||||
// Frontend expects: { id, type: "image/png", transferMethod, url, uploadedId, supportFileType, name, size }
|
||||
|
||||
// Determine file type for MIME conversion
|
||||
const fileType = (file as { type?: string }).type || 'image'
|
||||
|
||||
// If file already has transferMethod, use it as base and ensure all required fields exist
|
||||
// Otherwise, create a new complete file object
|
||||
const baseFile = ('transferMethod' in file) ? (file as Partial<FileEntity>) : null
|
||||
|
||||
const convertedFile: FileEntity = {
|
||||
id: baseFile?.id || (file as { id: string }).id,
|
||||
type: baseFile?.type || (fileType === 'image' ? 'image/png' : fileType === 'video' ? 'video/mp4' : fileType === 'audio' ? 'audio/mpeg' : 'application/octet-stream'),
|
||||
transferMethod: (baseFile?.transferMethod as FileEntity['transferMethod']) || (fileType === 'image' ? 'remote_url' : 'local_file'),
|
||||
uploadedId: baseFile?.uploadedId || (file as { id: string }).id,
|
||||
supportFileType: baseFile?.supportFileType || (fileType === 'image' ? 'image' : fileType === 'video' ? 'video' : fileType === 'audio' ? 'audio' : 'document'),
|
||||
progress: baseFile?.progress ?? 100,
|
||||
name: baseFile?.name || `generated_${fileType}.${fileType === 'image' ? 'png' : fileType === 'video' ? 'mp4' : fileType === 'audio' ? 'mp3' : 'bin'}`,
|
||||
url: baseFile?.url || (file as { url?: string }).url,
|
||||
size: baseFile?.size ?? 0, // Generated files don't have a known size
|
||||
}
|
||||
|
||||
// For agent mode, add files to the last thought
|
||||
const lastThought = responseItem.agent_thoughts?.[responseItem.agent_thoughts?.length - 1]
|
||||
if (lastThought)
|
||||
responseItem.agent_thoughts![responseItem.agent_thoughts!.length - 1].message_files = [...(lastThought as any).message_files, file]
|
||||
if (lastThought) {
|
||||
const thought = lastThought as { message_files?: FileEntity[] }
|
||||
responseItem.agent_thoughts![responseItem.agent_thoughts!.length - 1].message_files = [...(thought.message_files ?? []), convertedFile]
|
||||
}
|
||||
// For non-agent mode, add files directly to responseItem.message_files
|
||||
else {
|
||||
const currentFiles = (responseItem.message_files as FileEntity[] | undefined) ?? []
|
||||
responseItem.message_files = [...currentFiles, convertedFile]
|
||||
}
|
||||
|
||||
updateCurrentQAOnTree({
|
||||
placeholderQuestionId,
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs'
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs-vite'
|
||||
|
||||
import type { ChatItem } from '../types'
|
||||
import { User } from '@/app/components/base/icons/src/public/avatar'
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs'
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs-vite'
|
||||
import { useState } from 'react'
|
||||
import Checkbox from '.'
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs'
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs-vite'
|
||||
import type { Item } from '.'
|
||||
import { useState } from 'react'
|
||||
import Chip from '.'
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs'
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs-vite'
|
||||
import { useState } from 'react'
|
||||
import Confirm from '.'
|
||||
import Button from '../button'
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs'
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs-vite'
|
||||
import { useEffect, useState } from 'react'
|
||||
import ContentDialog from '.'
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs'
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs-vite'
|
||||
import { useState } from 'react'
|
||||
import CopyFeedback, { CopyFeedbackNew } from '.'
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs'
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs-vite'
|
||||
import CopyIcon from '.'
|
||||
|
||||
const meta = {
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs'
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs-vite'
|
||||
import CornerLabel from '.'
|
||||
|
||||
const meta = {
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs'
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs-vite'
|
||||
import type { DatePickerProps } from './types'
|
||||
import { useState } from 'react'
|
||||
import { fn } from 'storybook/test'
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs'
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs-vite'
|
||||
import { useEffect, useState } from 'react'
|
||||
import Dialog from '.'
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs'
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs-vite'
|
||||
import Divider from '.'
|
||||
|
||||
const meta = {
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs'
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs-vite'
|
||||
import { useState } from 'react'
|
||||
import { fn } from 'storybook/test'
|
||||
import DrawerPlus from '.'
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs'
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs-vite'
|
||||
import { useState } from 'react'
|
||||
import { fn } from 'storybook/test'
|
||||
import Drawer from '.'
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs'
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs-vite'
|
||||
import type { Item } from '.'
|
||||
import { useState } from 'react'
|
||||
import { fn } from 'storybook/test'
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/* eslint-disable tailwindcss/classnames-order */
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs'
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs-vite'
|
||||
import Effect from '.'
|
||||
|
||||
const meta = {
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs'
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs-vite'
|
||||
import { useState } from 'react'
|
||||
import EmojiPickerInner from './Inner'
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs'
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs-vite'
|
||||
import { useState } from 'react'
|
||||
import EmojiPicker from '.'
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs'
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs-vite'
|
||||
import type { Features } from './types'
|
||||
import { useState } from 'react'
|
||||
import { FeaturesProvider } from '.'
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs'
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs-vite'
|
||||
import FileIcon from '.'
|
||||
|
||||
const meta = {
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs'
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs-vite'
|
||||
import FileImageRender from './file-image-render'
|
||||
|
||||
const SAMPLE_IMAGE = 'data:image/svg+xml;utf8,<svg xmlns=\'http://www.w3.org/2000/svg\' width=\'320\' height=\'180\'><defs><linearGradient id=\'grad\' x1=\'0%\' y1=\'0%\' x2=\'100%\' y2=\'100%\'><stop offset=\'0%\' stop-color=\'#FEE2FF\'/><stop offset=\'100%\' stop-color=\'#E0EAFF\'/></linearGradient></defs><rect width=\'320\' height=\'180\' rx=\'18\' fill=\'url(#grad)\'/><text x=\'50%\' y=\'50%\' dominant-baseline=\'middle\' text-anchor=\'middle\' font-family=\'sans-serif\' font-size=\'24\' fill=\'#1F2937\'>Preview</text></svg>'
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs'
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs-vite'
|
||||
import type { FileEntity } from './types'
|
||||
import { useState } from 'react'
|
||||
import { SupportUploadFileTypes } from '@/app/components/workflow/types'
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs'
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs-vite'
|
||||
import FileTypeIcon from './file-type-icon'
|
||||
import { FileAppearanceTypeEnum } from './types'
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user