Compare commits

..

7 Commits

Author SHA1 Message Date
yyh
41857cf9cd Fix retry handling for request bodies 2026-01-05 14:31:57 +08:00
yyh
78ab09b521 fix(nodejs-client): handle malformed JSON responses gracefully
Fix critical bugs in response body parsing:

1. Error responses (non-OK status):
   - Response body can only be read once
   - Previous code: response.json() fails → response.text() throws 'Body already read'
   - Fix: Read as text first, then parse JSON with fallback to raw text

2. Success responses (OK status):
   - Malformed JSON was not caught, causing NetworkError misclassification
   - Fix: Wrap JSON.parse in try-catch, fallback to raw text

These fixes prevent unhandled promise rejections and provide better error
messages when servers return malformed JSON responses.

Reported-by: gemini-code-assist
Severity: Critical (issue 1), High (issue 2)
2026-01-05 14:23:20 +08:00
yyh
249a491743 fix(nodejs-client): fix timeout handling in retry loop
Fix critical bug in retry loop timeout control:
- Move AbortController creation inside while loop to ensure each retry uses a fresh controller
- Previous implementation caused retries to use an already-aborted controller, making retries fail immediately
- Add eslint-disable comments to explain necessary type assertions
- Improve code comments explaining DOM ReadableStream type conversion

This fix ensures the retry mechanism works correctly with independent timeout control for each attempt.
2026-01-05 14:13:22 +08:00
yyh
df67842fae fix(nodejs-client): fix all test mocks for fetch API
- Add createMockResponse helper to properly mock fetch responses
- Update all test cases to use the mock helper with all required methods
- Fix timeout error test to use AbortError instead of fake timers
- Ensure beforeEach restores real timers to avoid test interference
- All 85 tests now passing
2026-01-05 14:04:57 +08:00
yyh
a4495ab586 chore(nodejs-client): update pnpm-lock.yaml to remove axios 2026-01-05 13:58:25 +08:00
yyh
7c3f15c507 chore(nodejs-client): bump version to 3.1.0 2026-01-05 13:57:45 +08:00
yyh
5db66ad033 refactor(nodejs-client): replace axios with native fetch API
- Replace axios with Node.js native fetch API for HTTP requests
- Update HttpClient to use fetch instead of axios instance
- Convert axios-specific error handling to fetch-based error mapping
- Update response type handling for streams, JSON, text, etc.
- Remove axios from package.json dependencies
- Update all test files to mock fetch instead of axios

This change reduces external dependencies and uses the built-in
fetch API available in Node.js 18+, which is already the minimum
required version for this SDK.
2026-01-05 13:57:22 +08:00
798 changed files with 11784 additions and 42307 deletions

View File

@ -4,19 +4,6 @@
"context7@claude-plugins-official": true,
"typescript-lsp@claude-plugins-official": true,
"pyright-lsp@claude-plugins-official": true,
"ralph-loop@claude-plugins-official": true
},
"hooks": {
"PreToolUse": [
{
"matcher": "Bash",
"hooks": [
{
"type": "command",
"command": "npx -y block-no-verify@1.1.1"
}
]
}
]
"ralph-wiggum@claude-plugins-official": true
}
}

View File

@ -1,355 +0,0 @@
---
name: skill-creator
description: Guide for creating effective skills. This skill should be used when users want to create a new skill (or update an existing skill) that extends Claude's capabilities with specialized knowledge, workflows, or tool integrations.
---
# Skill Creator
This skill provides guidance for creating effective skills.
## About Skills
Skills are modular, self-contained packages that extend Claude's capabilities by providing
specialized knowledge, workflows, and tools. Think of them as "onboarding guides" for specific
domains or tasks—they transform Claude from a general-purpose agent into a specialized agent
equipped with procedural knowledge that no model can fully possess.
### What Skills Provide
1. Specialized workflows - Multi-step procedures for specific domains
2. Tool integrations - Instructions for working with specific file formats or APIs
3. Domain expertise - Company-specific knowledge, schemas, business logic
4. Bundled resources - Scripts, references, and assets for complex and repetitive tasks
## Core Principles
### Concise is Key
The context window is a public good. Skills share the context window with everything else Claude needs: system prompt, conversation history, other Skills' metadata, and the actual user request.
**Default assumption: Claude is already very smart.** Only add context Claude doesn't already have. Challenge each piece of information: "Does Claude really need this explanation?" and "Does this paragraph justify its token cost?"
Prefer concise examples over verbose explanations.
### Set Appropriate Degrees of Freedom
Match the level of specificity to the task's fragility and variability:
**High freedom (text-based instructions)**: Use when multiple approaches are valid, decisions depend on context, or heuristics guide the approach.
**Medium freedom (pseudocode or scripts with parameters)**: Use when a preferred pattern exists, some variation is acceptable, or configuration affects behavior.
**Low freedom (specific scripts, few parameters)**: Use when operations are fragile and error-prone, consistency is critical, or a specific sequence must be followed.
Think of Claude as exploring a path: a narrow bridge with cliffs needs specific guardrails (low freedom), while an open field allows many routes (high freedom).
### Anatomy of a Skill
Every skill consists of a required SKILL.md file and optional bundled resources:
```
skill-name/
├── SKILL.md (required)
│ ├── YAML frontmatter metadata (required)
│ │ ├── name: (required)
│ │ └── description: (required)
│ └── Markdown instructions (required)
└── Bundled Resources (optional)
├── scripts/ - Executable code (Python/Bash/etc.)
├── references/ - Documentation intended to be loaded into context as needed
└── assets/ - Files used in output (templates, icons, fonts, etc.)
```
#### SKILL.md (required)
Every SKILL.md consists of:
- **Frontmatter** (YAML): Contains `name` and `description` fields. These are the only fields that Claude reads to determine when the skill gets used, thus it is very important to be clear and comprehensive in describing what the skill is, and when it should be used.
- **Body** (Markdown): Instructions and guidance for using the skill. Only loaded AFTER the skill triggers (if at all).
#### Bundled Resources (optional)
##### Scripts (`scripts/`)
Executable code (Python/Bash/etc.) for tasks that require deterministic reliability or are repeatedly rewritten.
- **When to include**: When the same code is being rewritten repeatedly or deterministic reliability is needed
- **Example**: `scripts/rotate_pdf.py` for PDF rotation tasks
- **Benefits**: Token efficient, deterministic, may be executed without loading into context
- **Note**: Scripts may still need to be read by Claude for patching or environment-specific adjustments
##### References (`references/`)
Documentation and reference material intended to be loaded as needed into context to inform Claude's process and thinking.
- **When to include**: For documentation that Claude should reference while working
- **Examples**: `references/finance.md` for financial schemas, `references/mnda.md` for company NDA template, `references/policies.md` for company policies, `references/api_docs.md` for API specifications
- **Use cases**: Database schemas, API documentation, domain knowledge, company policies, detailed workflow guides
- **Benefits**: Keeps SKILL.md lean, loaded only when Claude determines it's needed
- **Best practice**: If files are large (>10k words), include grep search patterns in SKILL.md
- **Avoid duplication**: Information should live in either SKILL.md or references files, not both. Prefer references files for detailed information unless it's truly core to the skill—this keeps SKILL.md lean while making information discoverable without hogging the context window. Keep only essential procedural instructions and workflow guidance in SKILL.md; move detailed reference material, schemas, and examples to references files.
##### Assets (`assets/`)
Files not intended to be loaded into context, but rather used within the output Claude produces.
- **When to include**: When the skill needs files that will be used in the final output
- **Examples**: `assets/logo.png` for brand assets, `assets/slides.pptx` for PowerPoint templates, `assets/frontend-template/` for HTML/React boilerplate, `assets/font.ttf` for typography
- **Use cases**: Templates, images, icons, boilerplate code, fonts, sample documents that get copied or modified
- **Benefits**: Separates output resources from documentation, enables Claude to use files without loading them into context
#### What to Not Include in a Skill
A skill should only contain essential files that directly support its functionality. Do NOT create extraneous documentation or auxiliary files, including:
- README.md
- INSTALLATION_GUIDE.md
- QUICK_REFERENCE.md
- CHANGELOG.md
- etc.
The skill should only contain the information needed for an AI agent to do the job at hand. It should not contain auxilary context about the process that went into creating it, setup and testing procedures, user-facing documentation, etc. Creating additional documentation files just adds clutter and confusion.
### Progressive Disclosure Design Principle
Skills use a three-level loading system to manage context efficiently:
1. **Metadata (name + description)** - Always in context (~100 words)
2. **SKILL.md body** - When skill triggers (<5k words)
3. **Bundled resources** - As needed by Claude (Unlimited because scripts can be executed without reading into context window)
#### Progressive Disclosure Patterns
Keep SKILL.md body to the essentials and under 500 lines to minimize context bloat. Split content into separate files when approaching this limit. When splitting out content into other files, it is very important to reference them from SKILL.md and describe clearly when to read them, to ensure the reader of the skill knows they exist and when to use them.
**Key principle:** When a skill supports multiple variations, frameworks, or options, keep only the core workflow and selection guidance in SKILL.md. Move variant-specific details (patterns, examples, configuration) into separate reference files.
**Pattern 1: High-level guide with references**
```markdown
# PDF Processing
## Quick start
Extract text with pdfplumber:
[code example]
## Advanced features
- **Form filling**: See [FORMS.md](FORMS.md) for complete guide
- **API reference**: See [REFERENCE.md](REFERENCE.md) for all methods
- **Examples**: See [EXAMPLES.md](EXAMPLES.md) for common patterns
```
Claude loads FORMS.md, REFERENCE.md, or EXAMPLES.md only when needed.
**Pattern 2: Domain-specific organization**
For Skills with multiple domains, organize content by domain to avoid loading irrelevant context:
```
bigquery-skill/
├── SKILL.md (overview and navigation)
└── reference/
├── finance.md (revenue, billing metrics)
├── sales.md (opportunities, pipeline)
├── product.md (API usage, features)
└── marketing.md (campaigns, attribution)
```
When a user asks about sales metrics, Claude only reads sales.md.
Similarly, for skills supporting multiple frameworks or variants, organize by variant:
```
cloud-deploy/
├── SKILL.md (workflow + provider selection)
└── references/
├── aws.md (AWS deployment patterns)
├── gcp.md (GCP deployment patterns)
└── azure.md (Azure deployment patterns)
```
When the user chooses AWS, Claude only reads aws.md.
**Pattern 3: Conditional details**
Show basic content, link to advanced content:
```markdown
# DOCX Processing
## Creating documents
Use docx-js for new documents. See [DOCX-JS.md](DOCX-JS.md).
## Editing documents
For simple edits, modify the XML directly.
**For tracked changes**: See [REDLINING.md](REDLINING.md)
**For OOXML details**: See [OOXML.md](OOXML.md)
```
Claude reads REDLINING.md or OOXML.md only when the user needs those features.
**Important guidelines:**
- **Avoid deeply nested references** - Keep references one level deep from SKILL.md. All reference files should link directly from SKILL.md.
- **Structure longer reference files** - For files longer than 100 lines, include a table of contents at the top so Claude can see the full scope when previewing.
## Skill Creation Process
Skill creation involves these steps:
1. Understand the skill with concrete examples
2. Plan reusable skill contents (scripts, references, assets)
3. Initialize the skill (run init_skill.py)
4. Edit the skill (implement resources and write SKILL.md)
5. Package the skill (run package_skill.py)
6. Iterate based on real usage
Follow these steps in order, skipping only if there is a clear reason why they are not applicable.
### Step 1: Understanding the Skill with Concrete Examples
Skip this step only when the skill's usage patterns are already clearly understood. It remains valuable even when working with an existing skill.
To create an effective skill, clearly understand concrete examples of how the skill will be used. This understanding can come from either direct user examples or generated examples that are validated with user feedback.
For example, when building an image-editor skill, relevant questions include:
- "What functionality should the image-editor skill support? Editing, rotating, anything else?"
- "Can you give some examples of how this skill would be used?"
- "I can imagine users asking for things like 'Remove the red-eye from this image' or 'Rotate this image'. Are there other ways you imagine this skill being used?"
- "What would a user say that should trigger this skill?"
To avoid overwhelming users, avoid asking too many questions in a single message. Start with the most important questions and follow up as needed for better effectiveness.
Conclude this step when there is a clear sense of the functionality the skill should support.
### Step 2: Planning the Reusable Skill Contents
To turn concrete examples into an effective skill, analyze each example by:
1. Considering how to execute on the example from scratch
2. Identifying what scripts, references, and assets would be helpful when executing these workflows repeatedly
Example: When building a `pdf-editor` skill to handle queries like "Help me rotate this PDF," the analysis shows:
1. Rotating a PDF requires re-writing the same code each time
2. A `scripts/rotate_pdf.py` script would be helpful to store in the skill
Example: When designing a `frontend-webapp-builder` skill for queries like "Build me a todo app" or "Build me a dashboard to track my steps," the analysis shows:
1. Writing a frontend webapp requires the same boilerplate HTML/React each time
2. An `assets/hello-world/` template containing the boilerplate HTML/React project files would be helpful to store in the skill
Example: When building a `big-query` skill to handle queries like "How many users have logged in today?" the analysis shows:
1. Querying BigQuery requires re-discovering the table schemas and relationships each time
2. A `references/schema.md` file documenting the table schemas would be helpful to store in the skill
To establish the skill's contents, analyze each concrete example to create a list of the reusable resources to include: scripts, references, and assets.
### Step 3: Initializing the Skill
At this point, it is time to actually create the skill.
Skip this step only if the skill being developed already exists, and iteration or packaging is needed. In this case, continue to the next step.
When creating a new skill from scratch, always run the `init_skill.py` script. The script conveniently generates a new template skill directory that automatically includes everything a skill requires, making the skill creation process much more efficient and reliable.
Usage:
```bash
scripts/init_skill.py <skill-name> --path <output-directory>
```
The script:
- Creates the skill directory at the specified path
- Generates a SKILL.md template with proper frontmatter and TODO placeholders
- Creates example resource directories: `scripts/`, `references/`, and `assets/`
- Adds example files in each directory that can be customized or deleted
After initialization, customize or remove the generated SKILL.md and example files as needed.
### Step 4: Edit the Skill
When editing the (newly-generated or existing) skill, remember that the skill is being created for another instance of Claude to use. Include information that would be beneficial and non-obvious to Claude. Consider what procedural knowledge, domain-specific details, or reusable assets would help another Claude instance execute these tasks more effectively.
#### Learn Proven Design Patterns
Consult these helpful guides based on your skill's needs:
- **Multi-step processes**: See references/workflows.md for sequential workflows and conditional logic
- **Specific output formats or quality standards**: See references/output-patterns.md for template and example patterns
These files contain established best practices for effective skill design.
#### Start with Reusable Skill Contents
To begin implementation, start with the reusable resources identified above: `scripts/`, `references/`, and `assets/` files. Note that this step may require user input. For example, when implementing a `brand-guidelines` skill, the user may need to provide brand assets or templates to store in `assets/`, or documentation to store in `references/`.
Added scripts must be tested by actually running them to ensure there are no bugs and that the output matches what is expected. If there are many similar scripts, only a representative sample needs to be tested to ensure confidence that they all work while balancing time to completion.
Any example files and directories not needed for the skill should be deleted. The initialization script creates example files in `scripts/`, `references/`, and `assets/` to demonstrate structure, but most skills won't need all of them.
#### Update SKILL.md
**Writing Guidelines:** Always use imperative/infinitive form.
##### Frontmatter
Write the YAML frontmatter with `name` and `description`:
- `name`: The skill name
- `description`: This is the primary triggering mechanism for your skill, and helps Claude understand when to use the skill.
- Include both what the Skill does and specific triggers/contexts for when to use it.
- Include all "when to use" information here - Not in the body. The body is only loaded after triggering, so "When to Use This Skill" sections in the body are not helpful to Claude.
- Example description for a `docx` skill: "Comprehensive document creation, editing, and analysis with support for tracked changes, comments, formatting preservation, and text extraction. Use when Claude needs to work with professional documents (.docx files) for: (1) Creating new documents, (2) Modifying or editing content, (3) Working with tracked changes, (4) Adding comments, or any other document tasks"
Do not include any other fields in YAML frontmatter.
##### Body
Write instructions for using the skill and its bundled resources.
### Step 5: Packaging a Skill
Once development of the skill is complete, it must be packaged into a distributable .skill file that gets shared with the user. The packaging process automatically validates the skill first to ensure it meets all requirements:
```bash
scripts/package_skill.py <path/to/skill-folder>
```
Optional output directory specification:
```bash
scripts/package_skill.py <path/to/skill-folder> ./dist
```
The packaging script will:
1. **Validate** the skill automatically, checking:
- YAML frontmatter format and required fields
- Skill naming conventions and directory structure
- Description completeness and quality
- File organization and resource references
2. **Package** the skill if validation passes, creating a .skill file named after the skill (e.g., `my-skill.skill`) that includes all files and maintains the proper directory structure for distribution. The .skill file is a zip file with a .skill extension.
If validation fails, the script will report the errors and exit without creating a package. Fix any validation errors and run the packaging command again.
### Step 6: Iterate
After testing the skill, users may request improvements. Often this happens right after using the skill, with fresh context of how the skill performed.
**Iteration workflow:**
1. Use the skill on real tasks
2. Notice struggles or inefficiencies
3. Identify how SKILL.md or bundled resources should be updated
4. Implement changes and test again

View File

@ -1,86 +0,0 @@
# Output Patterns
Use these patterns when skills need to produce consistent, high-quality output.
## Template Pattern
Provide templates for output format. Match the level of strictness to your needs.
**For strict requirements (like API responses or data formats):**
```markdown
## Report structure
ALWAYS use this exact template structure:
# [Analysis Title]
## Executive summary
[One-paragraph overview of key findings]
## Key findings
- Finding 1 with supporting data
- Finding 2 with supporting data
- Finding 3 with supporting data
## Recommendations
1. Specific actionable recommendation
2. Specific actionable recommendation
```
**For flexible guidance (when adaptation is useful):**
```markdown
## Report structure
Here is a sensible default format, but use your best judgment:
# [Analysis Title]
## Executive summary
[Overview]
## Key findings
[Adapt sections based on what you discover]
## Recommendations
[Tailor to the specific context]
Adjust sections as needed for the specific analysis type.
```
## Examples Pattern
For skills where output quality depends on seeing examples, provide input/output pairs:
```markdown
## Commit message format
Generate commit messages following these examples:
**Example 1:**
Input: Added user authentication with JWT tokens
Output:
```
feat(auth): implement JWT-based authentication
Add login endpoint and token validation middleware
```
**Example 2:**
Input: Fixed bug where dates displayed incorrectly in reports
Output:
```
fix(reports): correct date formatting in timezone conversion
Use UTC timestamps consistently across report generation
```
Follow this style: type(scope): brief description, then detailed explanation.
```
Examples help Claude understand the desired style and level of detail more clearly than descriptions alone.

View File

@ -1,28 +0,0 @@
# Workflow Patterns
## Sequential Workflows
For complex tasks, break operations into clear, sequential steps. It is often helpful to give Claude an overview of the process towards the beginning of SKILL.md:
```markdown
Filling a PDF form involves these steps:
1. Analyze the form (run analyze_form.py)
2. Create field mapping (edit fields.json)
3. Validate mapping (run validate_fields.py)
4. Fill the form (run fill_form.py)
5. Verify output (run verify_output.py)
```
## Conditional Workflows
For tasks with branching logic, guide Claude through decision points:
```markdown
1. Determine the modification type:
**Creating new content?** → Follow "Creation workflow" below
**Editing existing content?** → Follow "Editing workflow" below
2. Creation workflow: [steps]
3. Editing workflow: [steps]
```

View File

@ -1,300 +0,0 @@
#!/usr/bin/env python3
"""
Skill Initializer - Creates a new skill from template
Usage:
init_skill.py <skill-name> --path <path>
Examples:
init_skill.py my-new-skill --path skills/public
init_skill.py my-api-helper --path skills/private
init_skill.py custom-skill --path /custom/location
"""
import sys
from pathlib import Path
SKILL_TEMPLATE = """---
name: {skill_name}
description: [TODO: Complete and informative explanation of what the skill does and when to use it. Include WHEN to use this skill - specific scenarios, file types, or tasks that trigger it.]
---
# {skill_title}
## Overview
[TODO: 1-2 sentences explaining what this skill enables]
## Structuring This Skill
[TODO: Choose the structure that best fits this skill's purpose. Common patterns:
**1. Workflow-Based** (best for sequential processes)
- Works well when there are clear step-by-step procedures
- Example: DOCX skill with "Workflow Decision Tree""Reading""Creating""Editing"
- Structure: ## Overview → ## Workflow Decision Tree → ## Step 1 → ## Step 2...
**2. Task-Based** (best for tool collections)
- Works well when the skill offers different operations/capabilities
- Example: PDF skill with "Quick Start""Merge PDFs""Split PDFs""Extract Text"
- Structure: ## Overview → ## Quick Start → ## Task Category 1 → ## Task Category 2...
**3. Reference/Guidelines** (best for standards or specifications)
- Works well for brand guidelines, coding standards, or requirements
- Example: Brand styling with "Brand Guidelines""Colors""Typography""Features"
- Structure: ## Overview → ## Guidelines → ## Specifications → ## Usage...
**4. Capabilities-Based** (best for integrated systems)
- Works well when the skill provides multiple interrelated features
- Example: Product Management with "Core Capabilities" → numbered capability list
- Structure: ## Overview → ## Core Capabilities → ### 1. Feature → ### 2. Feature...
Patterns can be mixed and matched as needed. Most skills combine patterns (e.g., start with task-based, add workflow for complex operations).
Delete this entire "Structuring This Skill" section when done - it's just guidance.]
## [TODO: Replace with the first main section based on chosen structure]
[TODO: Add content here. See examples in existing skills:
- Code samples for technical skills
- Decision trees for complex workflows
- Concrete examples with realistic user requests
- References to scripts/templates/references as needed]
## Resources
This skill includes example resource directories that demonstrate how to organize different types of bundled resources:
### scripts/
Executable code (Python/Bash/etc.) that can be run directly to perform specific operations.
**Examples from other skills:**
- PDF skill: `fill_fillable_fields.py`, `extract_form_field_info.py` - utilities for PDF manipulation
- DOCX skill: `document.py`, `utilities.py` - Python modules for document processing
**Appropriate for:** Python scripts, shell scripts, or any executable code that performs automation, data processing, or specific operations.
**Note:** Scripts may be executed without loading into context, but can still be read by Claude for patching or environment adjustments.
### references/
Documentation and reference material intended to be loaded into context to inform Claude's process and thinking.
**Examples from other skills:**
- Product management: `communication.md`, `context_building.md` - detailed workflow guides
- BigQuery: API reference documentation and query examples
- Finance: Schema documentation, company policies
**Appropriate for:** In-depth documentation, API references, database schemas, comprehensive guides, or any detailed information that Claude should reference while working.
### assets/
Files not intended to be loaded into context, but rather used within the output Claude produces.
**Examples from other skills:**
- Brand styling: PowerPoint template files (.pptx), logo files
- Frontend builder: HTML/React boilerplate project directories
- Typography: Font files (.ttf, .woff2)
**Appropriate for:** Templates, boilerplate code, document templates, images, icons, fonts, or any files meant to be copied or used in the final output.
---
**Any unneeded directories can be deleted.** Not every skill requires all three types of resources.
"""
EXAMPLE_SCRIPT = '''#!/usr/bin/env python3
"""
Example helper script for {skill_name}
This is a placeholder script that can be executed directly.
Replace with actual implementation or delete if not needed.
Example real scripts from other skills:
- pdf/scripts/fill_fillable_fields.py - Fills PDF form fields
- pdf/scripts/convert_pdf_to_images.py - Converts PDF pages to images
"""
def main():
print("This is an example script for {skill_name}")
# TODO: Add actual script logic here
# This could be data processing, file conversion, API calls, etc.
if __name__ == "__main__":
main()
'''
EXAMPLE_REFERENCE = """# Reference Documentation for {skill_title}
This is a placeholder for detailed reference documentation.
Replace with actual reference content or delete if not needed.
Example real reference docs from other skills:
- product-management/references/communication.md - Comprehensive guide for status updates
- product-management/references/context_building.md - Deep-dive on gathering context
- bigquery/references/ - API references and query examples
## When Reference Docs Are Useful
Reference docs are ideal for:
- Comprehensive API documentation
- Detailed workflow guides
- Complex multi-step processes
- Information too lengthy for main SKILL.md
- Content that's only needed for specific use cases
## Structure Suggestions
### API Reference Example
- Overview
- Authentication
- Endpoints with examples
- Error codes
- Rate limits
### Workflow Guide Example
- Prerequisites
- Step-by-step instructions
- Common patterns
- Troubleshooting
- Best practices
"""
EXAMPLE_ASSET = """# Example Asset File
This placeholder represents where asset files would be stored.
Replace with actual asset files (templates, images, fonts, etc.) or delete if not needed.
Asset files are NOT intended to be loaded into context, but rather used within
the output Claude produces.
Example asset files from other skills:
- Brand guidelines: logo.png, slides_template.pptx
- Frontend builder: hello-world/ directory with HTML/React boilerplate
- Typography: custom-font.ttf, font-family.woff2
- Data: sample_data.csv, test_dataset.json
## Common Asset Types
- Templates: .pptx, .docx, boilerplate directories
- Images: .png, .jpg, .svg, .gif
- Fonts: .ttf, .otf, .woff, .woff2
- Boilerplate code: Project directories, starter files
- Icons: .ico, .svg
- Data files: .csv, .json, .xml, .yaml
Note: This is a text placeholder. Actual assets can be any file type.
"""
def title_case_skill_name(skill_name):
"""Convert hyphenated skill name to Title Case for display."""
return " ".join(word.capitalize() for word in skill_name.split("-"))
def init_skill(skill_name, path):
"""
Initialize a new skill directory with template SKILL.md.
Args:
skill_name: Name of the skill
path: Path where the skill directory should be created
Returns:
Path to created skill directory, or None if error
"""
# Determine skill directory path
skill_dir = Path(path).resolve() / skill_name
# Check if directory already exists
if skill_dir.exists():
print(f"❌ Error: Skill directory already exists: {skill_dir}")
return None
# Create skill directory
try:
skill_dir.mkdir(parents=True, exist_ok=False)
print(f"✅ Created skill directory: {skill_dir}")
except Exception as e:
print(f"❌ Error creating directory: {e}")
return None
# Create SKILL.md from template
skill_title = title_case_skill_name(skill_name)
skill_content = SKILL_TEMPLATE.format(skill_name=skill_name, skill_title=skill_title)
skill_md_path = skill_dir / "SKILL.md"
try:
skill_md_path.write_text(skill_content)
print("✅ Created SKILL.md")
except Exception as e:
print(f"❌ Error creating SKILL.md: {e}")
return None
# Create resource directories with example files
try:
# Create scripts/ directory with example script
scripts_dir = skill_dir / "scripts"
scripts_dir.mkdir(exist_ok=True)
example_script = scripts_dir / "example.py"
example_script.write_text(EXAMPLE_SCRIPT.format(skill_name=skill_name))
example_script.chmod(0o755)
print("✅ Created scripts/example.py")
# Create references/ directory with example reference doc
references_dir = skill_dir / "references"
references_dir.mkdir(exist_ok=True)
example_reference = references_dir / "api_reference.md"
example_reference.write_text(EXAMPLE_REFERENCE.format(skill_title=skill_title))
print("✅ Created references/api_reference.md")
# Create assets/ directory with example asset placeholder
assets_dir = skill_dir / "assets"
assets_dir.mkdir(exist_ok=True)
example_asset = assets_dir / "example_asset.txt"
example_asset.write_text(EXAMPLE_ASSET)
print("✅ Created assets/example_asset.txt")
except Exception as e:
print(f"❌ Error creating resource directories: {e}")
return None
# Print next steps
print(f"\n✅ Skill '{skill_name}' initialized successfully at {skill_dir}")
print("\nNext steps:")
print("1. Edit SKILL.md to complete the TODO items and update the description")
print("2. Customize or delete the example files in scripts/, references/, and assets/")
print("3. Run the validator when ready to check the skill structure")
return skill_dir
def main():
if len(sys.argv) < 4 or sys.argv[2] != "--path":
print("Usage: init_skill.py <skill-name> --path <path>")
print("\nSkill name requirements:")
print(" - Hyphen-case identifier (e.g., 'data-analyzer')")
print(" - Lowercase letters, digits, and hyphens only")
print(" - Max 40 characters")
print(" - Must match directory name exactly")
print("\nExamples:")
print(" init_skill.py my-new-skill --path skills/public")
print(" init_skill.py my-api-helper --path skills/private")
print(" init_skill.py custom-skill --path /custom/location")
sys.exit(1)
skill_name = sys.argv[1]
path = sys.argv[3]
print(f"🚀 Initializing skill: {skill_name}")
print(f" Location: {path}")
print()
result = init_skill(skill_name, path)
if result:
sys.exit(0)
else:
sys.exit(1)
if __name__ == "__main__":
main()

View File

@ -1,110 +0,0 @@
#!/usr/bin/env python3
"""
Skill Packager - Creates a distributable .skill file of a skill folder
Usage:
python utils/package_skill.py <path/to/skill-folder> [output-directory]
Example:
python utils/package_skill.py skills/public/my-skill
python utils/package_skill.py skills/public/my-skill ./dist
"""
import sys
import zipfile
from pathlib import Path
from quick_validate import validate_skill
def package_skill(skill_path, output_dir=None):
"""
Package a skill folder into a .skill file.
Args:
skill_path: Path to the skill folder
output_dir: Optional output directory for the .skill file (defaults to current directory)
Returns:
Path to the created .skill file, or None if error
"""
skill_path = Path(skill_path).resolve()
# Validate skill folder exists
if not skill_path.exists():
print(f"❌ Error: Skill folder not found: {skill_path}")
return None
if not skill_path.is_dir():
print(f"❌ Error: Path is not a directory: {skill_path}")
return None
# Validate SKILL.md exists
skill_md = skill_path / "SKILL.md"
if not skill_md.exists():
print(f"❌ Error: SKILL.md not found in {skill_path}")
return None
# Run validation before packaging
print("🔍 Validating skill...")
valid, message = validate_skill(skill_path)
if not valid:
print(f"❌ Validation failed: {message}")
print(" Please fix the validation errors before packaging.")
return None
print(f"{message}\n")
# Determine output location
skill_name = skill_path.name
if output_dir:
output_path = Path(output_dir).resolve()
output_path.mkdir(parents=True, exist_ok=True)
else:
output_path = Path.cwd()
skill_filename = output_path / f"{skill_name}.skill"
# Create the .skill file (zip format)
try:
with zipfile.ZipFile(skill_filename, "w", zipfile.ZIP_DEFLATED) as zipf:
# Walk through the skill directory
for file_path in skill_path.rglob("*"):
if file_path.is_file():
# Calculate the relative path within the zip
arcname = file_path.relative_to(skill_path.parent)
zipf.write(file_path, arcname)
print(f" Added: {arcname}")
print(f"\n✅ Successfully packaged skill to: {skill_filename}")
return skill_filename
except Exception as e:
print(f"❌ Error creating .skill file: {e}")
return None
def main():
if len(sys.argv) < 2:
print("Usage: python utils/package_skill.py <path/to/skill-folder> [output-directory]")
print("\nExample:")
print(" python utils/package_skill.py skills/public/my-skill")
print(" python utils/package_skill.py skills/public/my-skill ./dist")
sys.exit(1)
skill_path = sys.argv[1]
output_dir = sys.argv[2] if len(sys.argv) > 2 else None
print(f"📦 Packaging skill: {skill_path}")
if output_dir:
print(f" Output directory: {output_dir}")
print()
result = package_skill(skill_path, output_dir)
if result:
sys.exit(0)
else:
sys.exit(1)
if __name__ == "__main__":
main()

View File

@ -1,97 +0,0 @@
#!/usr/bin/env python3
"""
Quick validation script for skills - minimal version
"""
import sys
import os
import re
import yaml
from pathlib import Path
def validate_skill(skill_path):
"""Basic validation of a skill"""
skill_path = Path(skill_path)
# Check SKILL.md exists
skill_md = skill_path / "SKILL.md"
if not skill_md.exists():
return False, "SKILL.md not found"
# Read and validate frontmatter
content = skill_md.read_text()
if not content.startswith("---"):
return False, "No YAML frontmatter found"
# Extract frontmatter
match = re.match(r"^---\n(.*?)\n---", content, re.DOTALL)
if not match:
return False, "Invalid frontmatter format"
frontmatter_text = match.group(1)
# Parse YAML frontmatter
try:
frontmatter = yaml.safe_load(frontmatter_text)
if not isinstance(frontmatter, dict):
return False, "Frontmatter must be a YAML dictionary"
except yaml.YAMLError as e:
return False, f"Invalid YAML in frontmatter: {e}"
# Define allowed properties
ALLOWED_PROPERTIES = {"name", "description", "license", "allowed-tools", "metadata"}
# Check for unexpected properties (excluding nested keys under metadata)
unexpected_keys = set(frontmatter.keys()) - ALLOWED_PROPERTIES
if unexpected_keys:
return False, (
f"Unexpected key(s) in SKILL.md frontmatter: {', '.join(sorted(unexpected_keys))}. "
f"Allowed properties are: {', '.join(sorted(ALLOWED_PROPERTIES))}"
)
# Check required fields
if "name" not in frontmatter:
return False, "Missing 'name' in frontmatter"
if "description" not in frontmatter:
return False, "Missing 'description' in frontmatter"
# Extract name for validation
name = frontmatter.get("name", "")
if not isinstance(name, str):
return False, f"Name must be a string, got {type(name).__name__}"
name = name.strip()
if name:
# Check naming convention (hyphen-case: lowercase with hyphens)
if not re.match(r"^[a-z0-9-]+$", name):
return False, f"Name '{name}' should be hyphen-case (lowercase letters, digits, and hyphens only)"
if name.startswith("-") or name.endswith("-") or "--" in name:
return False, f"Name '{name}' cannot start/end with hyphen or contain consecutive hyphens"
# Check name length (max 64 characters per spec)
if len(name) > 64:
return False, f"Name is too long ({len(name)} characters). Maximum is 64 characters."
# Extract and validate description
description = frontmatter.get("description", "")
if not isinstance(description, str):
return False, f"Description must be a string, got {type(description).__name__}"
description = description.strip()
if description:
# Check for angle brackets
if "<" in description or ">" in description:
return False, "Description cannot contain angle brackets (< or >)"
# Check description length (max 1024 characters per spec)
if len(description) > 1024:
return False, f"Description is too long ({len(description)} characters). Maximum is 1024 characters."
return True, "Skill is valid!"
if __name__ == "__main__":
if len(sys.argv) != 2:
print("Usage: python quick_validate.py <skill_directory>")
sys.exit(1)
valid, message = validate_skill(sys.argv[1])
print(message)
sys.exit(0 if valid else 1)

View File

@ -20,4 +20,4 @@
- [x] I understand that this PR may be closed in case there was no previous discussion or issues. (This doesn't apply to typos!)
- [x] I've added a test for each change that was introduced, and I tried as much as possible to make a single atomic change.
- [x] I've updated the documentation accordingly.
- [x] I ran `make lint` and `make type-check` (backend) and `cd web && npx lint-staged` (frontend) to appease the lint gods
- [x] I ran `dev/reformat`(backend) and `cd web && npx lint-staged`(frontend) to appease the lint gods

View File

@ -39,6 +39,12 @@ jobs:
- name: Install dependencies
run: uv sync --project api --dev
- name: Run pyrefly check
run: |
cd api
uv add --dev pyrefly
uv run pyrefly check || true
- name: Run dify config tests
run: uv run --project api dev/pytest/pytest_config_tests.py

View File

@ -1,29 +0,0 @@
name: Deploy HITL
on:
workflow_run:
workflows: ["Build and Push API & Web"]
branches:
- "feat/hitl-frontend"
- "feat/hitl-backend"
types:
- completed
jobs:
deploy:
runs-on: ubuntu-latest
if: |
github.event.workflow_run.conclusion == 'success' &&
(
github.event.workflow_run.head_branch == 'feat/hitl-frontend' ||
github.event.workflow_run.head_branch == 'feat/hitl-backend'
)
steps:
- name: Deploy to server
uses: appleboy/ssh-action@v0.1.8
with:
host: ${{ secrets.HITL_SSH_HOST }}
username: ${{ secrets.SSH_USER }}
key: ${{ secrets.SSH_PRIVATE_KEY }}
script: |
${{ vars.SSH_SCRIPT || secrets.SSH_SCRIPT }}

View File

@ -1,4 +1,4 @@
name: Deploy Agent Dev
name: Deploy Trigger Dev
permissions:
contents: read
@ -7,7 +7,7 @@ on:
workflow_run:
workflows: ["Build and Push API & Web"]
branches:
- "deploy/agent-dev"
- "deploy/trigger-dev"
types:
- completed
@ -16,12 +16,12 @@ jobs:
runs-on: ubuntu-latest
if: |
github.event.workflow_run.conclusion == 'success' &&
github.event.workflow_run.head_branch == 'deploy/agent-dev'
github.event.workflow_run.head_branch == 'deploy/trigger-dev'
steps:
- name: Deploy to server
uses: appleboy/ssh-action@v0.1.8
with:
host: ${{ secrets.AGENT_DEV_SSH_HOST }}
host: ${{ secrets.TRIGGER_SSH_HOST }}
username: ${{ secrets.SSH_USER }}
key: ${{ secrets.SSH_PRIVATE_KEY }}
script: |

View File

@ -0,0 +1,94 @@
name: Translate i18n Files Based on English
on:
push:
branches: [main]
paths:
- 'web/i18n/en-US/*.json'
workflow_dispatch:
permissions:
contents: write
pull-requests: write
jobs:
check-and-update:
if: github.repository == 'langgenius/dify'
runs-on: ubuntu-latest
defaults:
run:
working-directory: web
steps:
# Keep use old checkout action version for https://github.com/peter-evans/create-pull-request/issues/4272
- uses: actions/checkout@v4
with:
fetch-depth: 0
token: ${{ secrets.GITHUB_TOKEN }}
- name: Check for file changes in i18n/en-US
id: check_files
run: |
# Skip check for manual trigger, translate all files
if [ "${{ github.event_name }}" == "workflow_dispatch" ]; then
echo "FILES_CHANGED=true" >> $GITHUB_ENV
echo "FILE_ARGS=" >> $GITHUB_ENV
echo "Manual trigger: translating all files"
else
git fetch origin "${{ github.event.before }}" || true
git fetch origin "${{ github.sha }}" || true
changed_files=$(git diff --name-only "${{ github.event.before }}" "${{ github.sha }}" -- 'i18n/en-US/*.json')
echo "Changed files: $changed_files"
if [ -n "$changed_files" ]; then
echo "FILES_CHANGED=true" >> $GITHUB_ENV
file_args=""
for file in $changed_files; do
filename=$(basename "$file" .json)
file_args="$file_args --file $filename"
done
echo "FILE_ARGS=$file_args" >> $GITHUB_ENV
echo "File arguments: $file_args"
else
echo "FILES_CHANGED=false" >> $GITHUB_ENV
fi
fi
- name: Install pnpm
uses: pnpm/action-setup@v4
with:
package_json_file: web/package.json
run_install: false
- name: Set up Node.js
if: env.FILES_CHANGED == 'true'
uses: actions/setup-node@v6
with:
node-version: 'lts/*'
cache: pnpm
cache-dependency-path: ./web/pnpm-lock.yaml
- name: Install dependencies
if: env.FILES_CHANGED == 'true'
working-directory: ./web
run: pnpm install --frozen-lockfile
- name: Generate i18n translations
if: env.FILES_CHANGED == 'true'
working-directory: ./web
run: pnpm run i18n:gen ${{ env.FILE_ARGS }}
- name: Create Pull Request
if: env.FILES_CHANGED == 'true'
uses: peter-evans/create-pull-request@v6
with:
token: ${{ secrets.GITHUB_TOKEN }}
commit-message: 'chore(i18n): update translations based on en-US changes'
title: 'chore(i18n): translate i18n files based on en-US changes'
body: |
This PR was automatically created to update i18n translation files based on changes in en-US locale.
**Triggered by:** ${{ github.sha }}
**Changes included:**
- Updated translation files for all locales
branch: chore/automated-i18n-updates-${{ github.sha }}
delete-branch: true

View File

@ -1,421 +0,0 @@
name: Translate i18n Files with Claude Code
# Note: claude-code-action doesn't support push events directly.
# Push events are handled by trigger-i18n-sync.yml which sends repository_dispatch.
# See: https://github.com/langgenius/dify/issues/30743
on:
repository_dispatch:
types: [i18n-sync]
workflow_dispatch:
inputs:
files:
description: 'Specific files to translate (space-separated, e.g., "app common"). Leave empty for all files.'
required: false
type: string
languages:
description: 'Specific languages to translate (space-separated, e.g., "zh-Hans ja-JP"). Leave empty for all supported languages.'
required: false
type: string
mode:
description: 'Sync mode: incremental (only changes) or full (re-check all keys)'
required: false
default: 'incremental'
type: choice
options:
- incremental
- full
permissions:
contents: write
pull-requests: write
jobs:
translate:
if: github.repository == 'langgenius/dify'
runs-on: ubuntu-latest
timeout-minutes: 60
steps:
- name: Checkout repository
uses: actions/checkout@v6
with:
fetch-depth: 0
token: ${{ secrets.GITHUB_TOKEN }}
- name: Configure Git
run: |
git config --global user.name "github-actions[bot]"
git config --global user.email "github-actions[bot]@users.noreply.github.com"
- name: Install pnpm
uses: pnpm/action-setup@v4
with:
package_json_file: web/package.json
run_install: false
- name: Set up Node.js
uses: actions/setup-node@v6
with:
node-version: 'lts/*'
cache: pnpm
cache-dependency-path: ./web/pnpm-lock.yaml
- name: Detect changed files and generate diff
id: detect_changes
run: |
if [ "${{ github.event_name }}" == "workflow_dispatch" ]; then
# Manual trigger
if [ -n "${{ github.event.inputs.files }}" ]; then
echo "CHANGED_FILES=${{ github.event.inputs.files }}" >> $GITHUB_OUTPUT
else
# Get all JSON files in en-US directory
files=$(ls web/i18n/en-US/*.json 2>/dev/null | xargs -n1 basename | sed 's/.json$//' | tr '\n' ' ')
echo "CHANGED_FILES=$files" >> $GITHUB_OUTPUT
fi
echo "TARGET_LANGS=${{ github.event.inputs.languages }}" >> $GITHUB_OUTPUT
echo "SYNC_MODE=${{ github.event.inputs.mode || 'incremental' }}" >> $GITHUB_OUTPUT
# For manual trigger with incremental mode, get diff from last commit
# For full mode, we'll do a complete check anyway
if [ "${{ github.event.inputs.mode }}" == "full" ]; then
echo "Full mode: will check all keys" > /tmp/i18n-diff.txt
echo "DIFF_AVAILABLE=false" >> $GITHUB_OUTPUT
else
git diff HEAD~1..HEAD -- 'web/i18n/en-US/*.json' > /tmp/i18n-diff.txt 2>/dev/null || echo "" > /tmp/i18n-diff.txt
if [ -s /tmp/i18n-diff.txt ]; then
echo "DIFF_AVAILABLE=true" >> $GITHUB_OUTPUT
else
echo "DIFF_AVAILABLE=false" >> $GITHUB_OUTPUT
fi
fi
elif [ "${{ github.event_name }}" == "repository_dispatch" ]; then
# Triggered by push via trigger-i18n-sync.yml workflow
# Validate required payload fields
if [ -z "${{ github.event.client_payload.changed_files }}" ]; then
echo "Error: repository_dispatch payload missing required 'changed_files' field" >&2
exit 1
fi
echo "CHANGED_FILES=${{ github.event.client_payload.changed_files }}" >> $GITHUB_OUTPUT
echo "TARGET_LANGS=" >> $GITHUB_OUTPUT
echo "SYNC_MODE=${{ github.event.client_payload.sync_mode || 'incremental' }}" >> $GITHUB_OUTPUT
# Decode the base64-encoded diff from the trigger workflow
if [ -n "${{ github.event.client_payload.diff_base64 }}" ]; then
if ! echo "${{ github.event.client_payload.diff_base64 }}" | base64 -d > /tmp/i18n-diff.txt 2>&1; then
echo "Warning: Failed to decode base64 diff payload" >&2
echo "" > /tmp/i18n-diff.txt
echo "DIFF_AVAILABLE=false" >> $GITHUB_OUTPUT
elif [ -s /tmp/i18n-diff.txt ]; then
echo "DIFF_AVAILABLE=true" >> $GITHUB_OUTPUT
else
echo "DIFF_AVAILABLE=false" >> $GITHUB_OUTPUT
fi
else
echo "" > /tmp/i18n-diff.txt
echo "DIFF_AVAILABLE=false" >> $GITHUB_OUTPUT
fi
else
echo "Unsupported event type: ${{ github.event_name }}"
exit 1
fi
# Truncate diff if too large (keep first 50KB)
if [ -f /tmp/i18n-diff.txt ]; then
head -c 50000 /tmp/i18n-diff.txt > /tmp/i18n-diff-truncated.txt
mv /tmp/i18n-diff-truncated.txt /tmp/i18n-diff.txt
fi
echo "Detected files: $(cat $GITHUB_OUTPUT | grep CHANGED_FILES || echo 'none')"
- name: Run Claude Code for Translation Sync
if: steps.detect_changes.outputs.CHANGED_FILES != ''
uses: anthropics/claude-code-action@v1
with:
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
github_token: ${{ secrets.GITHUB_TOKEN }}
prompt: |
You are a professional i18n synchronization engineer for the Dify project.
Your task is to keep all language translations in sync with the English source (en-US).
## CRITICAL TOOL RESTRICTIONS
- Use **Read** tool to read files (NOT cat or bash)
- Use **Edit** tool to modify JSON files (NOT node, jq, or bash scripts)
- Use **Bash** ONLY for: git commands, gh commands, pnpm commands
- Run bash commands ONE BY ONE, never combine with && or ||
- NEVER use `$()` command substitution - it's not supported. Split into separate commands instead.
## WORKING DIRECTORY & ABSOLUTE PATHS
Claude Code sandbox working directory may vary. Always use absolute paths:
- For pnpm: `pnpm --dir ${{ github.workspace }}/web <command>`
- For git: `git -C ${{ github.workspace }} <command>`
- For gh: `gh --repo ${{ github.repository }} <command>`
- For file paths: `${{ github.workspace }}/web/i18n/`
## EFFICIENCY RULES
- **ONE Edit per language file** - batch all key additions into a single Edit
- Insert new keys at the beginning of JSON (after `{`), lint:fix will sort them
- Translate ALL keys for a language mentally first, then do ONE Edit
## Context
- Changed/target files: ${{ steps.detect_changes.outputs.CHANGED_FILES }}
- Target languages (empty means all supported): ${{ steps.detect_changes.outputs.TARGET_LANGS }}
- Sync mode: ${{ steps.detect_changes.outputs.SYNC_MODE }}
- Translation files are located in: ${{ github.workspace }}/web/i18n/{locale}/{filename}.json
- Language configuration is in: ${{ github.workspace }}/web/i18n-config/languages.ts
- Git diff is available: ${{ steps.detect_changes.outputs.DIFF_AVAILABLE }}
## CRITICAL DESIGN: Verify First, Then Sync
You MUST follow this three-phase approach:
═══════════════════════════════════════════════════════════════
║ PHASE 1: VERIFY - Analyze and Generate Change Report ║
═══════════════════════════════════════════════════════════════
### Step 1.1: Analyze Git Diff (for incremental mode)
Use the Read tool to read `/tmp/i18n-diff.txt` to see the git diff.
Parse the diff to categorize changes:
- Lines with `+` (not `+++`): Added or modified values
- Lines with `-` (not `---`): Removed or old values
- Identify specific keys for each category:
* ADD: Keys that appear only in `+` lines (new keys)
* UPDATE: Keys that appear in both `-` and `+` lines (value changed)
* DELETE: Keys that appear only in `-` lines (removed keys)
### Step 1.2: Read Language Configuration
Use the Read tool to read `${{ github.workspace }}/web/i18n-config/languages.ts`.
Extract all languages with `supported: true`.
### Step 1.3: Run i18n:check for Each Language
```bash
pnpm --dir ${{ github.workspace }}/web install --frozen-lockfile
```
```bash
pnpm --dir ${{ github.workspace }}/web run i18n:check
```
This will report:
- Missing keys (need to ADD)
- Extra keys (need to DELETE)
### Step 1.4: Generate Change Report
Create a structured report identifying:
```
╔══════════════════════════════════════════════════════════════╗
║ I18N SYNC CHANGE REPORT ║
╠══════════════════════════════════════════════════════════════╣
║ Files to process: [list] ║
║ Languages to sync: [list] ║
╠══════════════════════════════════════════════════════════════╣
║ ADD (New Keys): ║
║ - [filename].[key]: "English value" ║
║ ... ║
╠══════════════════════════════════════════════════════════════╣
║ UPDATE (Modified Keys - MUST re-translate): ║
║ - [filename].[key]: "Old value" → "New value" ║
║ ... ║
╠══════════════════════════════════════════════════════════════╣
║ DELETE (Extra Keys): ║
║ - [language]/[filename].[key] ║
║ ... ║
╚══════════════════════════════════════════════════════════════╝
```
**IMPORTANT**: For UPDATE detection, compare git diff to find keys where
the English value changed. These MUST be re-translated even if target
language already has a translation (it's now stale!).
═══════════════════════════════════════════════════════════════
║ PHASE 2: SYNC - Execute Changes Based on Report ║
═══════════════════════════════════════════════════════════════
### Step 2.1: Process ADD Operations (BATCH per language file)
**CRITICAL WORKFLOW for efficiency:**
1. First, translate ALL new keys for ALL languages mentally
2. Then, for EACH language file, do ONE Edit operation:
- Read the file once
- Insert ALL new keys at the beginning (right after the opening `{`)
- Don't worry about alphabetical order - lint:fix will sort them later
Example Edit (adding 3 keys to zh-Hans/app.json):
```
old_string: '{\n "accessControl"'
new_string: '{\n "newKey1": "translation1",\n "newKey2": "translation2",\n "newKey3": "translation3",\n "accessControl"'
```
**IMPORTANT**:
- ONE Edit per language file (not one Edit per key!)
- Always use the Edit tool. NEVER use bash scripts, node, or jq.
### Step 2.2: Process UPDATE Operations
**IMPORTANT: Special handling for zh-Hans and ja-JP**
If zh-Hans or ja-JP files were ALSO modified in the same push:
- Run: `git -C ${{ github.workspace }} diff HEAD~1 --name-only` and check for zh-Hans or ja-JP files
- If found, it means someone manually translated them. Apply these rules:
1. **Missing keys**: Still ADD them (completeness required)
2. **Existing translations**: Compare with the NEW English value:
- If translation is **completely wrong** or **unrelated** → Update it
- If translation is **roughly correct** (captures the meaning) → Keep it, respect manual work
- When in doubt, **keep the manual translation**
Example:
- English changed: "Save" → "Save Changes"
- Manual translation: "保存更改" → Keep it (correct meaning)
- Manual translation: "删除" → Update it (completely wrong)
For other languages:
Use Edit tool to replace the old value with the new translation.
You can batch multiple updates in one Edit if they are adjacent.
### Step 2.3: Process DELETE Operations
For extra keys reported by i18n:check:
- Run: `pnpm --dir ${{ github.workspace }}/web run i18n:check --auto-remove`
- Or manually remove from target language JSON files
## Translation Guidelines
- PRESERVE all placeholders exactly as-is:
- `{{variable}}` - Mustache interpolation
- `${variable}` - Template literal
- `<tag>content</tag>` - HTML tags
- `_one`, `_other` - Pluralization suffixes (these are KEY suffixes, not values)
- Use appropriate language register (formal/informal) based on existing translations
- Match existing translation style in each language
- Technical terms: check existing conventions per language
- For CJK languages: no spaces between characters unless necessary
- For RTL languages (ar-TN, fa-IR): ensure proper text handling
## Output Format Requirements
- Alphabetical key ordering (if original file uses it)
- 2-space indentation
- Trailing newline at end of file
- Valid JSON (use proper escaping for special characters)
═══════════════════════════════════════════════════════════════
║ PHASE 3: RE-VERIFY - Confirm All Issues Resolved ║
═══════════════════════════════════════════════════════════════
### Step 3.1: Run Lint Fix (IMPORTANT!)
```bash
pnpm --dir ${{ github.workspace }}/web lint:fix --quiet -- 'i18n/**/*.json'
```
This ensures:
- JSON keys are sorted alphabetically (jsonc/sort-keys rule)
- Valid i18n keys (dify-i18n/valid-i18n-keys rule)
- No extra keys (dify-i18n/no-extra-keys rule)
### Step 3.2: Run Final i18n Check
```bash
pnpm --dir ${{ github.workspace }}/web run i18n:check
```
### Step 3.3: Fix Any Remaining Issues
If check reports issues:
- Go back to PHASE 2 for unresolved items
- Repeat until check passes
### Step 3.4: Generate Final Summary
```
╔══════════════════════════════════════════════════════════════╗
║ SYNC COMPLETED SUMMARY ║
╠══════════════════════════════════════════════════════════════╣
║ Language │ Added │ Updated │ Deleted │ Status ║
╠══════════════════════════════════════════════════════════════╣
║ zh-Hans │ 5 │ 2 │ 1 │ ✓ Complete ║
║ ja-JP │ 5 │ 2 │ 1 │ ✓ Complete ║
║ ... │ ... │ ... │ ... │ ... ║
╠══════════════════════════════════════════════════════════════╣
║ i18n:check │ PASSED - All keys in sync ║
╚══════════════════════════════════════════════════════════════╝
```
## Mode-Specific Behavior
**SYNC_MODE = "incremental"** (default):
- Focus on keys identified from git diff
- Also check i18n:check output for any missing/extra keys
- Efficient for small changes
**SYNC_MODE = "full"**:
- Compare ALL keys between en-US and each language
- Run i18n:check to identify all discrepancies
- Use for first-time sync or fixing historical issues
## Important Notes
1. Always run i18n:check BEFORE and AFTER making changes
2. The check script is the source of truth for missing/extra keys
3. For UPDATE scenario: git diff is the source of truth for changed values
4. Create a single commit with all translation changes
5. If any translation fails, continue with others and report failures
═══════════════════════════════════════════════════════════════
║ PHASE 4: COMMIT AND CREATE PR ║
═══════════════════════════════════════════════════════════════
After all translations are complete and verified:
### Step 4.1: Check for changes
```bash
git -C ${{ github.workspace }} status --porcelain
```
If there are changes:
### Step 4.2: Create a new branch and commit
Run these git commands ONE BY ONE (not combined with &&).
**IMPORTANT**: Do NOT use `$()` command substitution. Use two separate commands:
1. First, get the timestamp:
```bash
date +%Y%m%d-%H%M%S
```
(Note the output, e.g., "20260115-143052")
2. Then create branch using the timestamp value:
```bash
git -C ${{ github.workspace }} checkout -b chore/i18n-sync-20260115-143052
```
(Replace "20260115-143052" with the actual timestamp from step 1)
3. Stage changes:
```bash
git -C ${{ github.workspace }} add web/i18n/
```
4. Commit:
```bash
git -C ${{ github.workspace }} commit -m "chore(i18n): sync translations with en-US - Mode: ${{ steps.detect_changes.outputs.SYNC_MODE }}"
```
5. Push:
```bash
git -C ${{ github.workspace }} push origin HEAD
```
### Step 4.3: Create Pull Request
```bash
gh pr create --repo ${{ github.repository }} --title "chore(i18n): sync translations with en-US" --body "## Summary
This PR was automatically generated to sync i18n translation files.
### Changes
- Mode: ${{ steps.detect_changes.outputs.SYNC_MODE }}
- Files processed: ${{ steps.detect_changes.outputs.CHANGED_FILES }}
### Verification
- [x] \`i18n:check\` passed
- [x] \`lint:fix\` applied
🤖 Generated with Claude Code GitHub Action" --base main
```
claude_args: |
--max-turns 150
--allowedTools "Read,Write,Edit,Bash(git *),Bash(git:*),Bash(gh *),Bash(gh:*),Bash(pnpm *),Bash(pnpm:*),Bash(date *),Bash(date:*),Glob,Grep"

View File

@ -1,66 +0,0 @@
name: Trigger i18n Sync on Push
# This workflow bridges the push event to repository_dispatch
# because claude-code-action doesn't support push events directly.
# See: https://github.com/langgenius/dify/issues/30743
on:
push:
branches: [main]
paths:
- 'web/i18n/en-US/*.json'
permissions:
contents: write
jobs:
trigger:
if: github.repository == 'langgenius/dify'
runs-on: ubuntu-latest
timeout-minutes: 5
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Detect changed files and generate diff
id: detect
run: |
BEFORE_SHA="${{ github.event.before }}"
# Handle edge case: force push may have null/zero SHA
if [ -z "$BEFORE_SHA" ] || [ "$BEFORE_SHA" = "0000000000000000000000000000000000000000" ]; then
BEFORE_SHA="HEAD~1"
fi
# Detect changed i18n files
changed=$(git diff --name-only "$BEFORE_SHA" "${{ github.sha }}" -- 'web/i18n/en-US/*.json' 2>/dev/null | xargs -n1 basename 2>/dev/null | sed 's/.json$//' | tr '\n' ' ' || echo "")
echo "changed_files=$changed" >> $GITHUB_OUTPUT
# Generate diff for context
git diff "$BEFORE_SHA" "${{ github.sha }}" -- 'web/i18n/en-US/*.json' > /tmp/i18n-diff.txt 2>/dev/null || echo "" > /tmp/i18n-diff.txt
# Truncate if too large (keep first 50KB to match receiving workflow)
head -c 50000 /tmp/i18n-diff.txt > /tmp/i18n-diff-truncated.txt
mv /tmp/i18n-diff-truncated.txt /tmp/i18n-diff.txt
# Base64 encode the diff for safe JSON transport (portable, single-line)
diff_base64=$(base64 < /tmp/i18n-diff.txt | tr -d '\n')
echo "diff_base64=$diff_base64" >> $GITHUB_OUTPUT
if [ -n "$changed" ]; then
echo "has_changes=true" >> $GITHUB_OUTPUT
echo "Detected changed files: $changed"
else
echo "has_changes=false" >> $GITHUB_OUTPUT
echo "No i18n changes detected"
fi
- name: Trigger i18n sync workflow
if: steps.detect.outputs.has_changes == 'true'
uses: peter-evans/repository-dispatch@v3
with:
token: ${{ secrets.GITHUB_TOKEN }}
event-type: i18n-sync
client-payload: '{"changed_files": "${{ steps.detect.outputs.changed_files }}", "diff_base64": "${{ steps.detect.outputs.diff_base64 }}", "sync_mode": "incremental", "trigger_sha": "${{ github.sha }}"}'

1
.nvmrc Normal file
View File

@ -0,0 +1 @@
22.11.0

View File

@ -60,10 +60,9 @@ check:
@echo "✅ Code check complete"
lint:
@echo "🔧 Running ruff format, check with fixes, import linter, and dotenv-linter..."
@echo "🔧 Running ruff format, check with fixes, and import linter..."
@uv run --project api --dev sh -c 'ruff format ./api && ruff check --fix ./api'
@uv run --directory api --dev lint-imports
@uv run --project api --dev dotenv-linter ./api/.env.example ./web/.env.example
@echo "✅ Linting complete"
type-check:
@ -123,7 +122,7 @@ help:
@echo "Backend Code Quality:"
@echo " make format - Format code with ruff"
@echo " make check - Check code with ruff"
@echo " make lint - Format, fix, and lint code (ruff, imports, dotenv)"
@echo " make lint - Format and fix code with ruff"
@echo " make type-check - Run type checking with basedpyright"
@echo " make test - Run backend unit tests"
@echo ""

View File

@ -575,10 +575,6 @@ LOGSTORE_DUAL_WRITE_ENABLED=false
# Enable dual-read fallback to SQL database when LogStore returns no results (default: true)
# Useful for migration scenarios where historical data exists only in SQL database
LOGSTORE_DUAL_READ_ENABLED=true
# Control flag for whether to write the `graph` field to LogStore.
# If LOGSTORE_ENABLE_PUT_GRAPH_FIELD is "true", write the full `graph` field;
# otherwise write an empty {} instead. Defaults to writing the `graph` field.
LOGSTORE_ENABLE_PUT_GRAPH_FIELD=true
# Celery beat configuration
CELERY_BEAT_SCHEDULER_TIME=1
@ -589,7 +585,6 @@ ENABLE_CLEAN_UNUSED_DATASETS_TASK=false
ENABLE_CREATE_TIDB_SERVERLESS_TASK=false
ENABLE_UPDATE_TIDB_SERVERLESS_STATUS_TASK=false
ENABLE_CLEAN_MESSAGES=false
ENABLE_WORKFLOW_RUN_CLEANUP_TASK=false
ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK=false
ENABLE_DATASETS_QUEUE_MONITOR=false
ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK=true

View File

@ -3,11 +3,9 @@ root_packages =
core
configs
controllers
extensions
models
tasks
services
include_external_packages = True
[importlinter:contract:workflow]
name = Workflow
@ -35,28 +33,6 @@ ignore_imports =
core.workflow.nodes.loop.loop_node -> core.workflow.graph
core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine.command_channels
[importlinter:contract:workflow-infrastructure-dependencies]
name = Workflow Infrastructure Dependencies
type = forbidden
source_modules =
core.workflow
forbidden_modules =
extensions.ext_database
extensions.ext_redis
allow_indirect_imports = True
ignore_imports =
core.workflow.nodes.agent.agent_node -> extensions.ext_database
core.workflow.nodes.datasource.datasource_node -> extensions.ext_database
core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_database
core.workflow.nodes.llm.file_saver -> extensions.ext_database
core.workflow.nodes.llm.llm_utils -> extensions.ext_database
core.workflow.nodes.llm.node -> extensions.ext_database
core.workflow.nodes.tool.tool_node -> extensions.ext_database
core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis
core.workflow.graph_engine.manager -> extensions.ext_redis
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_redis
[importlinter:contract:rsc]
name = RSC
type = layers

View File

@ -50,33 +50,16 @@ WORKDIR /app/api
# Create non-root user
ARG dify_uid=1001
ARG NODE_MAJOR=22
ARG NODE_PACKAGE_VERSION=22.21.0-1nodesource1
ARG NODESOURCE_KEY_FPR=6F71F525282841EEDAF851B42F59B5F99B1BE0B4
RUN groupadd -r -g ${dify_uid} dify && \
useradd -r -u ${dify_uid} -g ${dify_uid} -s /bin/bash dify && \
chown -R dify:dify /app
RUN \
apt-get update \
&& apt-get install -y --no-install-recommends \
ca-certificates \
curl \
gnupg \
&& mkdir -p /etc/apt/keyrings \
&& curl -fsSL https://deb.nodesource.com/gpgkey/nodesource-repo.gpg.key -o /tmp/nodesource.gpg \
&& gpg --show-keys --with-colons /tmp/nodesource.gpg \
| awk -F: '/^fpr:/ {print $10}' \
| grep -Fx "${NODESOURCE_KEY_FPR}" \
&& gpg --dearmor -o /etc/apt/keyrings/nodesource.gpg /tmp/nodesource.gpg \
&& rm -f /tmp/nodesource.gpg \
&& echo "deb [signed-by=/etc/apt/keyrings/nodesource.gpg] https://deb.nodesource.com/node_${NODE_MAJOR}.x nodistro main" \
> /etc/apt/sources.list.d/nodesource.list \
&& apt-get update \
# Install dependencies
&& apt-get install -y --no-install-recommends \
# basic environment
nodejs=${NODE_PACKAGE_VERSION} \
curl nodejs \
# for gmpy2 \
libgmp-dev libmpfr-dev libmpc-dev \
# For Security
@ -96,8 +79,7 @@ COPY --from=packages --chown=dify:dify ${VIRTUAL_ENV} ${VIRTUAL_ENV}
ENV PATH="${VIRTUAL_ENV}/bin:${PATH}"
# Download nltk data
RUN mkdir -p /usr/local/share/nltk_data \
&& NLTK_DATA=/usr/local/share/nltk_data python -c "import nltk; from unstructured.nlp.tokenize import download_nltk_packages; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger'); nltk.download('stopwords'); download_nltk_packages()" \
RUN mkdir -p /usr/local/share/nltk_data && NLTK_DATA=/usr/local/share/nltk_data python -c "import nltk; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger'); nltk.download('stopwords')" \
&& chmod -R 755 /usr/local/share/nltk_data
ENV TIKTOKEN_CACHE_DIR=/app/api/.tiktoken_cache

View File

@ -1,5 +1,4 @@
import base64
import datetime
import json
import logging
import secrets
@ -35,7 +34,7 @@ from libs.rsa import generate_key_pair
from models import Tenant
from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment
from models.dataset import Document as DatasetDocument
from models.model import App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation, UploadFile
from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation, UploadFile
from models.oauth import DatasourceOauthParamConfig, DatasourceProvider
from models.provider import Provider, ProviderModel
from models.provider_ids import DatasourceProviderID, ToolProviderID
@ -46,7 +45,6 @@ from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpi
from services.plugin.data_migration import PluginDataMigration
from services.plugin.plugin_migration import PluginMigration
from services.plugin.plugin_service import PluginService
from services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs import WorkflowRunCleanup
from tasks.remove_app_and_related_data_task import delete_draft_variables_batch
logger = logging.getLogger(__name__)
@ -64,10 +62,8 @@ def reset_password(email, new_password, password_confirm):
if str(new_password).strip() != str(password_confirm).strip():
click.echo(click.style("Passwords do not match.", fg="red"))
return
normalized_email = email.strip().lower()
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session)
account = session.query(Account).where(Account.email == email).one_or_none()
if not account:
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
@ -88,7 +84,7 @@ def reset_password(email, new_password, password_confirm):
base64_password_hashed = base64.b64encode(password_hashed).decode()
account.password = base64_password_hashed
account.password_salt = base64_salt
AccountService.reset_login_error_rate_limit(normalized_email)
AccountService.reset_login_error_rate_limit(email)
click.echo(click.style("Password reset successfully.", fg="green"))
@ -104,22 +100,20 @@ def reset_email(email, new_email, email_confirm):
if str(new_email).strip() != str(email_confirm).strip():
click.echo(click.style("New emails do not match.", fg="red"))
return
normalized_new_email = new_email.strip().lower()
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session)
account = session.query(Account).where(Account.email == email).one_or_none()
if not account:
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
return
try:
email_validate(normalized_new_email)
email_validate(new_email)
except:
click.echo(click.style(f"Invalid email: {new_email}", fg="red"))
return
account.email = normalized_new_email
account.email = new_email
click.echo(click.style("Email updated successfully.", fg="green"))
@ -241,7 +235,7 @@ def migrate_annotation_vector_database():
if annotations:
for annotation in annotations:
document = Document(
page_content=annotation.question_text,
page_content=annotation.question,
metadata={"annotation_id": annotation.id, "app_id": app.id, "doc_id": annotation.id},
)
documents.append(document)
@ -664,7 +658,7 @@ def create_tenant(email: str, language: str | None = None, name: str | None = No
return
# Create account
email = email.strip().lower()
email = email.strip()
if "@" not in email:
click.echo(click.style("Invalid email address.", fg="red"))
@ -858,61 +852,6 @@ def clear_free_plan_tenant_expired_logs(days: int, batch: int, tenant_ids: list[
click.echo(click.style("Clear free plan tenant expired logs completed.", fg="green"))
@click.command("clean-workflow-runs", help="Clean expired workflow runs and related data for free tenants.")
@click.option("--days", default=30, show_default=True, help="Delete workflow runs created before N days ago.")
@click.option("--batch-size", default=200, show_default=True, help="Batch size for selecting workflow runs.")
@click.option(
"--start-from",
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
default=None,
help="Optional lower bound (inclusive) for created_at; must be paired with --end-before.",
)
@click.option(
"--end-before",
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
default=None,
help="Optional upper bound (exclusive) for created_at; must be paired with --start-from.",
)
@click.option(
"--dry-run",
is_flag=True,
help="Preview cleanup results without deleting any workflow run data.",
)
def clean_workflow_runs(
days: int,
batch_size: int,
start_from: datetime.datetime | None,
end_before: datetime.datetime | None,
dry_run: bool,
):
"""
Clean workflow runs and related workflow data for free tenants.
"""
if (start_from is None) ^ (end_before is None):
raise click.UsageError("--start-from and --end-before must be provided together.")
start_time = datetime.datetime.now(datetime.UTC)
click.echo(click.style(f"Starting workflow run cleanup at {start_time.isoformat()}.", fg="white"))
WorkflowRunCleanup(
days=days,
batch_size=batch_size,
start_from=start_from,
end_before=end_before,
dry_run=dry_run,
).run()
end_time = datetime.datetime.now(datetime.UTC)
elapsed = end_time - start_time
click.echo(
click.style(
f"Workflow run cleanup completed. start={start_time.isoformat()} "
f"end={end_time.isoformat()} duration={elapsed}",
fg="green",
)
)
@click.option("-f", "--force", is_flag=True, help="Skip user confirmation and force the command to execute.")
@click.command("clear-orphaned-file-records", help="Clear orphaned file records.")
def clear_orphaned_file_records(force: bool):
@ -1245,217 +1184,6 @@ def remove_orphaned_files_on_storage(force: bool):
click.echo(click.style(f"Removed {removed_files} orphaned files, with {error_files} errors.", fg="yellow"))
@click.command("file-usage", help="Query file usages and show where files are referenced.")
@click.option("--file-id", type=str, default=None, help="Filter by file UUID.")
@click.option("--key", type=str, default=None, help="Filter by storage key.")
@click.option("--src", type=str, default=None, help="Filter by table.column pattern (e.g., 'documents.%' or '%.icon').")
@click.option("--limit", type=int, default=100, help="Limit number of results (default: 100).")
@click.option("--offset", type=int, default=0, help="Offset for pagination (default: 0).")
@click.option("--json", "output_json", is_flag=True, help="Output results in JSON format.")
def file_usage(
file_id: str | None,
key: str | None,
src: str | None,
limit: int,
offset: int,
output_json: bool,
):
"""
Query file usages and show where files are referenced in the database.
This command reuses the same reference checking logic as clear-orphaned-file-records
and displays detailed information about where each file is referenced.
"""
# define tables and columns to process
files_tables = [
{"table": "upload_files", "id_column": "id", "key_column": "key"},
{"table": "tool_files", "id_column": "id", "key_column": "file_key"},
]
ids_tables = [
{"type": "uuid", "table": "message_files", "column": "upload_file_id", "pk_column": "id"},
{"type": "text", "table": "documents", "column": "data_source_info", "pk_column": "id"},
{"type": "text", "table": "document_segments", "column": "content", "pk_column": "id"},
{"type": "text", "table": "messages", "column": "answer", "pk_column": "id"},
{"type": "text", "table": "workflow_node_executions", "column": "inputs", "pk_column": "id"},
{"type": "text", "table": "workflow_node_executions", "column": "process_data", "pk_column": "id"},
{"type": "text", "table": "workflow_node_executions", "column": "outputs", "pk_column": "id"},
{"type": "text", "table": "conversations", "column": "introduction", "pk_column": "id"},
{"type": "text", "table": "conversations", "column": "system_instruction", "pk_column": "id"},
{"type": "text", "table": "accounts", "column": "avatar", "pk_column": "id"},
{"type": "text", "table": "apps", "column": "icon", "pk_column": "id"},
{"type": "text", "table": "sites", "column": "icon", "pk_column": "id"},
{"type": "json", "table": "messages", "column": "inputs", "pk_column": "id"},
{"type": "json", "table": "messages", "column": "message", "pk_column": "id"},
]
# Stream file usages with pagination to avoid holding all results in memory
paginated_usages = []
total_count = 0
# First, build a mapping of file_id -> storage_key from the base tables
file_key_map = {}
for files_table in files_tables:
query = f"SELECT {files_table['id_column']}, {files_table['key_column']} FROM {files_table['table']}"
with db.engine.begin() as conn:
rs = conn.execute(sa.text(query))
for row in rs:
file_key_map[str(row[0])] = f"{files_table['table']}:{row[1]}"
# If filtering by key or file_id, verify it exists
if file_id and file_id not in file_key_map:
if output_json:
click.echo(json.dumps({"error": f"File ID {file_id} not found in base tables"}))
else:
click.echo(click.style(f"File ID {file_id} not found in base tables.", fg="red"))
return
if key:
valid_prefixes = {f"upload_files:{key}", f"tool_files:{key}"}
matching_file_ids = [fid for fid, fkey in file_key_map.items() if fkey in valid_prefixes]
if not matching_file_ids:
if output_json:
click.echo(json.dumps({"error": f"Key {key} not found in base tables"}))
else:
click.echo(click.style(f"Key {key} not found in base tables.", fg="red"))
return
guid_regexp = "[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}"
# For each reference table/column, find matching file IDs and record the references
for ids_table in ids_tables:
src_filter = f"{ids_table['table']}.{ids_table['column']}"
# Skip if src filter doesn't match (use fnmatch for wildcard patterns)
if src:
if "%" in src or "_" in src:
import fnmatch
# Convert SQL LIKE wildcards to fnmatch wildcards (% -> *, _ -> ?)
pattern = src.replace("%", "*").replace("_", "?")
if not fnmatch.fnmatch(src_filter, pattern):
continue
else:
if src_filter != src:
continue
if ids_table["type"] == "uuid":
# Direct UUID match
query = (
f"SELECT {ids_table['pk_column']}, {ids_table['column']} "
f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL"
)
with db.engine.begin() as conn:
rs = conn.execute(sa.text(query))
for row in rs:
record_id = str(row[0])
ref_file_id = str(row[1])
if ref_file_id not in file_key_map:
continue
storage_key = file_key_map[ref_file_id]
# Apply filters
if file_id and ref_file_id != file_id:
continue
if key and not storage_key.endswith(key):
continue
# Only collect items within the requested page range
if offset <= total_count < offset + limit:
paginated_usages.append(
{
"src": f"{ids_table['table']}.{ids_table['column']}",
"record_id": record_id,
"file_id": ref_file_id,
"key": storage_key,
}
)
total_count += 1
elif ids_table["type"] in ("text", "json"):
# Extract UUIDs from text/json content
column_cast = f"{ids_table['column']}::text" if ids_table["type"] == "json" else ids_table["column"]
query = (
f"SELECT {ids_table['pk_column']}, {column_cast} "
f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL"
)
with db.engine.begin() as conn:
rs = conn.execute(sa.text(query))
for row in rs:
record_id = str(row[0])
content = str(row[1])
# Find all UUIDs in the content
import re
uuid_pattern = re.compile(guid_regexp, re.IGNORECASE)
matches = uuid_pattern.findall(content)
for ref_file_id in matches:
if ref_file_id not in file_key_map:
continue
storage_key = file_key_map[ref_file_id]
# Apply filters
if file_id and ref_file_id != file_id:
continue
if key and not storage_key.endswith(key):
continue
# Only collect items within the requested page range
if offset <= total_count < offset + limit:
paginated_usages.append(
{
"src": f"{ids_table['table']}.{ids_table['column']}",
"record_id": record_id,
"file_id": ref_file_id,
"key": storage_key,
}
)
total_count += 1
# Output results
if output_json:
result = {
"total": total_count,
"offset": offset,
"limit": limit,
"usages": paginated_usages,
}
click.echo(json.dumps(result, indent=2))
else:
click.echo(
click.style(f"Found {total_count} file usages (showing {len(paginated_usages)} results)", fg="white")
)
click.echo("")
if not paginated_usages:
click.echo(click.style("No file usages found matching the specified criteria.", fg="yellow"))
return
# Print table header
click.echo(
click.style(
f"{'Src (Table.Column)':<50} {'Record ID':<40} {'File ID':<40} {'Storage Key':<60}",
fg="cyan",
)
)
click.echo(click.style("-" * 190, fg="white"))
# Print each usage
for usage in paginated_usages:
click.echo(f"{usage['src']:<50} {usage['record_id']:<40} {usage['file_id']:<40} {usage['key']:<60}")
# Show pagination info
if offset + limit < total_count:
click.echo("")
click.echo(
click.style(
f"Showing {offset + 1}-{offset + len(paginated_usages)} of {total_count} results", fg="white"
)
)
click.echo(click.style(f"Use --offset {offset + limit} to see next page", fg="white"))
@click.command("setup-system-tool-oauth-client", help="Setup system tool oauth client.")
@click.option("--provider", prompt=True, help="Provider name")
@click.option("--client-params", prompt=True, help="Client Params")

View File

@ -959,16 +959,6 @@ class MailConfig(BaseSettings):
default=None,
)
ENABLE_TRIAL_APP: bool = Field(
description="Enable trial app",
default=False,
)
ENABLE_EXPLORE_BANNER: bool = Field(
description="Enable explore banner",
default=False,
)
class RagEtlConfig(BaseSettings):
"""
@ -1111,10 +1101,6 @@ class CeleryScheduleTasksConfig(BaseSettings):
description="Enable clean messages task",
default=False,
)
ENABLE_WORKFLOW_RUN_CLEANUP_TASK: bool = Field(
description="Enable scheduled workflow run cleanup task",
default=False,
)
ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK: bool = Field(
description="Enable mail clean document notify task",
default=False,

View File

@ -8,11 +8,6 @@ class HostedCreditConfig(BaseSettings):
default="",
)
HOSTED_POOL_CREDITS: int = Field(
description="Pool credits for hosted service",
default=200,
)
def get_model_credits(self, model_name: str) -> int:
"""
Get credit value for a specific model name.
@ -65,46 +60,19 @@ class HostedOpenAiConfig(BaseSettings):
HOSTED_OPENAI_TRIAL_MODELS: str = Field(
description="Comma-separated list of available models for trial access",
default="gpt-4,"
"gpt-4-turbo-preview,"
"gpt-4-turbo-2024-04-09,"
"gpt-4-1106-preview,"
"gpt-4-0125-preview,"
"gpt-4-turbo,"
"gpt-4.1,"
"gpt-4.1-2025-04-14,"
"gpt-4.1-mini,"
"gpt-4.1-mini-2025-04-14,"
"gpt-4.1-nano,"
"gpt-4.1-nano-2025-04-14,"
"gpt-3.5-turbo,"
default="gpt-3.5-turbo,"
"gpt-3.5-turbo-1106,"
"gpt-3.5-turbo-instruct,"
"gpt-3.5-turbo-16k,"
"gpt-3.5-turbo-16k-0613,"
"gpt-3.5-turbo-1106,"
"gpt-3.5-turbo-0613,"
"gpt-3.5-turbo-0125,"
"gpt-3.5-turbo-instruct,"
"text-davinci-003,"
"chatgpt-4o-latest,"
"gpt-4o,"
"gpt-4o-2024-05-13,"
"gpt-4o-2024-08-06,"
"gpt-4o-2024-11-20,"
"gpt-4o-audio-preview,"
"gpt-4o-audio-preview-2025-06-03,"
"gpt-4o-mini,"
"gpt-4o-mini-2024-07-18,"
"o3-mini,"
"o3-mini-2025-01-31,"
"gpt-5-mini-2025-08-07,"
"gpt-5-mini,"
"o4-mini,"
"o4-mini-2025-04-16,"
"gpt-5-chat-latest,"
"gpt-5,"
"gpt-5-2025-08-07,"
"gpt-5-nano,"
"gpt-5-nano-2025-08-07",
"text-davinci-003",
)
HOSTED_OPENAI_QUOTA_LIMIT: NonNegativeInt = Field(
description="Quota limit for hosted OpenAI service usage",
default=200,
)
HOSTED_OPENAI_PAID_ENABLED: bool = Field(
@ -119,13 +87,6 @@ class HostedOpenAiConfig(BaseSettings):
"gpt-4-turbo-2024-04-09,"
"gpt-4-1106-preview,"
"gpt-4-0125-preview,"
"gpt-4-turbo,"
"gpt-4.1,"
"gpt-4.1-2025-04-14,"
"gpt-4.1-mini,"
"gpt-4.1-mini-2025-04-14,"
"gpt-4.1-nano,"
"gpt-4.1-nano-2025-04-14,"
"gpt-3.5-turbo,"
"gpt-3.5-turbo-16k,"
"gpt-3.5-turbo-16k-0613,"
@ -133,150 +94,7 @@ class HostedOpenAiConfig(BaseSettings):
"gpt-3.5-turbo-0613,"
"gpt-3.5-turbo-0125,"
"gpt-3.5-turbo-instruct,"
"text-davinci-003,"
"chatgpt-4o-latest,"
"gpt-4o,"
"gpt-4o-2024-05-13,"
"gpt-4o-2024-08-06,"
"gpt-4o-2024-11-20,"
"gpt-4o-audio-preview,"
"gpt-4o-audio-preview-2025-06-03,"
"gpt-4o-mini,"
"gpt-4o-mini-2024-07-18,"
"o3-mini,"
"o3-mini-2025-01-31,"
"gpt-5-mini-2025-08-07,"
"gpt-5-mini,"
"o4-mini,"
"o4-mini-2025-04-16,"
"gpt-5-chat-latest,"
"gpt-5,"
"gpt-5-2025-08-07,"
"gpt-5-nano,"
"gpt-5-nano-2025-08-07",
)
class HostedGeminiConfig(BaseSettings):
"""
Configuration for fetching Gemini service
"""
HOSTED_GEMINI_API_KEY: str | None = Field(
description="API key for hosted Gemini service",
default=None,
)
HOSTED_GEMINI_API_BASE: str | None = Field(
description="Base URL for hosted Gemini API",
default=None,
)
HOSTED_GEMINI_API_ORGANIZATION: str | None = Field(
description="Organization ID for hosted Gemini service",
default=None,
)
HOSTED_GEMINI_TRIAL_ENABLED: bool = Field(
description="Enable trial access to hosted Gemini service",
default=False,
)
HOSTED_GEMINI_TRIAL_MODELS: str = Field(
description="Comma-separated list of available models for trial access",
default="gemini-2.5-flash,gemini-2.0-flash,gemini-2.0-flash-lite,",
)
HOSTED_GEMINI_PAID_ENABLED: bool = Field(
description="Enable paid access to hosted gemini service",
default=False,
)
HOSTED_GEMINI_PAID_MODELS: str = Field(
description="Comma-separated list of available models for paid access",
default="gemini-2.5-flash,gemini-2.0-flash,gemini-2.0-flash-lite,",
)
class HostedXAIConfig(BaseSettings):
"""
Configuration for fetching XAI service
"""
HOSTED_XAI_API_KEY: str | None = Field(
description="API key for hosted XAI service",
default=None,
)
HOSTED_XAI_API_BASE: str | None = Field(
description="Base URL for hosted XAI API",
default=None,
)
HOSTED_XAI_API_ORGANIZATION: str | None = Field(
description="Organization ID for hosted XAI service",
default=None,
)
HOSTED_XAI_TRIAL_ENABLED: bool = Field(
description="Enable trial access to hosted XAI service",
default=False,
)
HOSTED_XAI_TRIAL_MODELS: str = Field(
description="Comma-separated list of available models for trial access",
default="grok-3,grok-3-mini,grok-3-mini-fast",
)
HOSTED_XAI_PAID_ENABLED: bool = Field(
description="Enable paid access to hosted XAI service",
default=False,
)
HOSTED_XAI_PAID_MODELS: str = Field(
description="Comma-separated list of available models for paid access",
default="grok-3,grok-3-mini,grok-3-mini-fast",
)
class HostedDeepseekConfig(BaseSettings):
"""
Configuration for fetching Deepseek service
"""
HOSTED_DEEPSEEK_API_KEY: str | None = Field(
description="API key for hosted Deepseek service",
default=None,
)
HOSTED_DEEPSEEK_API_BASE: str | None = Field(
description="Base URL for hosted Deepseek API",
default=None,
)
HOSTED_DEEPSEEK_API_ORGANIZATION: str | None = Field(
description="Organization ID for hosted Deepseek service",
default=None,
)
HOSTED_DEEPSEEK_TRIAL_ENABLED: bool = Field(
description="Enable trial access to hosted Deepseek service",
default=False,
)
HOSTED_DEEPSEEK_TRIAL_MODELS: str = Field(
description="Comma-separated list of available models for trial access",
default="deepseek-chat,deepseek-reasoner",
)
HOSTED_DEEPSEEK_PAID_ENABLED: bool = Field(
description="Enable paid access to hosted Deepseek service",
default=False,
)
HOSTED_DEEPSEEK_PAID_MODELS: str = Field(
description="Comma-separated list of available models for paid access",
default="deepseek-chat,deepseek-reasoner",
"text-davinci-003",
)
@ -326,66 +144,16 @@ class HostedAnthropicConfig(BaseSettings):
default=False,
)
HOSTED_ANTHROPIC_QUOTA_LIMIT: NonNegativeInt = Field(
description="Quota limit for hosted Anthropic service usage",
default=600000,
)
HOSTED_ANTHROPIC_PAID_ENABLED: bool = Field(
description="Enable paid access to hosted Anthropic service",
default=False,
)
HOSTED_ANTHROPIC_TRIAL_MODELS: str = Field(
description="Comma-separated list of available models for paid access",
default="claude-opus-4-20250514,"
"claude-sonnet-4-20250514,"
"claude-3-5-haiku-20241022,"
"claude-3-opus-20240229,"
"claude-3-7-sonnet-20250219,"
"claude-3-haiku-20240307",
)
HOSTED_ANTHROPIC_PAID_MODELS: str = Field(
description="Comma-separated list of available models for paid access",
default="claude-opus-4-20250514,"
"claude-sonnet-4-20250514,"
"claude-3-5-haiku-20241022,"
"claude-3-opus-20240229,"
"claude-3-7-sonnet-20250219,"
"claude-3-haiku-20240307",
)
class HostedTongyiConfig(BaseSettings):
"""
Configuration for hosted Tongyi service
"""
HOSTED_TONGYI_API_KEY: str | None = Field(
description="API key for hosted Tongyi service",
default=None,
)
HOSTED_TONGYI_USE_INTERNATIONAL_ENDPOINT: bool = Field(
description="Use international endpoint for hosted Tongyi service",
default=False,
)
HOSTED_TONGYI_TRIAL_ENABLED: bool = Field(
description="Enable trial access to hosted Tongyi service",
default=False,
)
HOSTED_TONGYI_PAID_ENABLED: bool = Field(
description="Enable paid access to hosted Anthropic service",
default=False,
)
HOSTED_TONGYI_TRIAL_MODELS: str = Field(
description="Comma-separated list of available models for trial access",
default="",
)
HOSTED_TONGYI_PAID_MODELS: str = Field(
description="Comma-separated list of available models for paid access",
default="",
)
class HostedMinmaxConfig(BaseSettings):
"""
@ -478,13 +246,9 @@ class HostedServiceConfig(
HostedOpenAiConfig,
HostedSparkConfig,
HostedZhipuAIConfig,
HostedTongyiConfig,
# moderation
HostedModerationConfig,
# credit config
HostedCreditConfig,
HostedGeminiConfig,
HostedXAIConfig,
HostedDeepseekConfig,
):
pass

View File

@ -16,6 +16,7 @@ class MilvusConfig(BaseSettings):
description="Authentication token for Milvus, if token-based authentication is enabled",
default=None,
)
MILVUS_USER: str | None = Field(
description="Username for authenticating with Milvus, if username/password authentication is enabled",
default=None,

View File

@ -107,12 +107,10 @@ from .datasets.rag_pipeline import (
# Import explore controllers
from .explore import (
banner,
installed_app,
parameter,
recommended_app,
saved_message,
trial,
)
# Import tag controllers
@ -147,7 +145,6 @@ __all__ = [
"apikey",
"app",
"audio",
"banner",
"billing",
"bp",
"completion",
@ -201,7 +198,6 @@ __all__ = [
"statistic",
"tags",
"tool_providers",
"trial",
"trigger_providers",
"version",
"website",

View File

@ -15,7 +15,7 @@ from controllers.console.wraps import only_edition_cloud
from core.db.session_factory import session_factory
from extensions.ext_database import db
from libs.token import extract_access_token
from models.model import App, ExporleBanner, InstalledApp, RecommendedApp, TrialApp
from models.model import App, InstalledApp, RecommendedApp
P = ParamSpec("P")
R = TypeVar("R")
@ -32,8 +32,6 @@ class InsertExploreAppPayload(BaseModel):
language: str = Field(...)
category: str = Field(...)
position: int = Field(...)
can_trial: bool = Field(default=False)
trial_limit: int = Field(default=0)
@field_validator("language")
@classmethod
@ -41,33 +39,11 @@ class InsertExploreAppPayload(BaseModel):
return supported_language(value)
class InsertExploreBannerPayload(BaseModel):
category: str = Field(...)
title: str = Field(...)
description: str = Field(...)
img_src: str = Field(..., alias="img-src")
language: str = Field(default="en-US")
link: str = Field(...)
sort: int = Field(...)
@field_validator("language")
@classmethod
def validate_language(cls, value: str) -> str:
return supported_language(value)
model_config = {"populate_by_name": True}
console_ns.schema_model(
InsertExploreAppPayload.__name__,
InsertExploreAppPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
InsertExploreBannerPayload.__name__,
InsertExploreBannerPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
def admin_required(view: Callable[P, R]):
@wraps(view)
@ -133,20 +109,6 @@ class InsertExploreAppListApi(Resource):
)
db.session.add(recommended_app)
if payload.can_trial:
trial_app = db.session.execute(
select(TrialApp).where(TrialApp.app_id == payload.app_id)
).scalar_one_or_none()
if not trial_app:
db.session.add(
TrialApp(
app_id=payload.app_id,
tenant_id=app.tenant_id,
trial_limit=payload.trial_limit,
)
)
else:
trial_app.trial_limit = payload.trial_limit
app.is_public = True
db.session.commit()
@ -161,20 +123,6 @@ class InsertExploreAppListApi(Resource):
recommended_app.category = payload.category
recommended_app.position = payload.position
if payload.can_trial:
trial_app = db.session.execute(
select(TrialApp).where(TrialApp.app_id == payload.app_id)
).scalar_one_or_none()
if not trial_app:
db.session.add(
TrialApp(
app_id=payload.app_id,
tenant_id=app.tenant_id,
trial_limit=payload.trial_limit,
)
)
else:
trial_app.trial_limit = payload.trial_limit
app.is_public = True
db.session.commit()
@ -220,62 +168,7 @@ class InsertExploreAppApi(Resource):
for installed_app in installed_apps:
session.delete(installed_app)
trial_app = session.execute(
select(TrialApp).where(TrialApp.app_id == recommended_app.app_id)
).scalar_one_or_none()
if trial_app:
session.delete(trial_app)
db.session.delete(recommended_app)
db.session.commit()
return {"result": "success"}, 204
@console_ns.route("/admin/insert-explore-banner")
class InsertExploreBannerApi(Resource):
@console_ns.doc("insert_explore_banner")
@console_ns.doc(description="Insert an explore banner")
@console_ns.expect(console_ns.models[InsertExploreBannerPayload.__name__])
@console_ns.response(201, "Banner inserted successfully")
@only_edition_cloud
@admin_required
def post(self):
payload = InsertExploreBannerPayload.model_validate(console_ns.payload)
content = {
"category": payload.category,
"title": payload.title,
"description": payload.description,
"img-src": payload.img_src,
}
banner = ExporleBanner(
content=content,
link=payload.link,
sort=payload.sort,
language=payload.language,
)
db.session.add(banner)
db.session.commit()
return {"result": "success"}, 201
@console_ns.route("/admin/insert-explore-banner/<uuid:banner_id>")
class DeleteExploreBannerApi(Resource):
@console_ns.doc("delete_explore_banner")
@console_ns.doc(description="Delete an explore banner")
@console_ns.doc(params={"banner_id": "Banner ID to delete"})
@console_ns.response(204, "Banner deleted successfully")
@only_edition_cloud
@admin_required
def delete(self, banner_id):
banner = db.session.execute(select(ExporleBanner).where(ExporleBanner.id == banner_id)).scalar_one_or_none()
if not banner:
raise NotFound(f"Banner '{banner_id}' is not found")
db.session.delete(banner)
db.session.commit()
return {"result": "success"}, 204

View File

@ -272,7 +272,6 @@ class AnnotationExportApi(Resource):
@account_initialization_required
@edit_permission_required
def get(self, app_id):
app_id = str(app_id)
annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id)
response_data = {"data": marshal(annotation_list, annotation_fields)}
@ -360,6 +359,7 @@ class AnnotationBatchImportApi(Resource):
file.seek(0, 2) # Seek to end of file
file_size = file.tell()
file.seek(0) # Reset to beginning
max_size_bytes = dify_config.ANNOTATION_IMPORT_FILE_SIZE_LIMIT * 1024 * 1024
if file_size > max_size_bytes:
abort(

View File

@ -1,16 +1,14 @@
import re
import uuid
from datetime import datetime
from typing import Any, Literal, TypeAlias
from typing import Literal
from flask import request
from flask_restx import Resource
from pydantic import AliasChoices, BaseModel, ConfigDict, Field, computed_field, field_validator
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import (
@ -21,19 +19,27 @@ from controllers.console.wraps import (
is_admin_or_owner_required,
setup_required,
)
from core.file import helpers as file_helpers
from core.ops.ops_trace_manager import OpsTraceManager
from core.workflow.enums import NodeType
from extensions.ext_database import db
from fields.app_fields import (
deleted_tool_fields,
model_config_fields,
model_config_partial_fields,
site_fields,
tag_fields,
)
from fields.workflow_fields import workflow_partial_fields as _workflow_partial_fields_dict
from libs.helper import AppIconUrlField, TimestampField
from libs.login import current_account_with_tenant, login_required
from models import App, Workflow
from models.model import IconType
from services.app_dsl_service import AppDslService, ImportMode
from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class AppListQuery(BaseModel):
@ -186,292 +192,124 @@ class AppTracePayload(BaseModel):
return value
JSONValue: TypeAlias = Any
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
class ResponseModel(BaseModel):
model_config = ConfigDict(
from_attributes=True,
extra="ignore",
populate_by_name=True,
serialize_by_alias=True,
protected_namespaces=(),
)
reg(AppListQuery)
reg(CreateAppPayload)
reg(UpdateAppPayload)
reg(CopyAppPayload)
reg(AppExportQuery)
reg(AppNamePayload)
reg(AppIconPayload)
reg(AppSiteStatusPayload)
reg(AppApiStatusPayload)
reg(AppTracePayload)
# Register models for flask_restx to avoid dict type issues in Swagger
# Register base models first
tag_model = console_ns.model("Tag", tag_fields)
def _to_timestamp(value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
workflow_partial_model = console_ns.model("WorkflowPartial", _workflow_partial_fields_dict)
model_config_model = console_ns.model("ModelConfig", model_config_fields)
def _build_icon_url(icon_type: str | IconType | None, icon: str | None) -> str | None:
if icon is None or icon_type is None:
return None
icon_type_value = icon_type.value if isinstance(icon_type, IconType) else str(icon_type)
if icon_type_value.lower() != IconType.IMAGE.value:
return None
return file_helpers.get_signed_file_url(icon)
model_config_partial_model = console_ns.model("ModelConfigPartial", model_config_partial_fields)
deleted_tool_model = console_ns.model("DeletedTool", deleted_tool_fields)
class Tag(ResponseModel):
id: str
name: str
type: str
site_model = console_ns.model("Site", site_fields)
app_partial_model = console_ns.model(
"AppPartial",
{
"id": fields.String,
"name": fields.String,
"max_active_requests": fields.Raw(),
"description": fields.String(attribute="desc_or_prompt"),
"mode": fields.String(attribute="mode_compatible_with_agent"),
"icon_type": fields.String,
"icon": fields.String,
"icon_background": fields.String,
"icon_url": AppIconUrlField,
"model_config": fields.Nested(model_config_partial_model, attribute="app_model_config", allow_null=True),
"workflow": fields.Nested(workflow_partial_model, allow_null=True),
"use_icon_as_answer_icon": fields.Boolean,
"created_by": fields.String,
"created_at": TimestampField,
"updated_by": fields.String,
"updated_at": TimestampField,
"tags": fields.List(fields.Nested(tag_model)),
"access_mode": fields.String,
"create_user_name": fields.String,
"author_name": fields.String,
"has_draft_trigger": fields.Boolean,
},
)
class WorkflowPartial(ResponseModel):
id: str
created_by: str | None = None
created_at: int | None = None
updated_by: str | None = None
updated_at: int | None = None
app_detail_model = console_ns.model(
"AppDetail",
{
"id": fields.String,
"name": fields.String,
"description": fields.String,
"mode": fields.String(attribute="mode_compatible_with_agent"),
"icon": fields.String,
"icon_background": fields.String,
"enable_site": fields.Boolean,
"enable_api": fields.Boolean,
"model_config": fields.Nested(model_config_model, attribute="app_model_config", allow_null=True),
"workflow": fields.Nested(workflow_partial_model, allow_null=True),
"tracing": fields.Raw,
"use_icon_as_answer_icon": fields.Boolean,
"created_by": fields.String,
"created_at": TimestampField,
"updated_by": fields.String,
"updated_at": TimestampField,
"access_mode": fields.String,
"tags": fields.List(fields.Nested(tag_model)),
},
)
@field_validator("created_at", "updated_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
app_detail_with_site_model = console_ns.model(
"AppDetailWithSite",
{
"id": fields.String,
"name": fields.String,
"description": fields.String,
"mode": fields.String(attribute="mode_compatible_with_agent"),
"icon_type": fields.String,
"icon": fields.String,
"icon_background": fields.String,
"icon_url": AppIconUrlField,
"enable_site": fields.Boolean,
"enable_api": fields.Boolean,
"model_config": fields.Nested(model_config_model, attribute="app_model_config", allow_null=True),
"workflow": fields.Nested(workflow_partial_model, allow_null=True),
"api_base_url": fields.String,
"use_icon_as_answer_icon": fields.Boolean,
"max_active_requests": fields.Integer,
"created_by": fields.String,
"created_at": TimestampField,
"updated_by": fields.String,
"updated_at": TimestampField,
"deleted_tools": fields.List(fields.Nested(deleted_tool_model)),
"access_mode": fields.String,
"tags": fields.List(fields.Nested(tag_model)),
"site": fields.Nested(site_model),
},
)
class ModelConfigPartial(ResponseModel):
model: JSONValue | None = Field(default=None, validation_alias=AliasChoices("model_dict", "model"))
pre_prompt: str | None = None
created_by: str | None = None
created_at: int | None = None
updated_by: str | None = None
updated_at: int | None = None
@field_validator("created_at", "updated_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
class ModelConfig(ResponseModel):
opening_statement: str | None = None
suggested_questions: JSONValue | None = Field(
default=None, validation_alias=AliasChoices("suggested_questions_list", "suggested_questions")
)
suggested_questions_after_answer: JSONValue | None = Field(
default=None,
validation_alias=AliasChoices("suggested_questions_after_answer_dict", "suggested_questions_after_answer"),
)
speech_to_text: JSONValue | None = Field(
default=None, validation_alias=AliasChoices("speech_to_text_dict", "speech_to_text")
)
text_to_speech: JSONValue | None = Field(
default=None, validation_alias=AliasChoices("text_to_speech_dict", "text_to_speech")
)
retriever_resource: JSONValue | None = Field(
default=None, validation_alias=AliasChoices("retriever_resource_dict", "retriever_resource")
)
annotation_reply: JSONValue | None = Field(
default=None, validation_alias=AliasChoices("annotation_reply_dict", "annotation_reply")
)
more_like_this: JSONValue | None = Field(
default=None, validation_alias=AliasChoices("more_like_this_dict", "more_like_this")
)
sensitive_word_avoidance: JSONValue | None = Field(
default=None, validation_alias=AliasChoices("sensitive_word_avoidance_dict", "sensitive_word_avoidance")
)
external_data_tools: JSONValue | None = Field(
default=None, validation_alias=AliasChoices("external_data_tools_list", "external_data_tools")
)
model: JSONValue | None = Field(default=None, validation_alias=AliasChoices("model_dict", "model"))
user_input_form: JSONValue | None = Field(
default=None, validation_alias=AliasChoices("user_input_form_list", "user_input_form")
)
dataset_query_variable: str | None = None
pre_prompt: str | None = None
agent_mode: JSONValue | None = Field(default=None, validation_alias=AliasChoices("agent_mode_dict", "agent_mode"))
prompt_type: str | None = None
chat_prompt_config: JSONValue | None = Field(
default=None, validation_alias=AliasChoices("chat_prompt_config_dict", "chat_prompt_config")
)
completion_prompt_config: JSONValue | None = Field(
default=None, validation_alias=AliasChoices("completion_prompt_config_dict", "completion_prompt_config")
)
dataset_configs: JSONValue | None = Field(
default=None, validation_alias=AliasChoices("dataset_configs_dict", "dataset_configs")
)
file_upload: JSONValue | None = Field(
default=None, validation_alias=AliasChoices("file_upload_dict", "file_upload")
)
created_by: str | None = None
created_at: int | None = None
updated_by: str | None = None
updated_at: int | None = None
@field_validator("created_at", "updated_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
class Site(ResponseModel):
access_token: str | None = Field(default=None, validation_alias="code")
code: str | None = None
title: str | None = None
icon_type: str | IconType | None = None
icon: str | None = None
icon_background: str | None = None
description: str | None = None
default_language: str | None = None
chat_color_theme: str | None = None
chat_color_theme_inverted: bool | None = None
customize_domain: str | None = None
copyright: str | None = None
privacy_policy: str | None = None
custom_disclaimer: str | None = None
customize_token_strategy: str | None = None
prompt_public: bool | None = None
app_base_url: str | None = None
show_workflow_steps: bool | None = None
use_icon_as_answer_icon: bool | None = None
created_by: str | None = None
created_at: int | None = None
updated_by: str | None = None
updated_at: int | None = None
@computed_field(return_type=str | None) # type: ignore
@property
def icon_url(self) -> str | None:
return _build_icon_url(self.icon_type, self.icon)
@field_validator("icon_type", mode="before")
@classmethod
def _normalize_icon_type(cls, value: str | IconType | None) -> str | None:
if isinstance(value, IconType):
return value.value
return value
@field_validator("created_at", "updated_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
class DeletedTool(ResponseModel):
type: str
tool_name: str
provider_id: str
class AppPartial(ResponseModel):
id: str
name: str
max_active_requests: int | None = None
description: str | None = Field(default=None, validation_alias=AliasChoices("desc_or_prompt", "description"))
mode: str = Field(validation_alias="mode_compatible_with_agent")
icon_type: str | None = None
icon: str | None = None
icon_background: str | None = None
model_config_: ModelConfigPartial | None = Field(
default=None,
validation_alias=AliasChoices("app_model_config", "model_config"),
alias="model_config",
)
workflow: WorkflowPartial | None = None
use_icon_as_answer_icon: bool | None = None
created_by: str | None = None
created_at: int | None = None
updated_by: str | None = None
updated_at: int | None = None
tags: list[Tag] = Field(default_factory=list)
access_mode: str | None = None
create_user_name: str | None = None
author_name: str | None = None
has_draft_trigger: bool | None = None
@computed_field(return_type=str | None) # type: ignore
@property
def icon_url(self) -> str | None:
return _build_icon_url(self.icon_type, self.icon)
@field_validator("created_at", "updated_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
class AppDetail(ResponseModel):
id: str
name: str
description: str | None = None
mode: str = Field(validation_alias="mode_compatible_with_agent")
icon: str | None = None
icon_background: str | None = None
enable_site: bool
enable_api: bool
model_config_: ModelConfig | None = Field(
default=None,
validation_alias=AliasChoices("app_model_config", "model_config"),
alias="model_config",
)
workflow: WorkflowPartial | None = None
tracing: JSONValue | None = None
use_icon_as_answer_icon: bool | None = None
created_by: str | None = None
created_at: int | None = None
updated_by: str | None = None
updated_at: int | None = None
access_mode: str | None = None
tags: list[Tag] = Field(default_factory=list)
@field_validator("created_at", "updated_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
class AppDetailWithSite(AppDetail):
icon_type: str | None = None
api_base_url: str | None = None
max_active_requests: int | None = None
deleted_tools: list[DeletedTool] = Field(default_factory=list)
site: Site | None = None
@computed_field(return_type=str | None) # type: ignore
@property
def icon_url(self) -> str | None:
return _build_icon_url(self.icon_type, self.icon)
class AppPagination(ResponseModel):
page: int
limit: int = Field(validation_alias=AliasChoices("per_page", "limit"))
total: int
has_more: bool = Field(validation_alias=AliasChoices("has_next", "has_more"))
data: list[AppPartial] = Field(validation_alias=AliasChoices("items", "data"))
class AppExportResponse(ResponseModel):
data: str
register_schema_models(
console_ns,
AppListQuery,
CreateAppPayload,
UpdateAppPayload,
CopyAppPayload,
AppExportQuery,
AppNamePayload,
AppIconPayload,
AppSiteStatusPayload,
AppApiStatusPayload,
AppTracePayload,
Tag,
WorkflowPartial,
ModelConfigPartial,
ModelConfig,
Site,
DeletedTool,
AppPartial,
AppDetail,
AppDetailWithSite,
AppPagination,
AppExportResponse,
app_pagination_model = console_ns.model(
"AppPagination",
{
"page": fields.Integer,
"limit": fields.Integer(attribute="per_page"),
"total": fields.Integer,
"has_more": fields.Boolean(attribute="has_next"),
"data": fields.List(fields.Nested(app_partial_model), attribute="items"),
},
)
@ -480,7 +318,7 @@ class AppListApi(Resource):
@console_ns.doc("list_apps")
@console_ns.doc(description="Get list of applications with pagination and filtering")
@console_ns.expect(console_ns.models[AppListQuery.__name__])
@console_ns.response(200, "Success", console_ns.models[AppPagination.__name__])
@console_ns.response(200, "Success", app_pagination_model)
@setup_required
@login_required
@account_initialization_required
@ -496,8 +334,7 @@ class AppListApi(Resource):
app_service = AppService()
app_pagination = app_service.get_paginate_apps(current_user.id, current_tenant_id, args_dict)
if not app_pagination:
empty = AppPagination(page=args.page, limit=args.limit, total=0, has_more=False, data=[])
return empty.model_dump(mode="json"), 200
return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False}
if FeatureService.get_system_features().webapp_auth.enabled:
app_ids = [str(app.id) for app in app_pagination.items]
@ -541,18 +378,18 @@ class AppListApi(Resource):
for app in app_pagination.items:
app.has_draft_trigger = str(app.id) in draft_trigger_app_ids
pagination_model = AppPagination.model_validate(app_pagination, from_attributes=True)
return pagination_model.model_dump(mode="json"), 200
return marshal(app_pagination, app_pagination_model), 200
@console_ns.doc("create_app")
@console_ns.doc(description="Create a new application")
@console_ns.expect(console_ns.models[CreateAppPayload.__name__])
@console_ns.response(201, "App created successfully", console_ns.models[AppDetail.__name__])
@console_ns.response(201, "App created successfully", app_detail_model)
@console_ns.response(403, "Insufficient permissions")
@console_ns.response(400, "Invalid request parameters")
@setup_required
@login_required
@account_initialization_required
@marshal_with(app_detail_model)
@cloud_edition_billing_resource_check("apps")
@edit_permission_required
def post(self):
@ -562,8 +399,8 @@ class AppListApi(Resource):
app_service = AppService()
app = app_service.create_app(current_tenant_id, args.model_dump(), current_user)
app_detail = AppDetail.model_validate(app, from_attributes=True)
return app_detail.model_dump(mode="json"), 201
return app, 201
@console_ns.route("/apps/<uuid:app_id>")
@ -571,12 +408,13 @@ class AppApi(Resource):
@console_ns.doc("get_app_detail")
@console_ns.doc(description="Get application details")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.response(200, "Success", console_ns.models[AppDetailWithSite.__name__])
@console_ns.response(200, "Success", app_detail_with_site_model)
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
@get_app_model(mode=None)
@get_app_model
@marshal_with(app_detail_with_site_model)
def get(self, app_model):
"""Get app detail"""
app_service = AppService()
@ -587,21 +425,21 @@ class AppApi(Resource):
app_setting = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=str(app_model.id))
app_model.access_mode = app_setting.access_mode
response_model = AppDetailWithSite.model_validate(app_model, from_attributes=True)
return response_model.model_dump(mode="json")
return app_model
@console_ns.doc("update_app")
@console_ns.doc(description="Update application details")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[UpdateAppPayload.__name__])
@console_ns.response(200, "App updated successfully", console_ns.models[AppDetailWithSite.__name__])
@console_ns.response(200, "App updated successfully", app_detail_with_site_model)
@console_ns.response(403, "Insufficient permissions")
@console_ns.response(400, "Invalid request parameters")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=None)
@get_app_model
@edit_permission_required
@marshal_with(app_detail_with_site_model)
def put(self, app_model):
"""Update app"""
args = UpdateAppPayload.model_validate(console_ns.payload)
@ -618,8 +456,8 @@ class AppApi(Resource):
"max_active_requests": args.max_active_requests or 0,
}
app_model = app_service.update_app(app_model, args_dict)
response_model = AppDetailWithSite.model_validate(app_model, from_attributes=True)
return response_model.model_dump(mode="json")
return app_model
@console_ns.doc("delete_app")
@console_ns.doc(description="Delete application")
@ -645,13 +483,14 @@ class AppCopyApi(Resource):
@console_ns.doc(description="Create a copy of an existing application")
@console_ns.doc(params={"app_id": "Application ID to copy"})
@console_ns.expect(console_ns.models[CopyAppPayload.__name__])
@console_ns.response(201, "App copied successfully", console_ns.models[AppDetailWithSite.__name__])
@console_ns.response(201, "App copied successfully", app_detail_with_site_model)
@console_ns.response(403, "Insufficient permissions")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=None)
@get_app_model
@edit_permission_required
@marshal_with(app_detail_with_site_model)
def post(self, app_model):
"""Copy app"""
# The role of the current user in the ta table must be admin, owner, or editor
@ -677,8 +516,7 @@ class AppCopyApi(Resource):
stmt = select(App).where(App.id == result.app_id)
app = session.scalar(stmt)
response_model = AppDetailWithSite.model_validate(app, from_attributes=True)
return response_model.model_dump(mode="json"), 201
return app, 201
@console_ns.route("/apps/<uuid:app_id>/export")
@ -687,7 +525,11 @@ class AppExportApi(Resource):
@console_ns.doc(description="Export application configuration as DSL")
@console_ns.doc(params={"app_id": "Application ID to export"})
@console_ns.expect(console_ns.models[AppExportQuery.__name__])
@console_ns.response(200, "App exported successfully", console_ns.models[AppExportResponse.__name__])
@console_ns.response(
200,
"App exported successfully",
console_ns.model("AppExportResponse", {"data": fields.String(description="DSL export data")}),
)
@console_ns.response(403, "Insufficient permissions")
@get_app_model
@setup_required
@ -698,14 +540,13 @@ class AppExportApi(Resource):
"""Export app"""
args = AppExportQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
payload = AppExportResponse(
data=AppDslService.export_dsl(
return {
"data": AppDslService.export_dsl(
app_model=app_model,
include_secret=args.include_secret,
workflow_id=args.workflow_id,
)
)
return payload.model_dump(mode="json")
}
@console_ns.route("/apps/<uuid:app_id>/name")
@ -714,19 +555,20 @@ class AppNameApi(Resource):
@console_ns.doc(description="Check if app name is available")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[AppNamePayload.__name__])
@console_ns.response(200, "Name availability checked", console_ns.models[AppDetail.__name__])
@console_ns.response(200, "Name availability checked")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=None)
@get_app_model
@marshal_with(app_detail_model)
@edit_permission_required
def post(self, app_model):
args = AppNamePayload.model_validate(console_ns.payload)
app_service = AppService()
app_model = app_service.update_app_name(app_model, args.name)
response_model = AppDetail.model_validate(app_model, from_attributes=True)
return response_model.model_dump(mode="json")
return app_model
@console_ns.route("/apps/<uuid:app_id>/icon")
@ -740,15 +582,16 @@ class AppIconApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=None)
@get_app_model
@marshal_with(app_detail_model)
@edit_permission_required
def post(self, app_model):
args = AppIconPayload.model_validate(console_ns.payload or {})
app_service = AppService()
app_model = app_service.update_app_icon(app_model, args.icon or "", args.icon_background or "")
response_model = AppDetail.model_validate(app_model, from_attributes=True)
return response_model.model_dump(mode="json")
return app_model
@console_ns.route("/apps/<uuid:app_id>/site-enable")
@ -757,20 +600,21 @@ class AppSiteStatus(Resource):
@console_ns.doc(description="Enable or disable app site")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[AppSiteStatusPayload.__name__])
@console_ns.response(200, "Site status updated successfully", console_ns.models[AppDetail.__name__])
@console_ns.response(200, "Site status updated successfully", app_detail_model)
@console_ns.response(403, "Insufficient permissions")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=None)
@get_app_model
@marshal_with(app_detail_model)
@edit_permission_required
def post(self, app_model):
args = AppSiteStatusPayload.model_validate(console_ns.payload)
app_service = AppService()
app_model = app_service.update_app_site_status(app_model, args.enable_site)
response_model = AppDetail.model_validate(app_model, from_attributes=True)
return response_model.model_dump(mode="json")
return app_model
@console_ns.route("/apps/<uuid:app_id>/api-enable")
@ -779,20 +623,21 @@ class AppApiStatus(Resource):
@console_ns.doc(description="Enable or disable app API")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[AppApiStatusPayload.__name__])
@console_ns.response(200, "API status updated successfully", console_ns.models[AppDetail.__name__])
@console_ns.response(200, "API status updated successfully", app_detail_model)
@console_ns.response(403, "Insufficient permissions")
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
@get_app_model(mode=None)
@get_app_model
@marshal_with(app_detail_model)
def post(self, app_model):
args = AppApiStatusPayload.model_validate(console_ns.payload)
app_service = AppService()
app_model = app_service.update_app_api_status(app_model, args.enable_api)
response_model = AppDetail.model_validate(app_model, from_attributes=True)
return response_model.model_dump(mode="json")
return app_model
@console_ns.route("/apps/<uuid:app_id>/trace")

View File

@ -348,13 +348,10 @@ class CompletionConversationApi(Resource):
)
if args.keyword:
from libs.helper import escape_like_pattern
escaped_keyword = escape_like_pattern(args.keyword)
query = query.join(Message, Message.conversation_id == Conversation.id).where(
or_(
Message.query.ilike(f"%{escaped_keyword}%", escape="\\"),
Message.answer.ilike(f"%{escaped_keyword}%", escape="\\"),
Message.query.ilike(f"%{args.keyword}%"),
Message.answer.ilike(f"%{args.keyword}%"),
)
)
@ -463,10 +460,7 @@ class ChatConversationApi(Resource):
query = sa.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.is_deleted.is_(False))
if args.keyword:
from libs.helper import escape_like_pattern
escaped_keyword = escape_like_pattern(args.keyword)
keyword_filter = f"%{escaped_keyword}%"
keyword_filter = f"%{args.keyword}%"
query = (
query.join(
Message,
@ -475,11 +469,11 @@ class ChatConversationApi(Resource):
.join(subquery, subquery.c.conversation_id == Conversation.id)
.where(
or_(
Message.query.ilike(keyword_filter, escape="\\"),
Message.answer.ilike(keyword_filter, escape="\\"),
Conversation.name.ilike(keyword_filter, escape="\\"),
Conversation.introduction.ilike(keyword_filter, escape="\\"),
subquery.c.from_end_user_session_id.ilike(keyword_filter, escape="\\"),
Message.query.ilike(keyword_filter),
Message.answer.ilike(keyword_filter),
Conversation.name.ilike(keyword_filter),
Conversation.introduction.ilike(keyword_filter),
subquery.c.from_end_user_session_id.ilike(keyword_filter),
),
)
.group_by(Conversation.id)

View File

@ -115,9 +115,3 @@ class InvokeRateLimitError(BaseHTTPException):
error_code = "rate_limit_error"
description = "Rate Limit Error"
code = 429
class NeedAddIdsError(BaseHTTPException):
error_code = "need_add_ids"
description = "Need to add ids."
code = 400

View File

@ -202,7 +202,6 @@ message_detail_model = console_ns.model(
"status": fields.String,
"error": fields.String,
"parent_message_id": fields.String,
"generation_detail": fields.Raw,
},
)

View File

@ -23,11 +23,6 @@ def _load_app_model(app_id: str) -> App | None:
return app_model
def _load_app_model_with_trial(app_id: str) -> App | None:
app_model = db.session.query(App).where(App.id == app_id, App.status == "normal").first()
return app_model
def get_app_model(view: Callable[P, R] | None = None, *, mode: Union[AppMode, list[AppMode], None] = None):
def decorator(view_func: Callable[P1, R1]):
@wraps(view_func)
@ -67,44 +62,3 @@ def get_app_model(view: Callable[P, R] | None = None, *, mode: Union[AppMode, li
return decorator
else:
return decorator(view)
def get_app_model_with_trial(view: Callable[P, R] | None = None, *, mode: Union[AppMode, list[AppMode], None] = None):
def decorator(view_func: Callable[P, R]):
@wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs):
if not kwargs.get("app_id"):
raise ValueError("missing app_id in path parameters")
app_id = kwargs.get("app_id")
app_id = str(app_id)
del kwargs["app_id"]
app_model = _load_app_model_with_trial(app_id)
if not app_model:
raise AppNotFoundError()
app_mode = AppMode.value_of(app_model.mode)
if mode is not None:
if isinstance(mode, list):
modes = mode
else:
modes = [mode]
if app_mode not in modes:
mode_values = {m.value for m in modes}
raise AppNotFoundError(f"App mode is not in the supported list: {mode_values}")
kwargs["app_model"] = app_model
return view_func(*args, **kwargs)
return decorated_view
if view is None:
return decorator
else:
return decorator(view)

View File

@ -63,9 +63,10 @@ class ActivateCheckApi(Resource):
args = ActivateCheckQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
workspaceId = args.workspace_id
reg_email = args.email
token = args.token
invitation = RegisterService.get_invitation_with_case_fallback(workspaceId, args.email, token)
invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token)
if invitation:
data = invitation.get("data", {})
tenant = invitation.get("tenant", None)
@ -99,12 +100,11 @@ class ActivateApi(Resource):
def post(self):
args = ActivatePayload.model_validate(console_ns.payload)
normalized_request_email = args.email.lower() if args.email else None
invitation = RegisterService.get_invitation_with_case_fallback(args.workspace_id, args.email, args.token)
invitation = RegisterService.get_invitation_if_token_valid(args.workspace_id, args.email, args.token)
if invitation is None:
raise AlreadyActivateError()
RegisterService.revoke_token(args.workspace_id, normalized_request_email, args.token)
RegisterService.revoke_token(args.workspace_id, args.email, args.token)
account = invitation["account"]
account.name = args.name

View File

@ -1,6 +1,7 @@
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from sqlalchemy.orm import Session
from configs import dify_config
@ -61,7 +62,6 @@ class EmailRegisterSendEmailApi(Resource):
@email_register_enabled
def post(self):
args = EmailRegisterSendPayload.model_validate(console_ns.payload)
normalized_email = args.email.lower()
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
@ -70,12 +70,13 @@ class EmailRegisterSendEmailApi(Resource):
if args.language in languages:
language = args.language
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args.email):
raise AccountInFreezeError()
with Session(db.engine) as session:
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
token = AccountService.send_email_register_email(email=normalized_email, account=account, language=language)
account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
token = None
token = AccountService.send_email_register_email(email=args.email, account=account, language=language)
return {"result": "success", "data": token}
@ -87,9 +88,9 @@ class EmailRegisterCheckApi(Resource):
def post(self):
args = EmailRegisterValidityPayload.model_validate(console_ns.payload)
user_email = args.email.lower()
user_email = args.email
is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(user_email)
is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(args.email)
if is_email_register_error_rate_limit:
raise EmailRegisterLimitError()
@ -97,14 +98,11 @@ class EmailRegisterCheckApi(Resource):
if token_data is None:
raise InvalidTokenError()
token_email = token_data.get("email")
normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
if user_email != normalized_token_email:
if user_email != token_data.get("email"):
raise InvalidEmailError()
if args.code != token_data.get("code"):
AccountService.add_email_register_error_rate_limit(user_email)
AccountService.add_email_register_error_rate_limit(args.email)
raise EmailCodeError()
# Verified, revoke the first token
@ -115,8 +113,8 @@ class EmailRegisterCheckApi(Resource):
user_email, code=args.code, additional_data={"phase": "register"}
)
AccountService.reset_email_register_error_rate_limit(user_email)
return {"is_valid": True, "email": normalized_token_email, "token": new_token}
AccountService.reset_email_register_error_rate_limit(args.email)
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
@console_ns.route("/email-register")
@ -143,23 +141,22 @@ class EmailRegisterResetApi(Resource):
AccountService.revoke_email_register_token(args.token)
email = register_data.get("email", "")
normalized_email = email.lower()
with Session(db.engine) as session:
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
if account:
raise EmailAlreadyInUseError()
else:
account = self._create_new_account(normalized_email, args.password_confirm)
account = self._create_new_account(email, args.password_confirm)
if not account:
raise AccountNotFoundError()
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(normalized_email)
AccountService.reset_login_error_rate_limit(email)
return {"result": "success", "data": token_pair.model_dump()}
def _create_new_account(self, email: str, password: str) -> Account | None:
def _create_new_account(self, email, password) -> Account | None:
# Create new account if allowed
account = None
try:

View File

@ -4,6 +4,7 @@ import secrets
from flask import request
from flask_restx import Resource, fields
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from sqlalchemy.orm import Session
from controllers.console import console_ns
@ -20,6 +21,7 @@ from events.tenant_event import tenant_was_created
from extensions.ext_database import db
from libs.helper import EmailStr, extract_remote_ip
from libs.password import hash_password, valid_password
from models import Account
from services.account_service import AccountService, TenantService
from services.feature_service import FeatureService
@ -74,7 +76,6 @@ class ForgotPasswordSendEmailApi(Resource):
@email_password_login_enabled
def post(self):
args = ForgotPasswordSendPayload.model_validate(console_ns.payload)
normalized_email = args.email.lower()
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
@ -86,11 +87,11 @@ class ForgotPasswordSendEmailApi(Resource):
language = "en-US"
with Session(db.engine) as session:
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
token = AccountService.send_reset_password_email(
account=account,
email=normalized_email,
email=args.email,
language=language,
is_allow_register=FeatureService.get_system_features().is_allow_register,
)
@ -121,9 +122,9 @@ class ForgotPasswordCheckApi(Resource):
def post(self):
args = ForgotPasswordCheckPayload.model_validate(console_ns.payload)
user_email = args.email.lower()
user_email = args.email
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(user_email)
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args.email)
if is_forgot_password_error_rate_limit:
raise EmailPasswordResetLimitError()
@ -131,16 +132,11 @@ class ForgotPasswordCheckApi(Resource):
if token_data is None:
raise InvalidTokenError()
token_email = token_data.get("email")
if not isinstance(token_email, str):
raise InvalidEmailError()
normalized_token_email = token_email.lower()
if user_email != normalized_token_email:
if user_email != token_data.get("email"):
raise InvalidEmailError()
if args.code != token_data.get("code"):
AccountService.add_forgot_password_error_rate_limit(user_email)
AccountService.add_forgot_password_error_rate_limit(args.email)
raise EmailCodeError()
# Verified, revoke the first token
@ -148,11 +144,11 @@ class ForgotPasswordCheckApi(Resource):
# Refresh token data by generating a new token
_, new_token = AccountService.generate_reset_password_token(
token_email, code=args.code, additional_data={"phase": "reset"}
user_email, code=args.code, additional_data={"phase": "reset"}
)
AccountService.reset_forgot_password_error_rate_limit(user_email)
return {"is_valid": True, "email": normalized_token_email, "token": new_token}
AccountService.reset_forgot_password_error_rate_limit(args.email)
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
@console_ns.route("/forgot-password/resets")
@ -191,8 +187,9 @@ class ForgotPasswordResetApi(Resource):
password_hashed = hash_password(args.new_password, salt)
email = reset_data.get("email", "")
with Session(db.engine) as session:
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
if account:
self._update_existing_account(account, password_hashed, salt, session)

View File

@ -1,5 +1,3 @@
from typing import Any
import flask_login
from flask import make_response, request
from flask_restx import Resource
@ -90,38 +88,33 @@ class LoginApi(Resource):
def post(self):
"""Authenticate user and login."""
args = LoginPayload.model_validate(console_ns.payload)
request_email = args.email
normalized_email = request_email.lower()
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args.email):
raise AccountInFreezeError()
is_login_error_rate_limit = AccountService.is_login_error_rate_limit(normalized_email)
is_login_error_rate_limit = AccountService.is_login_error_rate_limit(args.email)
if is_login_error_rate_limit:
raise EmailPasswordLoginLimitError()
invite_token = args.invite_token
invitation_data: dict[str, Any] | None = None
if invite_token:
invitation_data = RegisterService.get_invitation_with_case_fallback(None, request_email, invite_token)
if invitation_data is None:
invite_token = None
# TODO: why invitation is re-assigned with different type?
invitation = args.invite_token # type: ignore
if invitation:
invitation = RegisterService.get_invitation_if_token_valid(None, args.email, invitation) # type: ignore
try:
if invitation_data:
data = invitation_data.get("data", {})
if invitation:
data = invitation.get("data", {}) # type: ignore
invitee_email = data.get("email") if data else None
invitee_email_normalized = invitee_email.lower() if isinstance(invitee_email, str) else invitee_email
if invitee_email_normalized != normalized_email:
if invitee_email != args.email:
raise InvalidEmailError()
account = _authenticate_account_with_case_fallback(
request_email, normalized_email, args.password, invite_token
)
account = AccountService.authenticate(args.email, args.password, args.invite_token)
else:
account = AccountService.authenticate(args.email, args.password)
except services.errors.account.AccountLoginError:
raise AccountBannedError()
except services.errors.account.AccountPasswordError as exc:
AccountService.add_login_error_rate_limit(normalized_email)
raise AuthenticationFailedError() from exc
except services.errors.account.AccountPasswordError:
AccountService.add_login_error_rate_limit(args.email)
raise AuthenticationFailedError()
# SELF_HOSTED only have one workspace
tenants = TenantService.get_join_tenants(account)
if len(tenants) == 0:
@ -136,7 +129,7 @@ class LoginApi(Resource):
}
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(normalized_email)
AccountService.reset_login_error_rate_limit(args.email)
# Create response with cookies instead of returning tokens in body
response = make_response({"result": "success"})
@ -176,19 +169,18 @@ class ResetPasswordSendEmailApi(Resource):
@console_ns.expect(console_ns.models[EmailPayload.__name__])
def post(self):
args = EmailPayload.model_validate(console_ns.payload)
normalized_email = args.email.lower()
if args.language is not None and args.language == "zh-Hans":
language = "zh-Hans"
else:
language = "en-US"
try:
account = _get_account_with_case_fallback(args.email)
account = AccountService.get_user_through_email(args.email)
except AccountRegisterError:
raise AccountInFreezeError()
token = AccountService.send_reset_password_email(
email=normalized_email,
email=args.email,
account=account,
language=language,
is_allow_register=FeatureService.get_system_features().is_allow_register,
@ -203,7 +195,6 @@ class EmailCodeLoginSendEmailApi(Resource):
@console_ns.expect(console_ns.models[EmailPayload.__name__])
def post(self):
args = EmailPayload.model_validate(console_ns.payload)
normalized_email = args.email.lower()
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
@ -214,13 +205,13 @@ class EmailCodeLoginSendEmailApi(Resource):
else:
language = "en-US"
try:
account = _get_account_with_case_fallback(args.email)
account = AccountService.get_user_through_email(args.email)
except AccountRegisterError:
raise AccountInFreezeError()
if account is None:
if FeatureService.get_system_features().is_allow_register:
token = AccountService.send_email_code_login_email(email=normalized_email, language=language)
token = AccountService.send_email_code_login_email(email=args.email, language=language)
else:
raise AccountNotFound()
else:
@ -237,17 +228,14 @@ class EmailCodeLoginApi(Resource):
def post(self):
args = EmailCodeLoginPayload.model_validate(console_ns.payload)
original_email = args.email
user_email = original_email.lower()
user_email = args.email
language = args.language
token_data = AccountService.get_email_code_login_data(args.token)
if token_data is None:
raise InvalidTokenError()
token_email = token_data.get("email")
normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
if normalized_token_email != user_email:
if token_data["email"] != args.email:
raise InvalidEmailError()
if token_data["code"] != args.code:
@ -255,7 +243,7 @@ class EmailCodeLoginApi(Resource):
AccountService.revoke_email_code_login_token(args.token)
try:
account = _get_account_with_case_fallback(original_email)
account = AccountService.get_user_through_email(user_email)
except AccountRegisterError:
raise AccountInFreezeError()
if account:
@ -286,7 +274,7 @@ class EmailCodeLoginApi(Resource):
except WorkspacesLimitExceededError:
raise WorkspacesLimitExceeded()
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(user_email)
AccountService.reset_login_error_rate_limit(args.email)
# Create response with cookies instead of returning tokens in body
response = make_response({"result": "success"})
@ -320,22 +308,3 @@ class RefreshTokenApi(Resource):
return response
except Exception as e:
return {"result": "fail", "message": str(e)}, 401
def _get_account_with_case_fallback(email: str):
account = AccountService.get_user_through_email(email)
if account or email == email.lower():
return account
return AccountService.get_user_through_email(email.lower())
def _authenticate_account_with_case_fallback(
original_email: str, normalized_email: str, password: str, invite_token: str | None
):
try:
return AccountService.authenticate(original_email, password, invite_token)
except services.errors.account.AccountPasswordError:
if original_email == normalized_email:
raise
return AccountService.authenticate(normalized_email, password, invite_token)

View File

@ -3,6 +3,7 @@ import logging
import httpx
from flask import current_app, redirect, request
from flask_restx import Resource
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import Unauthorized
@ -117,10 +118,7 @@ class OAuthCallback(Resource):
invitation = RegisterService.get_invitation_by_token(token=invite_token)
if invitation:
invitation_email = invitation.get("email", None)
invitation_email_normalized = (
invitation_email.lower() if isinstance(invitation_email, str) else invitation_email
)
if invitation_email_normalized != user_info.email.lower():
if invitation_email != user_info.email:
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Invalid invitation token.")
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin/invite-settings?invite_token={invite_token}")
@ -161,7 +159,10 @@ class OAuthCallback(Resource):
ip_address=extract_remote_ip(request),
)
response = redirect(f"{dify_config.CONSOLE_WEB_URL}?oauth_new_user={str(oauth_new_user).lower()}")
base_url = dify_config.CONSOLE_WEB_URL
query_char = "&" if "?" in base_url else "?"
target_url = f"{base_url}{query_char}oauth_new_user={str(oauth_new_user).lower()}"
response = redirect(target_url)
set_access_token_to_cookie(request, response, token_pair.access_token)
set_refresh_token_to_cookie(request, response, token_pair.refresh_token)
@ -174,7 +175,7 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) ->
if not account:
with Session(db.engine) as session:
account = AccountService.get_account_by_email_with_case_fallback(user_info.email, session=session)
account = session.execute(select(Account).filter_by(email=user_info.email)).scalar_one_or_none()
return account
@ -196,10 +197,9 @@ def _generate_account(provider: str, user_info: OAuthUserInfo) -> tuple[Account,
tenant_was_created.send(new_tenant)
if not account:
normalized_email = user_info.email.lower()
oauth_new_user = True
if not FeatureService.get_system_features().is_allow_register:
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(user_info.email):
raise AccountRegisterError(
description=(
"This email account has been deleted within the past "
@ -210,11 +210,7 @@ def _generate_account(provider: str, user_info: OAuthUserInfo) -> tuple[Account,
raise AccountRegisterError(description=("Invalid email or password"))
account_name = user_info.name or "Dify"
account = RegisterService.register(
email=normalized_email,
name=account_name,
password=None,
open_id=user_info.id,
provider=provider,
email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider
)
# Set interface language

View File

@ -146,7 +146,6 @@ class DatasetUpdatePayload(BaseModel):
embedding_model: str | None = None
embedding_model_provider: str | None = None
retrieval_model: dict[str, Any] | None = None
summary_index_setting: dict[str, Any] | None = None
partial_member_list: list[dict[str, str]] | None = None
external_retrieval_model: dict[str, Any] | None = None
external_knowledge_id: str | None = None

View File

@ -39,10 +39,9 @@ from fields.document_fields import (
from libs.datetime_utils import naive_utc_now
from libs.login import current_account_with_tenant, login_required
from models import DatasetProcessRule, Document, DocumentSegment, UploadFile
from models.dataset import DocumentPipelineExecutionLog, DocumentSegmentSummary
from models.dataset import DocumentPipelineExecutionLog
from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel
from tasks.generate_summary_index_task import generate_summary_index_task
from ..app.error import (
ProviderModelCurrentlyNotSupportError,
@ -105,10 +104,6 @@ class DocumentRenamePayload(BaseModel):
name: str
class GenerateSummaryPayload(BaseModel):
document_list: list[str]
register_schema_models(
console_ns,
KnowledgeConfig,
@ -116,7 +111,6 @@ register_schema_models(
RetrievalModel,
DocumentRetryPayload,
DocumentRenamePayload,
GenerateSummaryPayload,
)
@ -301,97 +295,6 @@ class DatasetDocumentListApi(Resource):
paginated_documents = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
documents = paginated_documents.items
# Check if dataset has summary index enabled
has_summary_index = (
dataset.summary_index_setting
and dataset.summary_index_setting.get("enable") is True
)
# Filter documents that need summary calculation
documents_need_summary = [doc for doc in documents if doc.need_summary is True]
document_ids_need_summary = [str(doc.id) for doc in documents_need_summary]
# Calculate summary_index_status for documents that need summary (only if dataset summary index is enabled)
summary_status_map = {}
if has_summary_index and document_ids_need_summary:
# Get all segments for these documents (excluding qa_model and re_segment)
segments = (
db.session.query(DocumentSegment.id, DocumentSegment.document_id)
.where(
DocumentSegment.document_id.in_(document_ids_need_summary),
DocumentSegment.status != "re_segment",
DocumentSegment.tenant_id == current_tenant_id,
)
.all()
)
# Group segments by document_id
document_segments_map = {}
for segment in segments:
doc_id = str(segment.document_id)
if doc_id not in document_segments_map:
document_segments_map[doc_id] = []
document_segments_map[doc_id].append(segment.id)
# Get all summary records for these segments
all_segment_ids = [seg.id for seg in segments]
summaries = {}
if all_segment_ids:
summary_records = (
db.session.query(DocumentSegmentSummary)
.where(
DocumentSegmentSummary.chunk_id.in_(all_segment_ids),
DocumentSegmentSummary.dataset_id == dataset_id,
DocumentSegmentSummary.enabled == True, # Only count enabled summaries
)
.all()
)
summaries = {summary.chunk_id: summary.status for summary in summary_records}
# Calculate summary_index_status for each document
for doc_id in document_ids_need_summary:
segment_ids = document_segments_map.get(doc_id, [])
if not segment_ids:
# No segments, status is "GENERATING" (waiting to generate)
summary_status_map[doc_id] = "GENERATING"
continue
# Count summary statuses for this document's segments
status_counts = {"completed": 0, "generating": 0, "error": 0, "not_started": 0}
for segment_id in segment_ids:
status = summaries.get(segment_id, "not_started")
if status in status_counts:
status_counts[status] += 1
else:
status_counts["not_started"] += 1
total_segments = len(segment_ids)
completed_count = status_counts["completed"]
generating_count = status_counts["generating"]
error_count = status_counts["error"]
# Determine overall status (only three states: GENERATING, COMPLETED, ERROR)
if completed_count == total_segments:
summary_status_map[doc_id] = "COMPLETED"
elif error_count > 0:
# Has errors (even if some are completed or generating)
summary_status_map[doc_id] = "ERROR"
elif generating_count > 0 or status_counts["not_started"] > 0:
# Still generating or not started
summary_status_map[doc_id] = "GENERATING"
else:
# Default to generating
summary_status_map[doc_id] = "GENERATING"
# Add summary_index_status to each document
for document in documents:
if has_summary_index and document.need_summary is True:
document.summary_index_status = summary_status_map.get(str(document.id), "GENERATING")
else:
# Return null if summary index is not enabled or document doesn't need summary
document.summary_index_status = None
if fetch:
for document in documents:
completed_segments = (
@ -490,7 +393,6 @@ class DatasetDocumentListApi(Resource):
return {"result": "success"}, 204
@console_ns.route("/datasets/init")
class DatasetInitApi(Resource):
@console_ns.doc("init_dataset")
@ -849,12 +751,12 @@ class DocumentApi(DocumentResource):
elif metadata == "without":
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {}
data_source_info = document.data_source_detail_dict
response = {
"id": document.id,
"position": document.position,
"data_source_type": document.data_source_type,
"data_source_info": document.data_source_info_dict,
"data_source_detail_dict": document.data_source_detail_dict,
"data_source_info": data_source_info,
"dataset_process_rule_id": document.dataset_process_rule_id,
"dataset_process_rule": dataset_process_rules,
"document_process_rule": document_process_rules,
@ -878,17 +780,16 @@ class DocumentApi(DocumentResource):
"display_status": document.display_status,
"doc_form": document.doc_form,
"doc_language": document.doc_language,
"need_summary": document.need_summary if document.need_summary is not None else False,
}
else:
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {}
data_source_info = document.data_source_detail_dict
response = {
"id": document.id,
"position": document.position,
"data_source_type": document.data_source_type,
"data_source_info": document.data_source_info_dict,
"data_source_detail_dict": document.data_source_detail_dict,
"data_source_info": data_source_info,
"dataset_process_rule_id": document.dataset_process_rule_id,
"dataset_process_rule": dataset_process_rules,
"document_process_rule": document_process_rules,
@ -914,7 +815,6 @@ class DocumentApi(DocumentResource):
"display_status": document.display_status,
"doc_form": document.doc_form,
"doc_language": document.doc_language,
"need_summary": document.need_summary if document.need_summary is not None else False,
}
return response, 200
@ -1282,211 +1182,3 @@ class DocumentPipelineExecutionLogApi(DocumentResource):
"input_data": log.input_data,
"datasource_node_id": log.datasource_node_id,
}, 200
@console_ns.route("/datasets/<uuid:dataset_id>/documents/generate-summary")
class DocumentGenerateSummaryApi(Resource):
@console_ns.doc("generate_summary_for_documents")
@console_ns.doc(description="Generate summary index for documents")
@console_ns.doc(params={"dataset_id": "Dataset ID"})
@console_ns.expect(console_ns.models[GenerateSummaryPayload.__name__])
@console_ns.response(200, "Summary generation started successfully")
@console_ns.response(400, "Invalid request or dataset configuration")
@console_ns.response(403, "Permission denied")
@console_ns.response(404, "Dataset not found")
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id):
"""
Generate summary index for specified documents.
This endpoint checks if the dataset configuration supports summary generation
(indexing_technique must be 'high_quality' and summary_index_setting.enable must be true),
then asynchronously generates summary indexes for the provided documents.
"""
current_user, _ = current_account_with_tenant()
dataset_id = str(dataset_id)
# Get dataset
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
# Check permissions
if not current_user.is_dataset_editor:
raise Forbidden()
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# Validate request payload
payload = GenerateSummaryPayload.model_validate(console_ns.payload or {})
document_list = payload.document_list
if not document_list:
raise ValueError("document_list cannot be empty.")
# Check if dataset configuration supports summary generation
if dataset.indexing_technique != "high_quality":
raise ValueError(
f"Summary generation is only available for 'high_quality' indexing technique. "
f"Current indexing technique: {dataset.indexing_technique}"
)
summary_index_setting = dataset.summary_index_setting
if not summary_index_setting or not summary_index_setting.get("enable"):
raise ValueError(
"Summary index is not enabled for this dataset. "
"Please enable it in the dataset settings."
)
# Verify all documents exist and belong to the dataset
documents = (
db.session.query(Document)
.filter(
Document.id.in_(document_list),
Document.dataset_id == dataset_id,
)
.all()
)
if len(documents) != len(document_list):
found_ids = {doc.id for doc in documents}
missing_ids = set(document_list) - found_ids
raise NotFound(f"Some documents not found: {list(missing_ids)}")
# Dispatch async tasks for each document
for document in documents:
# Skip qa_model documents as they don't generate summaries
if document.doc_form == "qa_model":
logger.info(
f"Skipping summary generation for qa_model document {document.id}"
)
continue
# Dispatch async task
generate_summary_index_task(dataset_id, document.id)
logger.info(
f"Dispatched summary generation task for document {document.id} in dataset {dataset_id}"
)
return {"result": "success"}, 200
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/summary-status")
class DocumentSummaryStatusApi(DocumentResource):
@console_ns.doc("get_document_summary_status")
@console_ns.doc(description="Get summary index generation status for a document")
@console_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
@console_ns.response(200, "Summary status retrieved successfully")
@console_ns.response(404, "Document not found")
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id, document_id):
"""
Get summary index generation status for a document.
Returns:
- total_segments: Total number of segments in the document
- summary_status: Dictionary with status counts
- completed: Number of summaries completed
- generating: Number of summaries being generated
- error: Number of summaries with errors
- not_started: Number of segments without summary records
- summaries: List of summary records with status and content preview
"""
current_user, _ = current_account_with_tenant()
dataset_id = str(dataset_id)
document_id = str(document_id)
# Get document
document = self.get_document(dataset_id, document_id)
# Get dataset
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
# Check permissions
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# Get all segments for this document
segments = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.document_id == document_id,
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
)
.all()
)
total_segments = len(segments)
# Get all summary records for these segments
segment_ids = [segment.id for segment in segments]
summaries = []
if segment_ids:
summaries = (
db.session.query(DocumentSegmentSummary)
.filter(
DocumentSegmentSummary.document_id == document_id,
DocumentSegmentSummary.dataset_id == dataset_id,
DocumentSegmentSummary.chunk_id.in_(segment_ids),
DocumentSegmentSummary.enabled == True, # Only return enabled summaries
)
.all()
)
# Create a mapping of chunk_id to summary
summary_map = {summary.chunk_id: summary for summary in summaries}
# Count statuses
status_counts = {
"completed": 0,
"generating": 0,
"error": 0,
"not_started": 0,
}
summary_list = []
for segment in segments:
summary = summary_map.get(segment.id)
if summary:
status = summary.status
status_counts[status] = status_counts.get(status, 0) + 1
summary_list.append({
"segment_id": segment.id,
"segment_position": segment.position,
"status": summary.status,
"summary_preview": summary.summary_content[:100] + "..." if summary.summary_content and len(summary.summary_content) > 100 else summary.summary_content,
"error": summary.error,
"created_at": int(summary.created_at.timestamp()) if summary.created_at else None,
"updated_at": int(summary.updated_at.timestamp()) if summary.updated_at else None,
})
else:
status_counts["not_started"] += 1
summary_list.append({
"segment_id": segment.id,
"segment_position": segment.position,
"status": "not_started",
"summary_preview": None,
"error": None,
"created_at": None,
"updated_at": None,
})
return {
"total_segments": total_segments,
"summary_status": status_counts,
"summaries": summary_list,
}, 200

View File

@ -30,9 +30,8 @@ from core.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from fields.segment_fields import child_chunk_fields, segment_fields
from libs.helper import escape_like_pattern
from libs.login import current_account_with_tenant, login_required
from models.dataset import ChildChunk, DocumentSegment, DocumentSegmentSummary
from models.dataset import ChildChunk, DocumentSegment
from models.model import UploadFile
from services.dataset_service import DatasetService, DocumentService, SegmentService
from services.entities.knowledge_entities.knowledge_entities import ChildChunkUpdateArgs, SegmentUpdateArgs
@ -41,23 +40,6 @@ from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingS
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
def _get_segment_with_summary(segment, dataset_id):
"""Helper function to marshal segment and add summary information."""
segment_dict = marshal(segment, segment_fields)
# Query summary for this segment (only enabled summaries)
summary = (
db.session.query(DocumentSegmentSummary)
.where(
DocumentSegmentSummary.chunk_id == segment.id,
DocumentSegmentSummary.dataset_id == dataset_id,
DocumentSegmentSummary.enabled == True, # Only return enabled summaries
)
.first()
)
segment_dict["summary"] = summary.summary_content if summary else None
return segment_dict
class SegmentListQuery(BaseModel):
limit: int = Field(default=20, ge=1, le=100)
status: list[str] = Field(default_factory=list)
@ -80,7 +62,6 @@ class SegmentUpdatePayload(BaseModel):
keywords: list[str] | None = None
regenerate_child_chunks: bool = False
attachment_ids: list[str] | None = None
summary: str | None = None # Summary content for summary index
class BatchImportPayload(BaseModel):
@ -164,8 +145,6 @@ class DatasetDocumentSegmentListApi(Resource):
query = query.where(DocumentSegment.hit_count >= hit_count_gte)
if keyword:
# Escape special characters in keyword to prevent SQL injection via LIKE wildcards
escaped_keyword = escape_like_pattern(keyword)
# Search in both content and keywords fields
# Use database-specific methods for JSON array search
if dify_config.SQLALCHEMY_DATABASE_URI_SCHEME == "postgresql":
@ -177,15 +156,15 @@ class DatasetDocumentSegmentListApi(Resource):
.scalar_subquery()
),
",",
).ilike(f"%{escaped_keyword}%", escape="\\")
).ilike(f"%{keyword}%")
else:
# MySQL: Cast JSON to string for pattern matching
# MySQL stores Chinese text directly in JSON without Unicode escaping
keywords_condition = cast(DocumentSegment.keywords, String).ilike(f"%{escaped_keyword}%", escape="\\")
keywords_condition = cast(DocumentSegment.keywords, String).ilike(f"%{keyword}%")
query = query.where(
or_(
DocumentSegment.content.ilike(f"%{escaped_keyword}%", escape="\\"),
DocumentSegment.content.ilike(f"%{keyword}%"),
keywords_condition,
)
)
@ -198,34 +177,8 @@ class DatasetDocumentSegmentListApi(Resource):
segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
# Query summaries for all segments in this page (batch query for efficiency)
segment_ids = [segment.id for segment in segments.items]
summaries = {}
if segment_ids:
summary_records = (
db.session.query(DocumentSegmentSummary)
.where(
DocumentSegmentSummary.chunk_id.in_(segment_ids),
DocumentSegmentSummary.dataset_id == dataset_id,
)
.all()
)
# Only include enabled summaries
summaries = {
summary.chunk_id: summary.summary_content
for summary in summary_records
if summary.enabled is True
}
# Add summary to each segment
segments_with_summary = []
for segment in segments.items:
segment_dict = marshal(segment, segment_fields)
segment_dict["summary"] = summaries.get(segment.id)
segments_with_summary.append(segment_dict)
response = {
"data": segments_with_summary,
"data": marshal(segments.items, segment_fields),
"limit": limit,
"total": segments.total,
"total_pages": segments.pages,
@ -371,7 +324,7 @@ class DatasetDocumentSegmentAddApi(Resource):
payload_dict = payload.model_dump(exclude_none=True)
SegmentService.segment_create_args_validate(payload_dict, document)
segment = SegmentService.create_segment(payload_dict, document, dataset)
return {"data": _get_segment_with_summary(segment, dataset_id), "doc_form": document.doc_form}, 200
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>")
@ -433,12 +386,10 @@ class DatasetDocumentSegmentUpdateApi(Resource):
payload = SegmentUpdatePayload.model_validate(console_ns.payload or {})
payload_dict = payload.model_dump(exclude_none=True)
SegmentService.segment_create_args_validate(payload_dict, document)
# Update segment (summary update with change detection is handled in SegmentService.update_segment)
segment = SegmentService.update_segment(
SegmentUpdateArgs.model_validate(payload.model_dump(exclude_none=True)), segment, document, dataset
)
return {"data": _get_segment_with_summary(segment, dataset_id), "doc_form": document.doc_form}, 200
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
@setup_required
@login_required

View File

@ -1,4 +1,4 @@
from flask_restx import Resource, fields
from flask_restx import Resource
from controllers.common.schema import register_schema_model
from libs.login import login_required
@ -10,56 +10,17 @@ from ..wraps import (
cloud_edition_billing_rate_limit_check,
setup_required,
)
from fields.hit_testing_fields import (
child_chunk_fields,
document_fields,
files_fields,
hit_testing_record_fields,
segment_fields,
)
register_schema_model(console_ns, HitTestingPayload)
def _get_or_create_model(model_name: str, field_def):
"""Get or create a flask_restx model to avoid dict type issues in Swagger."""
existing = console_ns.models.get(model_name)
if existing is None:
existing = console_ns.model(model_name, field_def)
return existing
# Register models for flask_restx to avoid dict type issues in Swagger
document_model = _get_or_create_model("HitTestingDocument", document_fields)
segment_fields_copy = segment_fields.copy()
segment_fields_copy["document"] = fields.Nested(document_model)
segment_model = _get_or_create_model("HitTestingSegment", segment_fields_copy)
child_chunk_model = _get_or_create_model("HitTestingChildChunk", child_chunk_fields)
files_model = _get_or_create_model("HitTestingFile", files_fields)
hit_testing_record_fields_copy = hit_testing_record_fields.copy()
hit_testing_record_fields_copy["segment"] = fields.Nested(segment_model)
hit_testing_record_fields_copy["child_chunks"] = fields.List(fields.Nested(child_chunk_model))
hit_testing_record_fields_copy["files"] = fields.List(fields.Nested(files_model))
hit_testing_record_model = _get_or_create_model("HitTestingRecord", hit_testing_record_fields_copy)
# Response model for hit testing API
hit_testing_response_fields = {
"query": fields.String,
"records": fields.List(fields.Nested(hit_testing_record_model)),
}
hit_testing_response_model = _get_or_create_model("HitTestingResponse", hit_testing_response_fields)
@console_ns.route("/datasets/<uuid:dataset_id>/hit-testing")
class HitTestingApi(Resource, DatasetsHitTestingBase):
@console_ns.doc("test_dataset_retrieval")
@console_ns.doc(description="Test dataset knowledge retrieval")
@console_ns.doc(params={"dataset_id": "Dataset ID"})
@console_ns.expect(console_ns.models[HitTestingPayload.__name__])
@console_ns.response(200, "Hit testing completed successfully", model=hit_testing_response_model)
@console_ns.response(200, "Hit testing completed successfully")
@console_ns.response(404, "Dataset not found")
@console_ns.response(400, "Invalid parameters")
@setup_required

View File

@ -1,7 +1,7 @@
import logging
from typing import Any
from flask_restx import marshal
from flask_restx import marshal, reqparse
from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
@ -56,10 +56,15 @@ class DatasetsHitTestingBase:
HitTestingService.hit_testing_args_check(args)
@staticmethod
def parse_args(payload: dict[str, Any]) -> dict[str, Any]:
"""Validate and return hit-testing arguments from an incoming payload."""
hit_testing_payload = HitTestingPayload.model_validate(payload or {})
return hit_testing_payload.model_dump(exclude_none=True)
def parse_args():
parser = (
reqparse.RequestParser()
.add_argument("query", type=str, required=False, location="json")
.add_argument("attachment_ids", type=list, required=False, location="json")
.add_argument("retrieval_model", type=dict, required=False, location="json")
.add_argument("external_retrieval_model", type=dict, required=False, location="json")
)
return parser.parse_args()
@staticmethod
def perform_hit_testing(dataset, args):

View File

@ -355,7 +355,7 @@ class PublishedRagPipelineRunApi(Resource):
pipeline=pipeline,
user=current_user,
args=args,
invoke_from=InvokeFrom.DEBUGGER if payload.is_preview else InvokeFrom.PUBLISHED_PIPELINE,
invoke_from=InvokeFrom.DEBUGGER if payload.is_preview else InvokeFrom.PUBLISHED,
streaming=streaming,
)

View File

@ -1,43 +0,0 @@
from flask import request
from flask_restx import Resource
from controllers.console import api
from controllers.console.explore.wraps import explore_banner_enabled
from extensions.ext_database import db
from models.model import ExporleBanner
class BannerApi(Resource):
"""Resource for banner list."""
@explore_banner_enabled
def get(self):
"""Get banner list."""
language = request.args.get("language", "en-US")
# Build base query for enabled banners
base_query = db.session.query(ExporleBanner).where(ExporleBanner.status == "enabled")
# Try to get banners in the requested language
banners = base_query.where(ExporleBanner.language == language).order_by(ExporleBanner.sort).all()
# Fallback to en-US if no banners found and language is not en-US
if not banners and language != "en-US":
banners = base_query.where(ExporleBanner.language == "en-US").order_by(ExporleBanner.sort).all()
# Convert banners to serializable format
result = []
for banner in banners:
banner_data = {
"id": banner.id,
"content": banner.content, # Already parsed as JSON by SQLAlchemy
"link": banner.link,
"sort": banner.sort,
"status": banner.status,
"created_at": banner.created_at.isoformat() if banner.created_at else None,
}
result.append(banner_data)
return result
api.add_resource(BannerApi, "/explore/banners")

View File

@ -29,25 +29,3 @@ class AppAccessDeniedError(BaseHTTPException):
error_code = "access_denied"
description = "App access denied."
code = 403
class TrialAppNotAllowed(BaseHTTPException):
"""*403* `Trial App Not Allowed`
Raise if the user has reached the trial app limit.
"""
error_code = "trial_app_not_allowed"
code = 403
description = "the app is not allowed to be trial."
class TrialAppLimitExceeded(BaseHTTPException):
"""*403* `Trial App Limit Exceeded`
Raise if the user has exceeded the trial app limit.
"""
error_code = "trial_app_limit_exceeded"
code = 403
description = "The user has exceeded the trial app limit."

View File

@ -29,7 +29,6 @@ recommended_app_fields = {
"category": fields.String,
"position": fields.Integer,
"is_listed": fields.Boolean,
"can_trial": fields.Boolean,
}
recommended_app_list_fields = {

View File

@ -1,512 +0,0 @@
import logging
from typing import Any, cast
from flask import request
from flask_restx import Resource, marshal, marshal_with, reqparse
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services
from controllers.common.fields import Parameters as ParametersResponse
from controllers.common.fields import Site as SiteResponse
from controllers.console import api
from controllers.console.app.error import (
AppUnavailableError,
AudioTooLargeError,
CompletionRequestError,
ConversationCompletedError,
NeedAddIdsError,
NoAudioUploadedError,
ProviderModelCurrentlyNotSupportError,
ProviderNotInitializeError,
ProviderNotSupportSpeechToTextError,
ProviderQuotaExceededError,
UnsupportedAudioTypeError,
)
from controllers.console.app.wraps import get_app_model_with_trial
from controllers.console.explore.error import (
AppSuggestedQuestionsAfterAnswerDisabledError,
NotChatAppError,
NotCompletionAppError,
NotWorkflowAppError,
)
from controllers.console.explore.wraps import TrialAppResource, trial_feature_enable
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import (
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from core.model_runtime.errors.invoke import InvokeError
from core.workflow.graph_engine.manager import GraphEngineManager
from extensions.ext_database import db
from fields.app_fields import app_detail_fields_with_site
from fields.dataset_fields import dataset_fields
from fields.workflow_fields import workflow_fields
from libs import helper
from libs.helper import uuid_value
from libs.login import current_user
from models import Account
from models.account import TenantStatus
from models.model import AppMode, Site
from models.workflow import Workflow
from services.app_generate_service import AppGenerateService
from services.app_service import AppService
from services.audio_service import AudioService
from services.dataset_service import DatasetService
from services.errors.audio import (
AudioTooLargeServiceError,
NoAudioUploadedServiceError,
ProviderNotSupportSpeechToTextServiceError,
UnsupportedAudioTypeServiceError,
)
from services.errors.conversation import ConversationNotExistsError
from services.errors.llm import InvokeRateLimitError
from services.errors.message import (
MessageNotExistsError,
SuggestedQuestionsAfterAnswerDisabledError,
)
from services.message_service import MessageService
from services.recommended_app_service import RecommendedAppService
logger = logging.getLogger(__name__)
class TrialAppWorkflowRunApi(TrialAppResource):
def post(self, trial_app):
"""
Run workflow
"""
app_model = trial_app
if not app_model:
raise NotWorkflowAppError()
app_mode = AppMode.value_of(app_model.mode)
if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("files", type=list, required=False, location="json")
args = parser.parse_args()
assert current_user is not None
try:
app_id = app_model.id
user_id = current_user.id
response = AppGenerateService.generate(
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
)
RecommendedAppService.add_trial_app_record(app_id, user_id)
return helper.compact_generate_response(response)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except InvokeRateLimitError as ex:
raise InvokeRateLimitHttpError(ex.description)
except ValueError as e:
raise e
except Exception:
logger.exception("internal server error.")
raise InternalServerError()
class TrialAppWorkflowTaskStopApi(TrialAppResource):
def post(self, trial_app, task_id: str):
"""
Stop workflow task
"""
app_model = trial_app
if not app_model:
raise NotWorkflowAppError()
app_mode = AppMode.value_of(app_model.mode)
if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError()
assert current_user is not None
# Stop using both mechanisms for backward compatibility
# Legacy stop flag mechanism (without user check)
AppQueueManager.set_stop_flag_no_user_check(task_id)
# New graph engine command channel mechanism
GraphEngineManager.send_stop_command(task_id)
return {"result": "success"}
class TrialChatApi(TrialAppResource):
@trial_feature_enable
def post(self, trial_app):
app_model = trial_app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, location="json")
parser.add_argument("query", type=str, required=True, location="json")
parser.add_argument("files", type=list, required=False, location="json")
parser.add_argument("conversation_id", type=uuid_value, location="json")
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
args = parser.parse_args()
args["auto_generate_name"] = False
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
# Get IDs before they might be detached from session
app_id = app_model.id
user_id = current_user.id
response = AppGenerateService.generate(
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
)
RecommendedAppService.add_trial_app_record(app_id, user_id)
return helper.compact_generate_response(response)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
except services.errors.conversation.ConversationCompletedError:
raise ConversationCompletedError()
except services.errors.app_model_config.AppModelConfigBrokenError:
logger.exception("App model config broken.")
raise AppUnavailableError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except InvokeRateLimitError as ex:
raise InvokeRateLimitHttpError(ex.description)
except ValueError as e:
raise e
except Exception:
logger.exception("internal server error.")
raise InternalServerError()
class TrialMessageSuggestedQuestionApi(TrialAppResource):
@trial_feature_enable
def get(self, trial_app, message_id):
app_model = trial_app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
message_id = str(message_id)
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
questions = MessageService.get_suggested_questions_after_answer(
app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE
)
except MessageNotExistsError:
raise NotFound("Message not found")
except ConversationNotExistsError:
raise NotFound("Conversation not found")
except SuggestedQuestionsAfterAnswerDisabledError:
raise AppSuggestedQuestionsAfterAnswerDisabledError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except Exception:
logger.exception("internal server error.")
raise InternalServerError()
return {"data": questions}
class TrialChatAudioApi(TrialAppResource):
@trial_feature_enable
def post(self, trial_app):
app_model = trial_app
file = request.files["file"]
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
# Get IDs before they might be detached from session
app_id = app_model.id
user_id = current_user.id
response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=None)
RecommendedAppService.add_trial_app_record(app_id, user_id)
return response
except services.errors.app_model_config.AppModelConfigBrokenError:
logger.exception("App model config broken.")
raise AppUnavailableError()
except NoAudioUploadedServiceError:
raise NoAudioUploadedError()
except AudioTooLargeServiceError as e:
raise AudioTooLargeError(str(e))
except UnsupportedAudioTypeServiceError:
raise UnsupportedAudioTypeError()
except ProviderNotSupportSpeechToTextServiceError:
raise ProviderNotSupportSpeechToTextError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:
raise e
except Exception as e:
logger.exception("internal server error.")
raise InternalServerError()
class TrialChatTextApi(TrialAppResource):
@trial_feature_enable
def post(self, trial_app):
app_model = trial_app
try:
parser = reqparse.RequestParser()
parser.add_argument("message_id", type=str, required=False, location="json")
parser.add_argument("voice", type=str, location="json")
parser.add_argument("text", type=str, location="json")
parser.add_argument("streaming", type=bool, location="json")
args = parser.parse_args()
message_id = args.get("message_id", None)
text = args.get("text", None)
voice = args.get("voice", None)
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
# Get IDs before they might be detached from session
app_id = app_model.id
user_id = current_user.id
response = AudioService.transcript_tts(app_model=app_model, text=text, voice=voice, message_id=message_id)
RecommendedAppService.add_trial_app_record(app_id, user_id)
return response
except services.errors.app_model_config.AppModelConfigBrokenError:
logger.exception("App model config broken.")
raise AppUnavailableError()
except NoAudioUploadedServiceError:
raise NoAudioUploadedError()
except AudioTooLargeServiceError as e:
raise AudioTooLargeError(str(e))
except UnsupportedAudioTypeServiceError:
raise UnsupportedAudioTypeError()
except ProviderNotSupportSpeechToTextServiceError:
raise ProviderNotSupportSpeechToTextError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:
raise e
except Exception as e:
logger.exception("internal server error.")
raise InternalServerError()
class TrialCompletionApi(TrialAppResource):
@trial_feature_enable
def post(self, trial_app):
app_model = trial_app
if app_model.mode != "completion":
raise NotCompletionAppError()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, location="json")
parser.add_argument("query", type=str, location="json", default="")
parser.add_argument("files", type=list, required=False, location="json")
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
args = parser.parse_args()
streaming = args["response_mode"] == "streaming"
args["auto_generate_name"] = False
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
# Get IDs before they might be detached from session
app_id = app_model.id
user_id = current_user.id
response = AppGenerateService.generate(
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=streaming
)
RecommendedAppService.add_trial_app_record(app_id, user_id)
return helper.compact_generate_response(response)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
except services.errors.conversation.ConversationCompletedError:
raise ConversationCompletedError()
except services.errors.app_model_config.AppModelConfigBrokenError:
logger.exception("App model config broken.")
raise AppUnavailableError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:
raise e
except Exception:
logger.exception("internal server error.")
raise InternalServerError()
class TrialSitApi(Resource):
"""Resource for trial app sites."""
@trial_feature_enable
@get_app_model_with_trial
def get(self, app_model):
"""Retrieve app site info.
Returns the site configuration for the application including theme, icons, and text.
"""
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
if not site:
raise Forbidden()
assert app_model.tenant
if app_model.tenant.status == TenantStatus.ARCHIVE:
raise Forbidden()
return SiteResponse.model_validate(site).model_dump(mode="json")
class TrialAppParameterApi(Resource):
"""Resource for app variables."""
@trial_feature_enable
@get_app_model_with_trial
def get(self, app_model):
"""Retrieve app parameters."""
if app_model is None:
raise AppUnavailableError()
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
workflow = app_model.workflow
if workflow is None:
raise AppUnavailableError()
features_dict = workflow.features_dict
user_input_form = workflow.user_input_form(to_old_structure=True)
else:
app_model_config = app_model.app_model_config
if app_model_config is None:
raise AppUnavailableError()
features_dict = app_model_config.to_dict()
user_input_form = features_dict.get("user_input_form", [])
parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
return ParametersResponse.model_validate(parameters).model_dump(mode="json")
class AppApi(Resource):
@trial_feature_enable
@get_app_model_with_trial
@marshal_with(app_detail_fields_with_site)
def get(self, app_model):
"""Get app detail"""
app_service = AppService()
app_model = app_service.get_app(app_model)
return app_model
class AppWorkflowApi(Resource):
@trial_feature_enable
@get_app_model_with_trial
@marshal_with(workflow_fields)
def get(self, app_model):
"""Get workflow detail"""
if not app_model.workflow_id:
raise AppUnavailableError()
workflow = (
db.session.query(Workflow)
.where(
Workflow.id == app_model.workflow_id,
)
.first()
)
return workflow
class DatasetListApi(Resource):
@trial_feature_enable
@get_app_model_with_trial
def get(self, app_model):
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
ids = request.args.getlist("ids")
tenant_id = app_model.tenant_id
if ids:
datasets, total = DatasetService.get_datasets_by_ids(ids, tenant_id)
else:
raise NeedAddIdsError()
data = cast(list[dict[str, Any]], marshal(datasets, dataset_fields))
response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page}
return response
api.add_resource(TrialChatApi, "/trial-apps/<uuid:app_id>/chat-messages", endpoint="trial_app_chat_completion")
api.add_resource(
TrialMessageSuggestedQuestionApi,
"/trial-apps/<uuid:app_id>/messages/<uuid:message_id>/suggested-questions",
endpoint="trial_app_suggested_question",
)
api.add_resource(TrialChatAudioApi, "/trial-apps/<uuid:app_id>/audio-to-text", endpoint="trial_app_audio")
api.add_resource(TrialChatTextApi, "/trial-apps/<uuid:app_id>/text-to-audio", endpoint="trial_app_text")
api.add_resource(TrialCompletionApi, "/trial-apps/<uuid:app_id>/completion-messages", endpoint="trial_app_completion")
api.add_resource(TrialSitApi, "/trial-apps/<uuid:app_id>/site")
api.add_resource(TrialAppParameterApi, "/trial-apps/<uuid:app_id>/parameters", endpoint="trial_app_parameters")
api.add_resource(AppApi, "/trial-apps/<uuid:app_id>", endpoint="trial_app")
api.add_resource(TrialAppWorkflowRunApi, "/trial-apps/<uuid:app_id>/workflows/run", endpoint="trial_app_workflow_run")
api.add_resource(TrialAppWorkflowTaskStopApi, "/trial-apps/<uuid:app_id>/workflows/tasks/<string:task_id>/stop")
api.add_resource(AppWorkflowApi, "/trial-apps/<uuid:app_id>/workflows", endpoint="trial_app_workflow")
api.add_resource(DatasetListApi, "/trial-apps/<uuid:app_id>/datasets", endpoint="trial_app_datasets")

View File

@ -2,15 +2,14 @@ from collections.abc import Callable
from functools import wraps
from typing import Concatenate, ParamSpec, TypeVar
from flask import abort
from flask_restx import Resource
from werkzeug.exceptions import NotFound
from controllers.console.explore.error import AppAccessDeniedError, TrialAppLimitExceeded, TrialAppNotAllowed
from controllers.console.explore.error import AppAccessDeniedError
from controllers.console.wraps import account_initialization_required
from extensions.ext_database import db
from libs.login import current_account_with_tenant, login_required
from models import AccountTrialAppRecord, App, InstalledApp, TrialApp
from models import InstalledApp
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
@ -72,61 +71,6 @@ def user_allowed_to_access_app(view: Callable[Concatenate[InstalledApp, P], R] |
return decorator
def trial_app_required(view: Callable[Concatenate[App, P], R] | None = None):
def decorator(view: Callable[Concatenate[App, P], R]):
@wraps(view)
def decorated(app_id: str, *args: P.args, **kwargs: P.kwargs):
current_user, _ = current_account_with_tenant()
trial_app = db.session.query(TrialApp).where(TrialApp.app_id == str(app_id)).first()
if trial_app is None:
raise TrialAppNotAllowed()
app = trial_app.app
if app is None:
raise TrialAppNotAllowed()
account_trial_app_record = (
db.session.query(AccountTrialAppRecord)
.where(AccountTrialAppRecord.account_id == current_user.id, AccountTrialAppRecord.app_id == app_id)
.first()
)
if account_trial_app_record:
if account_trial_app_record.count >= trial_app.trial_limit:
raise TrialAppLimitExceeded()
return view(app, *args, **kwargs)
return decorated
if view:
return decorator(view)
return decorator
def trial_feature_enable(view: Callable[..., R]) -> Callable[..., R]:
@wraps(view)
def decorated(*args, **kwargs):
features = FeatureService.get_system_features()
if not features.enable_trial_app:
abort(403, "Trial app feature is not enabled.")
return view(*args, **kwargs)
return decorated
def explore_banner_enabled(view: Callable[..., R]) -> Callable[..., R]:
@wraps(view)
def decorated(*args, **kwargs):
features = FeatureService.get_system_features()
if not features.enable_explore_banner:
abort(403, "Explore banner feature is not enabled.")
return view(*args, **kwargs)
return decorated
class InstalledAppResource(Resource):
# must be reversed if there are multiple decorators
@ -136,13 +80,3 @@ class InstalledAppResource(Resource):
account_initialization_required,
login_required,
]
class TrialAppResource(Resource):
# must be reversed if there are multiple decorators
method_decorators = [
trial_app_required,
account_initialization_required,
login_required,
]

View File

@ -1,7 +1,7 @@
from typing import Literal
from flask import request
from flask_restx import Resource
from flask_restx import Resource, marshal_with
from werkzeug.exceptions import Forbidden
import services
@ -15,21 +15,18 @@ from controllers.common.errors import (
TooManyFilesError,
UnsupportedFileTypeError,
)
from controllers.common.schema import register_schema_models
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,
setup_required,
)
from extensions.ext_database import db
from fields.file_fields import FileResponse, UploadConfig
from fields.file_fields import file_fields, upload_config_fields
from libs.login import current_account_with_tenant, login_required
from services.file_service import FileService
from . import console_ns
register_schema_models(console_ns, UploadConfig, FileResponse)
PREVIEW_WORDS_LIMIT = 3000
@ -38,27 +35,26 @@ class FileApi(Resource):
@setup_required
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[UploadConfig.__name__])
@marshal_with(upload_config_fields)
def get(self):
config = UploadConfig(
file_size_limit=dify_config.UPLOAD_FILE_SIZE_LIMIT,
batch_count_limit=dify_config.UPLOAD_FILE_BATCH_LIMIT,
file_upload_limit=dify_config.BATCH_UPLOAD_LIMIT,
image_file_size_limit=dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
video_file_size_limit=dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
audio_file_size_limit=dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
workflow_file_upload_limit=dify_config.WORKFLOW_FILE_UPLOAD_LIMIT,
image_file_batch_limit=dify_config.IMAGE_FILE_BATCH_LIMIT,
single_chunk_attachment_limit=dify_config.SINGLE_CHUNK_ATTACHMENT_LIMIT,
attachment_image_file_size_limit=dify_config.ATTACHMENT_IMAGE_FILE_SIZE_LIMIT,
)
return config.model_dump(mode="json"), 200
return {
"file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT,
"batch_count_limit": dify_config.UPLOAD_FILE_BATCH_LIMIT,
"file_upload_limit": dify_config.BATCH_UPLOAD_LIMIT,
"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
"workflow_file_upload_limit": dify_config.WORKFLOW_FILE_UPLOAD_LIMIT,
"image_file_batch_limit": dify_config.IMAGE_FILE_BATCH_LIMIT,
"single_chunk_attachment_limit": dify_config.SINGLE_CHUNK_ATTACHMENT_LIMIT,
"attachment_image_file_size_limit": dify_config.ATTACHMENT_IMAGE_FILE_SIZE_LIMIT,
}, 200
@setup_required
@login_required
@account_initialization_required
@marshal_with(file_fields)
@cloud_edition_billing_resource_check("documents")
@console_ns.response(201, "File uploaded successfully", console_ns.models[FileResponse.__name__])
def post(self):
current_user, _ = current_account_with_tenant()
source_str = request.form.get("source")
@ -94,8 +90,7 @@ class FileApi(Resource):
except services.errors.file.BlockedFileExtensionError as blocked_extension_error:
raise BlockedFileExtensionError(blocked_extension_error.description)
response = FileResponse.model_validate(upload_file, from_attributes=True)
return response.model_dump(mode="json"), 201
return upload_file, 201
@console_ns.route("/files/<uuid:file_id>/preview")

View File

@ -1,7 +1,7 @@
import urllib.parse
import httpx
from flask_restx import Resource
from flask_restx import Resource, marshal_with
from pydantic import BaseModel, Field
import services
@ -11,22 +11,19 @@ from controllers.common.errors import (
RemoteFileUploadError,
UnsupportedFileTypeError,
)
from controllers.common.schema import register_schema_models
from core.file import helpers as file_helpers
from core.helper import ssrf_proxy
from extensions.ext_database import db
from fields.file_fields import FileWithSignedUrl, RemoteFileInfo
from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields
from libs.login import current_account_with_tenant
from services.file_service import FileService
from . import console_ns
register_schema_models(console_ns, RemoteFileInfo, FileWithSignedUrl)
@console_ns.route("/remote-files/<path:url>")
class RemoteFileInfoApi(Resource):
@console_ns.response(200, "Remote file info", console_ns.models[RemoteFileInfo.__name__])
@marshal_with(remote_file_info_fields)
def get(self, url):
decoded_url = urllib.parse.unquote(url)
resp = ssrf_proxy.head(decoded_url)
@ -34,11 +31,10 @@ class RemoteFileInfoApi(Resource):
# failed back to get method
resp = ssrf_proxy.get(decoded_url, timeout=3)
resp.raise_for_status()
info = RemoteFileInfo(
file_type=resp.headers.get("Content-Type", "application/octet-stream"),
file_length=int(resp.headers.get("Content-Length", 0)),
)
return info.model_dump(mode="json")
return {
"file_type": resp.headers.get("Content-Type", "application/octet-stream"),
"file_length": int(resp.headers.get("Content-Length", 0)),
}
class RemoteFileUploadPayload(BaseModel):
@ -54,7 +50,7 @@ console_ns.schema_model(
@console_ns.route("/remote-files/upload")
class RemoteFileUploadApi(Resource):
@console_ns.expect(console_ns.models[RemoteFileUploadPayload.__name__])
@console_ns.response(201, "Remote file uploaded", console_ns.models[FileWithSignedUrl.__name__])
@marshal_with(file_fields_with_signed_url)
def post(self):
args = RemoteFileUploadPayload.model_validate(console_ns.payload)
url = args.url
@ -89,14 +85,13 @@ class RemoteFileUploadApi(Resource):
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
payload = FileWithSignedUrl(
id=upload_file.id,
name=upload_file.name,
size=upload_file.size,
extension=upload_file.extension,
url=file_helpers.get_signed_file_url(upload_file_id=upload_file.id),
mime_type=upload_file.mime_type,
created_by=upload_file.created_by,
created_at=int(upload_file.created_at.timestamp()),
)
return payload.model_dump(mode="json"), 201
return {
"id": upload_file.id,
"name": upload_file.name,
"size": upload_file.size,
"extension": upload_file.extension,
"url": file_helpers.get_signed_file_url(upload_file_id=upload_file.id),
"mime_type": upload_file.mime_type,
"created_by": upload_file.created_by,
"created_at": upload_file.created_at,
}, 201

View File

@ -84,11 +84,10 @@ class SetupApi(Resource):
raise NotInitValidateError()
args = SetupRequestPayload.model_validate(console_ns.payload)
normalized_email = args.email.lower()
# setup
RegisterService.setup(
email=normalized_email,
email=args.email,
name=args.name,
password=args.password,
ip_address=extract_remote_ip(request),

View File

@ -1,5 +1,3 @@
from __future__ import annotations
from datetime import datetime
from typing import Literal
@ -41,7 +39,7 @@ from fields.member_fields import account_fields
from libs.datetime_utils import naive_utc_now
from libs.helper import EmailStr, TimestampField, extract_remote_ip, timezone
from libs.login import current_account_with_tenant, login_required
from models import AccountIntegrate, InvitationCode
from models import Account, AccountIntegrate, InvitationCode
from services.account_service import AccountService
from services.billing_service import BillingService
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
@ -101,7 +99,7 @@ class AccountPasswordPayload(BaseModel):
repeat_new_password: str
@model_validator(mode="after")
def check_passwords_match(self) -> AccountPasswordPayload:
def check_passwords_match(self) -> "AccountPasswordPayload":
if self.new_password != self.repeat_new_password:
raise RepeatPasswordNotMatchError()
return self
@ -536,8 +534,7 @@ class ChangeEmailSendEmailApi(Resource):
else:
language = "en-US"
account = None
user_email = None
email_for_sending = args.email.lower()
user_email = args.email
if args.phase is not None and args.phase == "new_email":
if args.token is None:
raise InvalidTokenError()
@ -547,24 +544,16 @@ class ChangeEmailSendEmailApi(Resource):
raise InvalidTokenError()
user_email = reset_data.get("email", "")
if user_email.lower() != current_user.email.lower():
if user_email != current_user.email:
raise InvalidEmailError()
user_email = current_user.email
else:
with Session(db.engine) as session:
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
if account is None:
raise AccountNotFound()
email_for_sending = account.email
user_email = account.email
token = AccountService.send_change_email_email(
account=account,
email=email_for_sending,
old_email=user_email,
language=language,
phase=args.phase,
account=account, email=args.email, old_email=user_email, language=language, phase=args.phase
)
return {"result": "success", "data": token}
@ -580,9 +569,9 @@ class ChangeEmailCheckApi(Resource):
payload = console_ns.payload or {}
args = ChangeEmailValidityPayload.model_validate(payload)
user_email = args.email.lower()
user_email = args.email
is_change_email_error_rate_limit = AccountService.is_change_email_error_rate_limit(user_email)
is_change_email_error_rate_limit = AccountService.is_change_email_error_rate_limit(args.email)
if is_change_email_error_rate_limit:
raise EmailChangeLimitError()
@ -590,13 +579,11 @@ class ChangeEmailCheckApi(Resource):
if token_data is None:
raise InvalidTokenError()
token_email = token_data.get("email")
normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
if user_email != normalized_token_email:
if user_email != token_data.get("email"):
raise InvalidEmailError()
if args.code != token_data.get("code"):
AccountService.add_change_email_error_rate_limit(user_email)
AccountService.add_change_email_error_rate_limit(args.email)
raise EmailCodeError()
# Verified, revoke the first token
@ -607,8 +594,8 @@ class ChangeEmailCheckApi(Resource):
user_email, code=args.code, old_email=token_data.get("old_email"), additional_data={}
)
AccountService.reset_change_email_error_rate_limit(user_email)
return {"is_valid": True, "email": normalized_token_email, "token": new_token}
AccountService.reset_change_email_error_rate_limit(args.email)
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
@console_ns.route("/account/change-email/reset")
@ -622,12 +609,11 @@ class ChangeEmailResetApi(Resource):
def post(self):
payload = console_ns.payload or {}
args = ChangeEmailResetPayload.model_validate(payload)
normalized_new_email = args.new_email.lower()
if AccountService.is_account_in_freeze(normalized_new_email):
if AccountService.is_account_in_freeze(args.new_email):
raise AccountInFreezeError()
if not AccountService.check_email_unique(normalized_new_email):
if not AccountService.check_email_unique(args.new_email):
raise EmailAlreadyInUseError()
reset_data = AccountService.get_change_email_data(args.token)
@ -638,13 +624,13 @@ class ChangeEmailResetApi(Resource):
old_email = reset_data.get("old_email", "")
current_user, _ = current_account_with_tenant()
if current_user.email.lower() != old_email.lower():
if current_user.email != old_email:
raise AccountNotFound()
updated_account = AccountService.update_account_email(current_user, email=normalized_new_email)
updated_account = AccountService.update_account_email(current_user, email=args.new_email)
AccountService.send_change_email_completed_notify_email(
email=normalized_new_email,
email=args.new_email,
)
return updated_account
@ -657,9 +643,8 @@ class CheckEmailUnique(Resource):
def post(self):
payload = console_ns.payload or {}
args = CheckEmailUniquePayload.model_validate(payload)
normalized_email = args.email.lower()
if AccountService.is_account_in_freeze(normalized_email):
if AccountService.is_account_in_freeze(args.email):
raise AccountInFreezeError()
if not AccountService.check_email_unique(normalized_email):
if not AccountService.check_email_unique(args.email):
raise EmailAlreadyInUseError()
return {"result": "success"}

View File

@ -116,31 +116,26 @@ class MemberInviteEmailApi(Resource):
raise WorkspaceMembersLimitExceeded()
for invitee_email in invitee_emails:
normalized_invitee_email = invitee_email.lower()
try:
if not inviter.current_tenant:
raise ValueError("No current tenant")
token = RegisterService.invite_new_member(
tenant=inviter.current_tenant,
email=invitee_email,
language=interface_language,
role=invitee_role,
inviter=inviter,
inviter.current_tenant, invitee_email, interface_language, role=invitee_role, inviter=inviter
)
encoded_invitee_email = parse.quote(normalized_invitee_email)
encoded_invitee_email = parse.quote(invitee_email)
invitation_results.append(
{
"status": "success",
"email": normalized_invitee_email,
"email": invitee_email,
"url": f"{console_web_url}/activate?email={encoded_invitee_email}&token={token}",
}
)
except AccountAlreadyInTenantError:
invitation_results.append(
{"status": "success", "email": normalized_invitee_email, "url": f"{console_web_url}/signin"}
{"status": "success", "email": invitee_email, "url": f"{console_web_url}/signin"}
)
except Exception as e:
invitation_results.append({"status": "failed", "email": normalized_invitee_email, "message": str(e)})
invitation_results.append({"status": "failed", "email": invitee_email, "message": str(e)})
return {
"result": "success",

View File

@ -1,14 +1,15 @@
import logging
from collections.abc import Mapping
from typing import Any
from flask import make_response, redirect, request
from flask_restx import Resource
from pydantic import BaseModel, model_validator
from flask_restx import Resource, reqparse
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, Forbidden
from configs import dify_config
from controllers.common.schema import register_schema_models
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
from controllers.web.error import NotFoundError
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin_daemon import CredentialType
@ -35,38 +36,29 @@ from ..wraps import (
logger = logging.getLogger(__name__)
class TriggerSubscriptionBuilderCreatePayload(BaseModel):
credential_type: str = CredentialType.UNAUTHORIZED
class TriggerSubscriptionUpdateRequest(BaseModel):
"""Request payload for updating a trigger subscription"""
name: str | None = Field(default=None, description="The name for the subscription")
credentials: Mapping[str, Any] | None = Field(default=None, description="The credentials for the subscription")
parameters: Mapping[str, Any] | None = Field(default=None, description="The parameters for the subscription")
properties: Mapping[str, Any] | None = Field(default=None, description="The properties for the subscription")
class TriggerSubscriptionBuilderVerifyPayload(BaseModel):
credentials: dict[str, Any]
class TriggerSubscriptionVerifyRequest(BaseModel):
"""Request payload for verifying subscription credentials."""
credentials: Mapping[str, Any] = Field(description="The credentials to verify")
class TriggerSubscriptionBuilderUpdatePayload(BaseModel):
name: str | None = None
parameters: dict[str, Any] | None = None
properties: dict[str, Any] | None = None
credentials: dict[str, Any] | None = None
console_ns.schema_model(
TriggerSubscriptionUpdateRequest.__name__,
TriggerSubscriptionUpdateRequest.model_json_schema(ref_template="#/definitions/{model}"),
)
@model_validator(mode="after")
def check_at_least_one_field(self):
if all(v is None for v in self.model_dump().values()):
raise ValueError("At least one of name, credentials, parameters, or properties must be provided")
return self
class TriggerOAuthClientPayload(BaseModel):
client_params: dict[str, Any] | None = None
enabled: bool | None = None
register_schema_models(
console_ns,
TriggerSubscriptionBuilderCreatePayload,
TriggerSubscriptionBuilderVerifyPayload,
TriggerSubscriptionBuilderUpdatePayload,
TriggerOAuthClientPayload,
console_ns.schema_model(
TriggerSubscriptionVerifyRequest.__name__,
TriggerSubscriptionVerifyRequest.model_json_schema(ref_template="#/definitions/{model}"),
)
@ -135,11 +127,16 @@ class TriggerSubscriptionListApi(Resource):
raise
parser = reqparse.RequestParser().add_argument(
"credential_type", type=str, required=False, nullable=True, location="json"
)
@console_ns.route(
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/create",
)
class TriggerSubscriptionBuilderCreateApi(Resource):
@console_ns.expect(console_ns.models[TriggerSubscriptionBuilderCreatePayload.__name__])
@console_ns.expect(parser)
@setup_required
@login_required
@edit_permission_required
@ -149,10 +146,10 @@ class TriggerSubscriptionBuilderCreateApi(Resource):
user = current_user
assert user.current_tenant_id is not None
payload = TriggerSubscriptionBuilderCreatePayload.model_validate(console_ns.payload or {})
args = parser.parse_args()
try:
credential_type = CredentialType.of(payload.credential_type)
credential_type = CredentialType.of(args.get("credential_type") or CredentialType.UNAUTHORIZED.value)
subscription_builder = TriggerSubscriptionBuilderService.create_trigger_subscription_builder(
tenant_id=user.current_tenant_id,
user_id=user.id,
@ -180,11 +177,18 @@ class TriggerSubscriptionBuilderGetApi(Resource):
)
parser_api = (
reqparse.RequestParser()
# The credentials of the subscription builder
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
)
@console_ns.route(
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/verify-and-update/<path:subscription_builder_id>",
)
class TriggerSubscriptionBuilderVerifyApi(Resource):
@console_ns.expect(console_ns.models[TriggerSubscriptionBuilderVerifyPayload.__name__])
class TriggerSubscriptionBuilderVerifyAndUpdateApi(Resource):
@console_ns.expect(parser_api)
@setup_required
@login_required
@edit_permission_required
@ -194,7 +198,7 @@ class TriggerSubscriptionBuilderVerifyApi(Resource):
user = current_user
assert user.current_tenant_id is not None
payload = TriggerSubscriptionBuilderVerifyPayload.model_validate(console_ns.payload or {})
args = parser_api.parse_args()
try:
# Use atomic update_and_verify to prevent race conditions
@ -204,7 +208,7 @@ class TriggerSubscriptionBuilderVerifyApi(Resource):
provider_id=TriggerProviderID(provider),
subscription_builder_id=subscription_builder_id,
subscription_builder_updater=SubscriptionBuilderUpdater(
credentials=payload.credentials,
credentials=args.get("credentials", None),
),
)
except Exception as e:
@ -212,11 +216,24 @@ class TriggerSubscriptionBuilderVerifyApi(Resource):
raise ValueError(str(e)) from e
parser_update_api = (
reqparse.RequestParser()
# The name of the subscription builder
.add_argument("name", type=str, required=False, nullable=True, location="json")
# The parameters of the subscription builder
.add_argument("parameters", type=dict, required=False, nullable=True, location="json")
# The properties of the subscription builder
.add_argument("properties", type=dict, required=False, nullable=True, location="json")
# The credentials of the subscription builder
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
)
@console_ns.route(
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/update/<path:subscription_builder_id>",
)
class TriggerSubscriptionBuilderUpdateApi(Resource):
@console_ns.expect(console_ns.models[TriggerSubscriptionBuilderUpdatePayload.__name__])
@console_ns.expect(parser_update_api)
@setup_required
@login_required
@edit_permission_required
@ -227,7 +244,7 @@ class TriggerSubscriptionBuilderUpdateApi(Resource):
assert isinstance(user, Account)
assert user.current_tenant_id is not None
payload = TriggerSubscriptionBuilderUpdatePayload.model_validate(console_ns.payload or {})
args = parser_update_api.parse_args()
try:
return jsonable_encoder(
TriggerSubscriptionBuilderService.update_trigger_subscription_builder(
@ -235,10 +252,10 @@ class TriggerSubscriptionBuilderUpdateApi(Resource):
provider_id=TriggerProviderID(provider),
subscription_builder_id=subscription_builder_id,
subscription_builder_updater=SubscriptionBuilderUpdater(
name=payload.name,
parameters=payload.parameters,
properties=payload.properties,
credentials=payload.credentials,
name=args.get("name", None),
parameters=args.get("parameters", None),
properties=args.get("properties", None),
credentials=args.get("credentials", None),
),
)
)
@ -273,7 +290,7 @@ class TriggerSubscriptionBuilderLogsApi(Resource):
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/build/<path:subscription_builder_id>",
)
class TriggerSubscriptionBuilderBuildApi(Resource):
@console_ns.expect(console_ns.models[TriggerSubscriptionBuilderUpdatePayload.__name__])
@console_ns.expect(parser_update_api)
@setup_required
@login_required
@edit_permission_required
@ -282,7 +299,7 @@ class TriggerSubscriptionBuilderBuildApi(Resource):
"""Build a subscription instance for a trigger provider"""
user = current_user
assert user.current_tenant_id is not None
payload = TriggerSubscriptionBuilderUpdatePayload.model_validate(console_ns.payload or {})
args = parser_update_api.parse_args()
try:
# Use atomic update_and_build to prevent race conditions
TriggerSubscriptionBuilderService.update_and_build_builder(
@ -291,9 +308,9 @@ class TriggerSubscriptionBuilderBuildApi(Resource):
provider_id=TriggerProviderID(provider),
subscription_builder_id=subscription_builder_id,
subscription_builder_updater=SubscriptionBuilderUpdater(
name=payload.name,
parameters=payload.parameters,
properties=payload.properties,
name=args.get("name", None),
parameters=args.get("parameters", None),
properties=args.get("properties", None),
),
)
return 200
@ -306,7 +323,7 @@ class TriggerSubscriptionBuilderBuildApi(Resource):
"/workspaces/current/trigger-provider/<path:subscription_id>/subscriptions/update",
)
class TriggerSubscriptionUpdateApi(Resource):
@console_ns.expect(console_ns.models[TriggerSubscriptionBuilderUpdatePayload.__name__])
@console_ns.expect(console_ns.models[TriggerSubscriptionUpdateRequest.__name__])
@setup_required
@login_required
@edit_permission_required
@ -316,7 +333,7 @@ class TriggerSubscriptionUpdateApi(Resource):
user = current_user
assert user.current_tenant_id is not None
request = TriggerSubscriptionBuilderUpdatePayload.model_validate(console_ns.payload or {})
args = TriggerSubscriptionUpdateRequest.model_validate(console_ns.payload)
subscription = TriggerProviderService.get_subscription_by_id(
tenant_id=user.current_tenant_id,
@ -328,32 +345,50 @@ class TriggerSubscriptionUpdateApi(Resource):
provider_id = TriggerProviderID(subscription.provider_id)
try:
# For rename only, just update the name
rename = request.name is not None and not any((request.credentials, request.parameters, request.properties))
# When credential type is UNAUTHORIZED, it indicates the subscription was manually created
# For Manually created subscription, they dont have credentials, parameters
# They only have name and properties(which is input by user)
manually_created = subscription.credential_type == CredentialType.UNAUTHORIZED
if rename or manually_created:
# rename only
if (
args.name is not None
and args.credentials is None
and args.parameters is None
and args.properties is None
):
TriggerProviderService.update_trigger_subscription(
tenant_id=user.current_tenant_id,
subscription_id=subscription_id,
name=request.name,
properties=request.properties,
name=args.name,
)
return 200
# For the rest cases(API_KEY, OAUTH2)
# we need to call third party provider(e.g. GitHub) to rebuild the subscription
TriggerProviderService.rebuild_trigger_subscription(
tenant_id=user.current_tenant_id,
name=request.name,
provider_id=provider_id,
subscription_id=subscription_id,
credentials=request.credentials or subscription.credentials,
parameters=request.parameters or subscription.parameters,
)
return 200
# rebuild for create automatically by the provider
match subscription.credential_type:
case CredentialType.UNAUTHORIZED:
TriggerProviderService.update_trigger_subscription(
tenant_id=user.current_tenant_id,
subscription_id=subscription_id,
name=args.name,
properties=args.properties,
)
return 200
case CredentialType.API_KEY | CredentialType.OAUTH2:
if args.credentials:
new_credentials: dict[str, Any] = {
key: value if value != HIDDEN_VALUE else subscription.credentials.get(key, UNKNOWN_VALUE)
for key, value in args.credentials.items()
}
else:
new_credentials = subscription.credentials
TriggerProviderService.rebuild_trigger_subscription(
tenant_id=user.current_tenant_id,
name=args.name,
provider_id=provider_id,
subscription_id=subscription_id,
credentials=new_credentials,
parameters=args.parameters or subscription.parameters,
)
return 200
case _:
raise BadRequest("Invalid credential type")
except ValueError as e:
raise BadRequest(str(e))
except Exception as e:
@ -546,6 +581,13 @@ class TriggerOAuthCallbackApi(Resource):
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
parser_oauth_client = (
reqparse.RequestParser()
.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
.add_argument("enabled", type=bool, required=False, nullable=True, location="json")
)
@console_ns.route("/workspaces/current/trigger-provider/<path:provider>/oauth/client")
class TriggerOAuthClientManageApi(Resource):
@setup_required
@ -593,7 +635,7 @@ class TriggerOAuthClientManageApi(Resource):
logger.exception("Error getting OAuth client", exc_info=e)
raise
@console_ns.expect(console_ns.models[TriggerOAuthClientPayload.__name__])
@console_ns.expect(parser_oauth_client)
@setup_required
@login_required
@is_admin_or_owner_required
@ -603,15 +645,15 @@ class TriggerOAuthClientManageApi(Resource):
user = current_user
assert user.current_tenant_id is not None
payload = TriggerOAuthClientPayload.model_validate(console_ns.payload or {})
args = parser_oauth_client.parse_args()
try:
provider_id = TriggerProviderID(provider)
return TriggerProviderService.save_custom_oauth_client_params(
tenant_id=user.current_tenant_id,
provider_id=provider_id,
client_params=payload.client_params,
enabled=payload.enabled,
client_params=args.get("client_params"),
enabled=args.get("enabled"),
)
except ValueError as e:
@ -647,7 +689,7 @@ class TriggerOAuthClientManageApi(Resource):
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/verify/<path:subscription_id>",
)
class TriggerSubscriptionVerifyApi(Resource):
@console_ns.expect(console_ns.models[TriggerSubscriptionBuilderVerifyPayload.__name__])
@console_ns.expect(console_ns.models[TriggerSubscriptionVerifyRequest.__name__])
@setup_required
@login_required
@edit_permission_required
@ -657,7 +699,9 @@ class TriggerSubscriptionVerifyApi(Resource):
user = current_user
assert user.current_tenant_id is not None
verify_request = TriggerSubscriptionBuilderVerifyPayload.model_validate(console_ns.payload or {})
verify_request: TriggerSubscriptionVerifyRequest = TriggerSubscriptionVerifyRequest.model_validate(
console_ns.payload
)
try:
result = TriggerProviderService.verify_subscription_credentials(

View File

@ -80,9 +80,6 @@ tenant_fields = {
"in_trial": fields.Boolean,
"trial_end_reason": fields.String,
"custom_config": fields.Raw(attribute="custom_config"),
"trial_credits": fields.Integer,
"trial_credits_used": fields.Integer,
"next_credit_reset_date": fields.Integer,
}
tenants_fields = {

View File

@ -358,12 +358,14 @@ def annotation_import_rate_limit(view: Callable[P, R]):
def decorated(*args: P.args, **kwargs: P.kwargs):
_, current_tenant_id = current_account_with_tenant()
current_time = int(time.time() * 1000)
# Check per-minute rate limit
minute_key = f"annotation_import_rate_limit:{current_tenant_id}:1min"
redis_client.zadd(minute_key, {current_time: current_time})
redis_client.zremrangebyscore(minute_key, 0, current_time - 60000)
minute_count = redis_client.zcard(minute_key)
redis_client.expire(minute_key, 120) # 2 minutes TTL
if minute_count > dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE:
abort(
429,
@ -377,6 +379,7 @@ def annotation_import_rate_limit(view: Callable[P, R]):
redis_client.zremrangebyscore(hour_key, 0, current_time - 3600000)
hour_count = redis_client.zcard(hour_key)
redis_client.expire(hour_key, 7200) # 2 hours TTL
if hour_count > dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR:
abort(
429,

View File

@ -4,18 +4,18 @@ from flask import request
from flask_restx import Resource
from flask_restx.api import HTTPStatus
from pydantic import BaseModel, Field
from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import Forbidden
import services
from core.file.helpers import verify_plugin_file_signature
from core.tools.tool_file_manager import ToolFileManager
from fields.file_fields import FileResponse
from fields.file_fields import build_file_model
from ..common.errors import (
FileTooLargeError,
UnsupportedFileTypeError,
)
from ..common.schema import register_schema_models
from ..console.wraps import setup_required
from ..files import files_ns
from ..inner_api.plugin.wraps import get_user
@ -35,8 +35,6 @@ files_ns.schema_model(
PluginUploadQuery.__name__, PluginUploadQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
register_schema_models(files_ns, FileResponse)
@files_ns.route("/upload/for-plugin")
class PluginUploadFileApi(Resource):
@ -53,7 +51,7 @@ class PluginUploadFileApi(Resource):
415: "Unsupported file type",
}
)
@files_ns.response(HTTPStatus.CREATED, "File uploaded", files_ns.models[FileResponse.__name__])
@files_ns.marshal_with(build_file_model(files_ns), code=HTTPStatus.CREATED)
def post(self):
"""Upload a file for plugin usage.
@ -71,7 +69,7 @@ class PluginUploadFileApi(Resource):
"""
args = PluginUploadQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
file = request.files.get("file")
file: FileStorage | None = request.files.get("file")
if file is None:
raise Forbidden("File is required.")
@ -82,8 +80,8 @@ class PluginUploadFileApi(Resource):
user_id = args.user_id
user = get_user(tenant_id, user_id)
filename = file.filename
mimetype = file.mimetype
filename: str | None = file.filename
mimetype: str | None = file.mimetype
if not filename or not mimetype:
raise Forbidden("Invalid request.")
@ -113,22 +111,22 @@ class PluginUploadFileApi(Resource):
preview_url = ToolFileManager.sign_file(tool_file_id=tool_file.id, extension=extension)
# Create a dictionary with all the necessary attributes
result = FileResponse(
id=tool_file.id,
name=tool_file.name,
size=tool_file.size,
extension=extension,
mime_type=mimetype,
preview_url=preview_url,
source_url=tool_file.original_url,
original_url=tool_file.original_url,
user_id=tool_file.user_id,
tenant_id=tool_file.tenant_id,
conversation_id=tool_file.conversation_id,
file_key=tool_file.file_key,
)
result = {
"id": tool_file.id,
"user_id": tool_file.user_id,
"tenant_id": tool_file.tenant_id,
"conversation_id": tool_file.conversation_id,
"file_key": tool_file.file_key,
"mimetype": tool_file.mimetype,
"original_url": tool_file.original_url,
"name": tool_file.name,
"size": tool_file.size,
"mime_type": mimetype,
"extension": extension,
"preview_url": preview_url,
}
return result.model_dump(mode="json"), 201
return result, 201
except services.errors.file.FileTooLargeError as file_too_large_error:
raise FileTooLargeError(file_too_large_error.description)
except services.errors.file.UnsupportedFileTypeError:

View File

@ -10,16 +10,13 @@ from controllers.common.errors import (
TooManyFilesError,
UnsupportedFileTypeError,
)
from controllers.common.schema import register_schema_models
from controllers.service_api import service_api_ns
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from extensions.ext_database import db
from fields.file_fields import FileResponse
from fields.file_fields import build_file_model
from models import App, EndUser
from services.file_service import FileService
register_schema_models(service_api_ns, FileResponse)
@service_api_ns.route("/files/upload")
class FileApi(Resource):
@ -34,8 +31,8 @@ class FileApi(Resource):
415: "Unsupported file type",
}
)
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM)) # type: ignore
@service_api_ns.response(HTTPStatus.CREATED, "File uploaded", service_api_ns.models[FileResponse.__name__])
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM))
@service_api_ns.marshal_with(build_file_model(service_api_ns), code=HTTPStatus.CREATED)
def post(self, app_model: App, end_user: EndUser):
"""Upload a file for use in conversations.
@ -67,5 +64,4 @@ class FileApi(Resource):
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
response = FileResponse.model_validate(upload_file, from_attributes=True)
return response.model_dump(mode="json"), 201
return upload_file, 201

View File

@ -24,7 +24,7 @@ class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase):
dataset_id_str = str(dataset_id)
dataset = self.get_and_validate_dataset(dataset_id_str)
args = self.parse_args(service_api_ns.payload)
args = self.parse_args()
self.hit_testing_args_check(args)
return self.perform_hit_testing(dataset, args)

View File

@ -174,7 +174,7 @@ class PipelineRunApi(DatasetApiResource):
pipeline=pipeline,
user=current_user,
args=payload.model_dump(),
invoke_from=InvokeFrom.PUBLISHED_PIPELINE if payload.is_published else InvokeFrom.DEBUGGER,
invoke_from=InvokeFrom.PUBLISHED if payload.is_published else InvokeFrom.DEBUGGER,
streaming=payload.response_mode == "streaming",
)

View File

@ -1,11 +1,9 @@
from typing import Literal
from flask import request
from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator
from flask_restx import reqparse
from flask_restx.inputs import int_range
from pydantic import TypeAdapter
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
from controllers.common.schema import register_schema_models
from controllers.web import web_ns
from controllers.web.error import NotChatAppError
from controllers.web.wraps import WebApiResource
@ -23,35 +21,6 @@ from services.errors.conversation import ConversationNotExistsError, LastConvers
from services.web_conversation_service import WebConversationService
class ConversationListQuery(BaseModel):
last_id: str | None = None
limit: int = Field(default=20, ge=1, le=100)
pinned: bool | None = None
sort_by: Literal["created_at", "-created_at", "updated_at", "-updated_at"] = "-updated_at"
@field_validator("last_id")
@classmethod
def validate_last_id(cls, value: str | None) -> str | None:
if value is None:
return value
return uuid_value(value)
class ConversationRenamePayload(BaseModel):
name: str | None = None
auto_generate: bool = False
@model_validator(mode="after")
def validate_name_requirement(self):
if not self.auto_generate:
if self.name is None or not self.name.strip():
raise ValueError("name is required when auto_generate is false")
return self
register_schema_models(web_ns, ConversationListQuery, ConversationRenamePayload)
@web_ns.route("/conversations")
class ConversationListApi(WebApiResource):
@web_ns.doc("Get Conversation List")
@ -95,8 +64,25 @@ class ConversationListApi(WebApiResource):
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
raw_args = request.args.to_dict()
query = ConversationListQuery.model_validate(raw_args)
parser = (
reqparse.RequestParser()
.add_argument("last_id", type=uuid_value, location="args")
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
.add_argument("pinned", type=str, choices=["true", "false", None], location="args")
.add_argument(
"sort_by",
type=str,
choices=["created_at", "-created_at", "updated_at", "-updated_at"],
required=False,
default="-updated_at",
location="args",
)
)
args = parser.parse_args()
pinned = None
if "pinned" in args and args["pinned"] is not None:
pinned = args["pinned"] == "true"
try:
with Session(db.engine) as session:
@ -104,11 +90,11 @@ class ConversationListApi(WebApiResource):
session=session,
app_model=app_model,
user=end_user,
last_id=query.last_id,
limit=query.limit,
last_id=args["last_id"],
limit=args["limit"],
invoke_from=InvokeFrom.WEB_APP,
pinned=query.pinned,
sort_by=query.sort_by,
pinned=pinned,
sort_by=args["sort_by"],
)
adapter = TypeAdapter(SimpleConversation)
conversations = [adapter.validate_python(item, from_attributes=True) for item in pagination.data]
@ -182,11 +168,16 @@ class ConversationRenameApi(WebApiResource):
conversation_id = str(c_id)
payload = ConversationRenamePayload.model_validate(web_ns.payload or {})
parser = (
reqparse.RequestParser()
.add_argument("name", type=str, required=False, location="json")
.add_argument("auto_generate", type=bool, required=False, default=False, location="json")
)
args = parser.parse_args()
try:
conversation = ConversationService.rename(
app_model, conversation_id, end_user, payload.name, payload.auto_generate
app_model, conversation_id, end_user, args["name"], args["auto_generate"]
)
return (
TypeAdapter(SimpleConversation)

View File

@ -1,4 +1,5 @@
from flask import request
from flask_restx import marshal_with
import services
from controllers.common.errors import (
@ -8,15 +9,12 @@ from controllers.common.errors import (
TooManyFilesError,
UnsupportedFileTypeError,
)
from controllers.common.schema import register_schema_models
from controllers.web import web_ns
from controllers.web.wraps import WebApiResource
from extensions.ext_database import db
from fields.file_fields import FileResponse
from fields.file_fields import build_file_model
from services.file_service import FileService
register_schema_models(web_ns, FileResponse)
@web_ns.route("/files/upload")
class FileApi(WebApiResource):
@ -30,7 +28,7 @@ class FileApi(WebApiResource):
415: "Unsupported file type",
}
)
@web_ns.response(201, "File uploaded successfully", web_ns.models[FileResponse.__name__])
@marshal_with(build_file_model(web_ns))
def post(self, app_model, end_user):
"""Upload a file for use in web applications.
@ -83,5 +81,4 @@ class FileApi(WebApiResource):
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
response = FileResponse.model_validate(upload_file, from_attributes=True)
return response.model_dump(mode="json"), 201
return upload_file, 201

View File

@ -4,6 +4,7 @@ import secrets
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from sqlalchemy.orm import Session
from controllers.common.schema import register_schema_models
@ -21,7 +22,7 @@ from controllers.web import web_ns
from extensions.ext_database import db
from libs.helper import EmailStr, extract_remote_ip
from libs.password import hash_password, valid_password
from models.account import Account
from models import Account
from services.account_service import AccountService
@ -69,9 +70,6 @@ class ForgotPasswordSendEmailApi(Resource):
def post(self):
payload = ForgotPasswordSendPayload.model_validate(web_ns.payload or {})
request_email = payload.email
normalized_email = request_email.lower()
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError()
@ -82,12 +80,12 @@ class ForgotPasswordSendEmailApi(Resource):
language = "en-US"
with Session(db.engine) as session:
account = AccountService.get_account_by_email_with_case_fallback(request_email, session=session)
account = session.execute(select(Account).filter_by(email=payload.email)).scalar_one_or_none()
token = None
if account is None:
raise AuthenticationFailedError()
else:
token = AccountService.send_reset_password_email(account=account, email=normalized_email, language=language)
token = AccountService.send_reset_password_email(account=account, email=payload.email, language=language)
return {"result": "success", "data": token}
@ -106,9 +104,9 @@ class ForgotPasswordCheckApi(Resource):
def post(self):
payload = ForgotPasswordCheckPayload.model_validate(web_ns.payload or {})
user_email = payload.email.lower()
user_email = payload.email
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(user_email)
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(payload.email)
if is_forgot_password_error_rate_limit:
raise EmailPasswordResetLimitError()
@ -116,16 +114,11 @@ class ForgotPasswordCheckApi(Resource):
if token_data is None:
raise InvalidTokenError()
token_email = token_data.get("email")
if not isinstance(token_email, str):
raise InvalidEmailError()
normalized_token_email = token_email.lower()
if user_email != normalized_token_email:
if user_email != token_data.get("email"):
raise InvalidEmailError()
if payload.code != token_data.get("code"):
AccountService.add_forgot_password_error_rate_limit(user_email)
AccountService.add_forgot_password_error_rate_limit(payload.email)
raise EmailCodeError()
# Verified, revoke the first token
@ -133,11 +126,11 @@ class ForgotPasswordCheckApi(Resource):
# Refresh token data by generating a new token
_, new_token = AccountService.generate_reset_password_token(
token_email, code=payload.code, additional_data={"phase": "reset"}
user_email, code=payload.code, additional_data={"phase": "reset"}
)
AccountService.reset_forgot_password_error_rate_limit(user_email)
return {"is_valid": True, "email": normalized_token_email, "token": new_token}
AccountService.reset_forgot_password_error_rate_limit(payload.email)
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
@web_ns.route("/forgot-password/resets")
@ -181,7 +174,7 @@ class ForgotPasswordResetApi(Resource):
email = reset_data.get("email", "")
with Session(db.engine) as session:
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
if account:
self._update_existing_account(account, password_hashed, salt, session)

View File

@ -10,12 +10,7 @@ from controllers.console.auth.error import (
InvalidEmailError,
)
from controllers.console.error import AccountBannedError
from controllers.console.wraps import (
decrypt_code_field,
decrypt_password_field,
only_edition_enterprise,
setup_required,
)
from controllers.console.wraps import only_edition_enterprise, setup_required
from controllers.web import web_ns
from controllers.web.wraps import decode_jwt_token
from libs.helper import email
@ -47,7 +42,6 @@ class LoginApi(Resource):
404: "Account not found",
}
)
@decrypt_password_field
def post(self):
"""Authenticate user and login."""
parser = (
@ -187,7 +181,6 @@ class EmailCodeLoginApi(Resource):
404: "Account not found",
}
)
@decrypt_code_field
def post(self):
parser = (
reqparse.RequestParser()
@ -197,29 +190,25 @@ class EmailCodeLoginApi(Resource):
)
args = parser.parse_args()
user_email = args["email"].lower()
user_email = args["email"]
token_data = WebAppAuthService.get_email_code_login_data(args["token"])
if token_data is None:
raise InvalidTokenError()
token_email = token_data.get("email")
if not isinstance(token_email, str):
raise InvalidEmailError()
normalized_token_email = token_email.lower()
if normalized_token_email != user_email:
if token_data["email"] != args["email"]:
raise InvalidEmailError()
if token_data["code"] != args["code"]:
raise EmailCodeError()
WebAppAuthService.revoke_email_code_login_token(args["token"])
account = WebAppAuthService.get_user_through_email(token_email)
account = WebAppAuthService.get_user_through_email(user_email)
if not account:
raise AuthenticationFailedError()
token = WebAppAuthService.login(account=account)
AccountService.reset_login_error_rate_limit(user_email)
AccountService.reset_login_error_rate_limit(args["email"])
response = make_response({"result": "success", "data": {"access_token": token}})
# set_access_token_to_cookie(request, response, token, samesite="None", httponly=False)
return response

View File

@ -1,6 +1,7 @@
import urllib.parse
import httpx
from flask_restx import marshal_with
from pydantic import BaseModel, Field, HttpUrl
import services
@ -13,7 +14,7 @@ from controllers.common.errors import (
from core.file import helpers as file_helpers
from core.helper import ssrf_proxy
from extensions.ext_database import db
from fields.file_fields import FileWithSignedUrl, RemoteFileInfo
from fields.file_fields import build_file_with_signed_url_model, build_remote_file_info_model
from services.file_service import FileService
from ..common.schema import register_schema_models
@ -25,7 +26,7 @@ class RemoteFileUploadPayload(BaseModel):
url: HttpUrl = Field(description="Remote file URL")
register_schema_models(web_ns, RemoteFileUploadPayload, RemoteFileInfo, FileWithSignedUrl)
register_schema_models(web_ns, RemoteFileUploadPayload)
@web_ns.route("/remote-files/<path:url>")
@ -40,7 +41,7 @@ class RemoteFileInfoApi(WebApiResource):
500: "Failed to fetch remote file",
}
)
@web_ns.response(200, "Remote file info", web_ns.models[RemoteFileInfo.__name__])
@marshal_with(build_remote_file_info_model(web_ns))
def get(self, app_model, end_user, url):
"""Get information about a remote file.
@ -64,11 +65,10 @@ class RemoteFileInfoApi(WebApiResource):
# failed back to get method
resp = ssrf_proxy.get(decoded_url, timeout=3)
resp.raise_for_status()
info = RemoteFileInfo(
file_type=resp.headers.get("Content-Type", "application/octet-stream"),
file_length=int(resp.headers.get("Content-Length", -1)),
)
return info.model_dump(mode="json")
return {
"file_type": resp.headers.get("Content-Type", "application/octet-stream"),
"file_length": int(resp.headers.get("Content-Length", -1)),
}
@web_ns.route("/remote-files/upload")
@ -84,7 +84,7 @@ class RemoteFileUploadApi(WebApiResource):
500: "Failed to fetch remote file",
}
)
@web_ns.response(201, "Remote file uploaded", web_ns.models[FileWithSignedUrl.__name__])
@marshal_with(build_file_with_signed_url_model(web_ns))
def post(self, app_model, end_user):
"""Upload a file from a remote URL.
@ -139,14 +139,13 @@ class RemoteFileUploadApi(WebApiResource):
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError
payload1 = FileWithSignedUrl(
id=upload_file.id,
name=upload_file.name,
size=upload_file.size,
extension=upload_file.extension,
url=file_helpers.get_signed_file_url(upload_file_id=upload_file.id),
mime_type=upload_file.mime_type,
created_by=upload_file.created_by,
created_at=int(upload_file.created_at.timestamp()),
)
return payload1.model_dump(mode="json"), 201
return {
"id": upload_file.id,
"name": upload_file.name,
"size": upload_file.size,
"extension": upload_file.extension,
"url": file_helpers.get_signed_file_url(upload_file_id=upload_file.id),
"mime_type": upload_file.mime_type,
"created_by": upload_file.created_by,
"created_at": upload_file.created_at,
}, 201

View File

@ -1,30 +1,18 @@
from flask import request
from pydantic import BaseModel, Field, TypeAdapter
from flask_restx import reqparse
from flask_restx.inputs import int_range
from pydantic import TypeAdapter
from werkzeug.exceptions import NotFound
from controllers.common.schema import register_schema_models
from controllers.web import web_ns
from controllers.web.error import NotCompletionAppError
from controllers.web.wraps import WebApiResource
from fields.conversation_fields import ResultResponse
from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem
from libs.helper import UUIDStrOrEmpty
from libs.helper import uuid_value
from services.errors.message import MessageNotExistsError
from services.saved_message_service import SavedMessageService
class SavedMessageListQuery(BaseModel):
last_id: UUIDStrOrEmpty | None = None
limit: int = Field(default=20, ge=1, le=100)
class SavedMessageCreatePayload(BaseModel):
message_id: UUIDStrOrEmpty
register_schema_models(web_ns, SavedMessageListQuery, SavedMessageCreatePayload)
@web_ns.route("/saved-messages")
class SavedMessageListApi(WebApiResource):
@web_ns.doc("Get Saved Messages")
@ -54,10 +42,14 @@ class SavedMessageListApi(WebApiResource):
if app_model.mode != "completion":
raise NotCompletionAppError()
raw_args = request.args.to_dict()
query = SavedMessageListQuery.model_validate(raw_args)
parser = (
reqparse.RequestParser()
.add_argument("last_id", type=uuid_value, location="args")
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
)
args = parser.parse_args()
pagination = SavedMessageService.pagination_by_last_id(app_model, end_user, query.last_id, query.limit)
pagination = SavedMessageService.pagination_by_last_id(app_model, end_user, args["last_id"], args["limit"])
adapter = TypeAdapter(SavedMessageItem)
items = [adapter.validate_python(message, from_attributes=True) for message in pagination.data]
return SavedMessageInfiniteScrollPagination(
@ -87,10 +79,11 @@ class SavedMessageListApi(WebApiResource):
if app_model.mode != "completion":
raise NotCompletionAppError()
payload = SavedMessageCreatePayload.model_validate(web_ns.payload or {})
parser = reqparse.RequestParser().add_argument("message_id", type=uuid_value, required=True, location="json")
args = parser.parse_args()
try:
SavedMessageService.save(app_model, end_user, payload.message_id)
SavedMessageService.save(app_model, end_user, args["message_id"])
except MessageNotExistsError:
raise NotFound("Message Not Exists.")

View File

@ -1,380 +0,0 @@
import logging
from collections.abc import Generator
from copy import deepcopy
from typing import Any
from core.agent.base_agent_runner import BaseAgentRunner
from core.agent.entities import AgentEntity, AgentLog, AgentResult
from core.agent.patterns.strategy_factory import StrategyFactory
from core.app.apps.base_app_queue_manager import PublishFrom
from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
from core.file import file_manager
from core.model_runtime.entities import (
AssistantPromptMessage,
LLMResult,
LLMResultChunk,
LLMUsage,
PromptMessage,
PromptMessageContentType,
SystemPromptMessage,
TextPromptMessageContent,
UserPromptMessage,
)
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import ToolInvokeMeta
from core.tools.tool_engine import ToolEngine
from models.model import Message
logger = logging.getLogger(__name__)
class AgentAppRunner(BaseAgentRunner):
def _create_tool_invoke_hook(self, message: Message):
"""
Create a tool invoke hook that uses ToolEngine.agent_invoke.
This hook handles file creation and returns proper meta information.
"""
# Get trace manager from app generate entity
trace_manager = self.application_generate_entity.trace_manager
def tool_invoke_hook(
tool: Tool, tool_args: dict[str, Any], tool_name: str
) -> tuple[str, list[str], ToolInvokeMeta]:
"""Hook that uses agent_invoke for proper file and meta handling."""
tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke(
tool=tool,
tool_parameters=tool_args,
user_id=self.user_id,
tenant_id=self.tenant_id,
message=message,
invoke_from=self.application_generate_entity.invoke_from,
agent_tool_callback=self.agent_callback,
trace_manager=trace_manager,
app_id=self.application_generate_entity.app_config.app_id,
message_id=message.id,
conversation_id=self.conversation.id,
)
# Publish files and track IDs
for message_file_id in message_files:
self.queue_manager.publish(
QueueMessageFileEvent(message_file_id=message_file_id),
PublishFrom.APPLICATION_MANAGER,
)
self._current_message_file_ids.append(message_file_id)
return tool_invoke_response, message_files, tool_invoke_meta
return tool_invoke_hook
def run(self, message: Message, query: str, **kwargs: Any) -> Generator[LLMResultChunk, None, None]:
"""
Run Agent application
"""
self.query = query
app_generate_entity = self.application_generate_entity
app_config = self.app_config
assert app_config is not None, "app_config is required"
assert app_config.agent is not None, "app_config.agent is required"
# convert tools into ModelRuntime Tool format
tool_instances, _ = self._init_prompt_tools()
assert app_config.agent
# Create tool invoke hook for agent_invoke
tool_invoke_hook = self._create_tool_invoke_hook(message)
# Get instruction for ReAct strategy
instruction = self.app_config.prompt_template.simple_prompt_template or ""
# Use factory to create appropriate strategy
strategy = StrategyFactory.create_strategy(
model_features=self.model_features,
model_instance=self.model_instance,
tools=list(tool_instances.values()),
files=list(self.files),
max_iterations=app_config.agent.max_iteration,
context=self.build_execution_context(),
agent_strategy=self.config.strategy,
tool_invoke_hook=tool_invoke_hook,
instruction=instruction,
)
# Initialize state variables
current_agent_thought_id = None
has_published_thought = False
current_tool_name: str | None = None
self._current_message_file_ids: list[str] = []
# organize prompt messages
prompt_messages = self._organize_prompt_messages()
# Run strategy
generator = strategy.run(
prompt_messages=prompt_messages,
model_parameters=app_generate_entity.model_conf.parameters,
stop=app_generate_entity.model_conf.stop,
stream=True,
)
# Consume generator and collect result
result: AgentResult | None = None
try:
while True:
try:
output = next(generator)
except StopIteration as e:
# Generator finished, get the return value
result = e.value
break
if isinstance(output, LLMResultChunk):
# Handle LLM chunk
if current_agent_thought_id and not has_published_thought:
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
PublishFrom.APPLICATION_MANAGER,
)
has_published_thought = True
yield output
elif isinstance(output, AgentLog):
# Handle Agent Log using log_type for type-safe dispatch
if output.status == AgentLog.LogStatus.START:
if output.log_type == AgentLog.LogType.ROUND:
# Start of a new round
message_file_ids: list[str] = []
current_agent_thought_id = self.create_agent_thought(
message_id=message.id,
message="",
tool_name="",
tool_input="",
messages_ids=message_file_ids,
)
has_published_thought = False
elif output.log_type == AgentLog.LogType.TOOL_CALL:
if current_agent_thought_id is None:
continue
# Tool call start - extract data from structured fields
current_tool_name = output.data.get("tool_name", "")
tool_input = output.data.get("tool_args", {})
self.save_agent_thought(
agent_thought_id=current_agent_thought_id,
tool_name=current_tool_name,
tool_input=tool_input,
thought=None,
observation=None,
tool_invoke_meta=None,
answer=None,
messages_ids=[],
)
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
PublishFrom.APPLICATION_MANAGER,
)
elif output.status == AgentLog.LogStatus.SUCCESS:
if output.log_type == AgentLog.LogType.THOUGHT:
if current_agent_thought_id is None:
continue
thought_text = output.data.get("thought")
self.save_agent_thought(
agent_thought_id=current_agent_thought_id,
tool_name=None,
tool_input=None,
thought=thought_text,
observation=None,
tool_invoke_meta=None,
answer=None,
messages_ids=[],
)
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
PublishFrom.APPLICATION_MANAGER,
)
elif output.log_type == AgentLog.LogType.TOOL_CALL:
if current_agent_thought_id is None:
continue
# Tool call finished
tool_output = output.data.get("output")
# Get meta from strategy output (now properly populated)
tool_meta = output.data.get("meta")
# Wrap tool_meta with tool_name as key (required by agent_service)
if tool_meta and current_tool_name:
tool_meta = {current_tool_name: tool_meta}
self.save_agent_thought(
agent_thought_id=current_agent_thought_id,
tool_name=None,
tool_input=None,
thought=None,
observation=tool_output,
tool_invoke_meta=tool_meta,
answer=None,
messages_ids=self._current_message_file_ids,
)
# Clear message file ids after saving
self._current_message_file_ids = []
current_tool_name = None
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
PublishFrom.APPLICATION_MANAGER,
)
elif output.log_type == AgentLog.LogType.ROUND:
if current_agent_thought_id is None:
continue
# Round finished - save LLM usage and answer
llm_usage = output.metadata.get(AgentLog.LogMetadata.LLM_USAGE)
llm_result = output.data.get("llm_result")
final_answer = output.data.get("final_answer")
self.save_agent_thought(
agent_thought_id=current_agent_thought_id,
tool_name=None,
tool_input=None,
thought=llm_result,
observation=None,
tool_invoke_meta=None,
answer=final_answer,
messages_ids=[],
llm_usage=llm_usage,
)
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
PublishFrom.APPLICATION_MANAGER,
)
except Exception:
# Re-raise any other exceptions
raise
# Process final result
if isinstance(result, AgentResult):
final_answer = result.text
usage = result.usage or LLMUsage.empty_usage()
# Publish end event
self.queue_manager.publish(
QueueMessageEndEvent(
llm_result=LLMResult(
model=self.model_instance.model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=final_answer),
usage=usage,
system_fingerprint="",
)
),
PublishFrom.APPLICATION_MANAGER,
)
def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""
Initialize system message
"""
if not prompt_template:
return prompt_messages or []
prompt_messages = prompt_messages or []
if prompt_messages and isinstance(prompt_messages[0], SystemPromptMessage):
prompt_messages[0] = SystemPromptMessage(content=prompt_template)
return prompt_messages
if not prompt_messages:
return [SystemPromptMessage(content=prompt_template)]
prompt_messages.insert(0, SystemPromptMessage(content=prompt_template))
return prompt_messages
def _organize_user_query(self, query: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""
Organize user query
"""
if self.files:
# get image detail config
image_detail_config = (
self.application_generate_entity.file_upload_config.image_config.detail
if (
self.application_generate_entity.file_upload_config
and self.application_generate_entity.file_upload_config.image_config
)
else None
)
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
for file in self.files:
prompt_message_contents.append(
file_manager.to_prompt_message_content(
file,
image_detail_config=image_detail_config,
)
)
prompt_message_contents.append(TextPromptMessageContent(data=query))
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
else:
prompt_messages.append(UserPromptMessage(content=query))
return prompt_messages
def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""
As for now, gpt supports both fc and vision at the first iteration.
We need to remove the image messages from the prompt messages at the first iteration.
"""
prompt_messages = deepcopy(prompt_messages)
for prompt_message in prompt_messages:
if isinstance(prompt_message, UserPromptMessage):
if isinstance(prompt_message.content, list):
prompt_message.content = "\n".join(
[
content.data
if content.type == PromptMessageContentType.TEXT
else "[image]"
if content.type == PromptMessageContentType.IMAGE
else "[file]"
for content in prompt_message.content
]
)
return prompt_messages
def _organize_prompt_messages(self):
# For ReAct strategy, use the agent prompt template
if self.config.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT and self.config.prompt:
prompt_template = self.config.prompt.first_prompt
else:
prompt_template = self.app_config.prompt_template.simple_prompt_template or ""
self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages)
query_prompt_messages = self._organize_user_query(self.query or "", [])
self.history_prompt_messages = AgentHistoryPromptTransform(
model_config=self.model_config,
prompt_messages=[*query_prompt_messages, *self._current_thoughts],
history_messages=self.history_prompt_messages,
memory=self.memory,
).get_prompt()
prompt_messages = [*self.history_prompt_messages, *query_prompt_messages, *self._current_thoughts]
if len(self._current_thoughts) != 0:
# clear messages after the first iteration
prompt_messages = self._clear_user_prompt_image_messages(prompt_messages)
return prompt_messages

View File

@ -1,12 +1,11 @@
import json
import logging
import uuid
from decimal import Decimal
from typing import Union, cast
from sqlalchemy import select
from core.agent.entities import AgentEntity, AgentToolEntity, ExecutionContext
from core.agent.entities import AgentEntity, AgentToolEntity
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
from core.app.apps.base_app_queue_manager import AppQueueManager
@ -42,7 +41,6 @@ from core.tools.tool_manager import ToolManager
from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool
from extensions.ext_database import db
from factories import file_factory
from models.enums import CreatorUserRole
from models.model import Conversation, Message, MessageAgentThought, MessageFile
logger = logging.getLogger(__name__)
@ -116,20 +114,9 @@ class BaseAgentRunner(AppRunner):
features = model_schema.features if model_schema and model_schema.features else []
self.stream_tool_call = ModelFeature.STREAM_TOOL_CALL in features
self.files = application_generate_entity.files if ModelFeature.VISION in features else []
self.model_features = features
self.query: str | None = ""
self._current_thoughts: list[PromptMessage] = []
def build_execution_context(self) -> ExecutionContext:
"""Build execution context."""
return ExecutionContext(
user_id=self.user_id,
app_id=self.app_config.app_id,
conversation_id=self.conversation.id,
message_id=self.message.id,
tenant_id=self.tenant_id,
)
def _repack_app_generate_entity(
self, app_generate_entity: AgentChatAppGenerateEntity
) -> AgentChatAppGenerateEntity:
@ -302,7 +289,6 @@ class BaseAgentRunner(AppRunner):
thought = MessageAgentThought(
message_id=message_id,
message_chain_id=None,
tool_process_data=None,
thought="",
tool=tool_name,
tool_labels_str="{}",
@ -310,20 +296,20 @@ class BaseAgentRunner(AppRunner):
tool_input=tool_input,
message=message,
message_token=0,
message_unit_price=Decimal(0),
message_price_unit=Decimal("0.001"),
message_unit_price=0,
message_price_unit=0,
message_files=json.dumps(messages_ids) if messages_ids else "",
answer="",
observation="",
answer_token=0,
answer_unit_price=Decimal(0),
answer_price_unit=Decimal("0.001"),
answer_unit_price=0,
answer_price_unit=0,
tokens=0,
total_price=Decimal(0),
total_price=0,
position=self.agent_thought_count + 1,
currency="USD",
latency=0,
created_by_role=CreatorUserRole.ACCOUNT,
created_by_role="account",
created_by=self.user_id,
)
@ -356,8 +342,7 @@ class BaseAgentRunner(AppRunner):
raise ValueError("agent thought not found")
if thought:
existing_thought = agent_thought.thought or ""
agent_thought.thought = f"{existing_thought}{thought}"
agent_thought.thought += thought
if tool_name:
agent_thought.tool = tool_name
@ -455,30 +440,21 @@ class BaseAgentRunner(AppRunner):
agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
if agent_thoughts:
for agent_thought in agent_thoughts:
tool_names_raw = agent_thought.tool
if tool_names_raw:
tool_names = tool_names_raw.split(";")
tools = agent_thought.tool
if tools:
tools = tools.split(";")
tool_calls: list[AssistantPromptMessage.ToolCall] = []
tool_call_response: list[ToolPromptMessage] = []
tool_input_payload = agent_thought.tool_input
if tool_input_payload:
try:
tool_inputs = json.loads(tool_input_payload)
except Exception:
tool_inputs = {tool: {} for tool in tool_names}
else:
tool_inputs = {tool: {} for tool in tool_names}
try:
tool_inputs = json.loads(agent_thought.tool_input)
except Exception:
tool_inputs = {tool: {} for tool in tools}
try:
tool_responses = json.loads(agent_thought.observation)
except Exception:
tool_responses = dict.fromkeys(tools, agent_thought.observation)
observation_payload = agent_thought.observation
if observation_payload:
try:
tool_responses = json.loads(observation_payload)
except Exception:
tool_responses = dict.fromkeys(tool_names, observation_payload)
else:
tool_responses = dict.fromkeys(tool_names, observation_payload)
for tool in tool_names:
for tool in tools:
# generate a uuid for tool call
tool_call_id = str(uuid.uuid4())
tool_calls.append(
@ -508,7 +484,7 @@ class BaseAgentRunner(AppRunner):
*tool_call_response,
]
)
if not tool_names_raw:
if not tools:
result.append(AssistantPromptMessage(content=agent_thought.thought))
else:
if message.answer:

View File

@ -0,0 +1,437 @@
import json
import logging
from abc import ABC, abstractmethod
from collections.abc import Generator, Mapping, Sequence
from typing import Any
from core.agent.base_agent_runner import BaseAgentRunner
from core.agent.entities import AgentScratchpadUnit
from core.agent.output_parser.cot_output_parser import CotAgentOutputParser
from core.app.apps.base_app_queue_manager import PublishFrom
from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
PromptMessageTool,
ToolPromptMessage,
UserPromptMessage,
)
from core.ops.ops_trace_manager import TraceQueueManager
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import ToolInvokeMeta
from core.tools.tool_engine import ToolEngine
from core.workflow.nodes.agent.exc import AgentMaxIterationError
from models.model import Message
logger = logging.getLogger(__name__)
class CotAgentRunner(BaseAgentRunner, ABC):
_is_first_iteration = True
_ignore_observation_providers = ["wenxin"]
_historic_prompt_messages: list[PromptMessage]
_agent_scratchpad: list[AgentScratchpadUnit]
_instruction: str
_query: str
_prompt_messages_tools: Sequence[PromptMessageTool]
def run(
self,
message: Message,
query: str,
inputs: Mapping[str, str],
) -> Generator:
"""
Run Cot agent application
"""
app_generate_entity = self.application_generate_entity
self._repack_app_generate_entity(app_generate_entity)
self._init_react_state(query)
trace_manager = app_generate_entity.trace_manager
# check model mode
if "Observation" not in app_generate_entity.model_conf.stop:
if app_generate_entity.model_conf.provider not in self._ignore_observation_providers:
app_generate_entity.model_conf.stop.append("Observation")
app_config = self.app_config
assert app_config.agent
# init instruction
inputs = inputs or {}
instruction = app_config.prompt_template.simple_prompt_template or ""
self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs)
iteration_step = 1
max_iteration_steps = min(app_config.agent.max_iteration, 99) + 1
# convert tools into ModelRuntime Tool format
tool_instances, prompt_messages_tools = self._init_prompt_tools()
self._prompt_messages_tools = prompt_messages_tools
function_call_state = True
llm_usage: dict[str, LLMUsage | None] = {"usage": None}
final_answer = ""
prompt_messages: list = [] # Initialize prompt_messages
agent_thought_id = "" # Initialize agent_thought_id
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage | None], usage: LLMUsage):
if not final_llm_usage_dict["usage"]:
final_llm_usage_dict["usage"] = usage
else:
llm_usage = final_llm_usage_dict["usage"]
llm_usage.prompt_tokens += usage.prompt_tokens
llm_usage.completion_tokens += usage.completion_tokens
llm_usage.total_tokens += usage.total_tokens
llm_usage.prompt_price += usage.prompt_price
llm_usage.completion_price += usage.completion_price
llm_usage.total_price += usage.total_price
model_instance = self.model_instance
while function_call_state and iteration_step <= max_iteration_steps:
# continue to run until there is not any tool call
function_call_state = False
if iteration_step == max_iteration_steps:
# the last iteration, remove all tools
self._prompt_messages_tools = []
message_file_ids: list[str] = []
agent_thought_id = self.create_agent_thought(
message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
)
if iteration_step > 1:
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
)
# recalc llm max tokens
prompt_messages = self._organize_prompt_messages()
self.recalc_llm_max_tokens(self.model_config, prompt_messages)
# invoke model
chunks = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=app_generate_entity.model_conf.parameters,
tools=[],
stop=app_generate_entity.model_conf.stop,
stream=True,
user=self.user_id,
callbacks=[],
)
usage_dict: dict[str, LLMUsage | None] = {}
react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)
scratchpad = AgentScratchpadUnit(
agent_response="",
thought="",
action_str="",
observation="",
action=None,
)
# publish agent thought if it's first iteration
if iteration_step == 1:
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
)
for chunk in react_chunks:
if isinstance(chunk, AgentScratchpadUnit.Action):
action = chunk
# detect action
assert scratchpad.agent_response is not None
scratchpad.agent_response += json.dumps(chunk.model_dump())
scratchpad.action_str = json.dumps(chunk.model_dump())
scratchpad.action = action
else:
assert scratchpad.agent_response is not None
scratchpad.agent_response += chunk
assert scratchpad.thought is not None
scratchpad.thought += chunk
yield LLMResultChunk(
model=self.model_config.model,
prompt_messages=prompt_messages,
system_fingerprint="",
delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=chunk), usage=None),
)
assert scratchpad.thought is not None
scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you"
self._agent_scratchpad.append(scratchpad)
# Check if max iteration is reached and model still wants to call tools
if iteration_step == max_iteration_steps and scratchpad.action:
if scratchpad.action.action_name.lower() != "final answer":
raise AgentMaxIterationError(app_config.agent.max_iteration)
# get llm usage
if "usage" in usage_dict:
if usage_dict["usage"] is not None:
increase_usage(llm_usage, usage_dict["usage"])
else:
usage_dict["usage"] = LLMUsage.empty_usage()
self.save_agent_thought(
agent_thought_id=agent_thought_id,
tool_name=(scratchpad.action.action_name if scratchpad.action and not scratchpad.is_final() else ""),
tool_input={scratchpad.action.action_name: scratchpad.action.action_input} if scratchpad.action else {},
tool_invoke_meta={},
thought=scratchpad.thought or "",
observation="",
answer=scratchpad.agent_response or "",
messages_ids=[],
llm_usage=usage_dict["usage"],
)
if not scratchpad.is_final():
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
)
if not scratchpad.action:
# failed to extract action, return final answer directly
final_answer = ""
else:
if scratchpad.action.action_name.lower() == "final answer":
# action is final answer, return final answer directly
try:
if isinstance(scratchpad.action.action_input, dict):
final_answer = json.dumps(scratchpad.action.action_input, ensure_ascii=False)
elif isinstance(scratchpad.action.action_input, str):
final_answer = scratchpad.action.action_input
else:
final_answer = f"{scratchpad.action.action_input}"
except TypeError:
final_answer = f"{scratchpad.action.action_input}"
else:
function_call_state = True
# action is tool call, invoke tool
tool_invoke_response, tool_invoke_meta = self._handle_invoke_action(
action=scratchpad.action,
tool_instances=tool_instances,
message_file_ids=message_file_ids,
trace_manager=trace_manager,
)
scratchpad.observation = tool_invoke_response
scratchpad.agent_response = tool_invoke_response
self.save_agent_thought(
agent_thought_id=agent_thought_id,
tool_name=scratchpad.action.action_name,
tool_input={scratchpad.action.action_name: scratchpad.action.action_input},
thought=scratchpad.thought or "",
observation={scratchpad.action.action_name: tool_invoke_response},
tool_invoke_meta={scratchpad.action.action_name: tool_invoke_meta.to_dict()},
answer=scratchpad.agent_response,
messages_ids=message_file_ids,
llm_usage=usage_dict["usage"],
)
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
)
# update prompt tool message
for prompt_tool in self._prompt_messages_tools:
self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
iteration_step += 1
yield LLMResultChunk(
model=model_instance.model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=0, message=AssistantPromptMessage(content=final_answer), usage=llm_usage["usage"]
),
system_fingerprint="",
)
# save agent thought
self.save_agent_thought(
agent_thought_id=agent_thought_id,
tool_name="",
tool_input={},
tool_invoke_meta={},
thought=final_answer,
observation={},
answer=final_answer,
messages_ids=[],
)
# publish end event
self.queue_manager.publish(
QueueMessageEndEvent(
llm_result=LLMResult(
model=model_instance.model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=final_answer),
usage=llm_usage["usage"] or LLMUsage.empty_usage(),
system_fingerprint="",
)
),
PublishFrom.APPLICATION_MANAGER,
)
def _handle_invoke_action(
self,
action: AgentScratchpadUnit.Action,
tool_instances: Mapping[str, Tool],
message_file_ids: list[str],
trace_manager: TraceQueueManager | None = None,
) -> tuple[str, ToolInvokeMeta]:
"""
handle invoke action
:param action: action
:param tool_instances: tool instances
:param message_file_ids: message file ids
:param trace_manager: trace manager
:return: observation, meta
"""
# action is tool call, invoke tool
tool_call_name = action.action_name
tool_call_args = action.action_input
tool_instance = tool_instances.get(tool_call_name)
if not tool_instance:
answer = f"there is not a tool named {tool_call_name}"
return answer, ToolInvokeMeta.error_instance(answer)
if isinstance(tool_call_args, str):
try:
tool_call_args = json.loads(tool_call_args)
except json.JSONDecodeError:
pass
# invoke tool
tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke(
tool=tool_instance,
tool_parameters=tool_call_args,
user_id=self.user_id,
tenant_id=self.tenant_id,
message=self.message,
invoke_from=self.application_generate_entity.invoke_from,
agent_tool_callback=self.agent_callback,
trace_manager=trace_manager,
)
# publish files
for message_file_id in message_files:
# publish message file
self.queue_manager.publish(
QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
)
# add message file ids
message_file_ids.append(message_file_id)
return tool_invoke_response, tool_invoke_meta
def _convert_dict_to_action(self, action: dict) -> AgentScratchpadUnit.Action:
"""
convert dict to action
"""
return AgentScratchpadUnit.Action(action_name=action["action"], action_input=action["action_input"])
def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: Mapping[str, Any]) -> str:
"""
fill in inputs from external data tools
"""
for key, value in inputs.items():
try:
instruction = instruction.replace(f"{{{{{key}}}}}", str(value))
except Exception:
continue
return instruction
def _init_react_state(self, query):
"""
init agent scratchpad
"""
self._query = query
self._agent_scratchpad = []
self._historic_prompt_messages = self._organize_historic_prompt_messages()
@abstractmethod
def _organize_prompt_messages(self) -> list[PromptMessage]:
"""
organize prompt messages
"""
def _format_assistant_message(self, agent_scratchpad: list[AgentScratchpadUnit]) -> str:
"""
format assistant message
"""
message = ""
for scratchpad in agent_scratchpad:
if scratchpad.is_final():
message += f"Final Answer: {scratchpad.agent_response}"
else:
message += f"Thought: {scratchpad.thought}\n\n"
if scratchpad.action_str:
message += f"Action: {scratchpad.action_str}\n\n"
if scratchpad.observation:
message += f"Observation: {scratchpad.observation}\n\n"
return message
def _organize_historic_prompt_messages(
self, current_session_messages: list[PromptMessage] | None = None
) -> list[PromptMessage]:
"""
organize historic prompt messages
"""
result: list[PromptMessage] = []
scratchpads: list[AgentScratchpadUnit] = []
current_scratchpad: AgentScratchpadUnit | None = None
for message in self.history_prompt_messages:
if isinstance(message, AssistantPromptMessage):
if not current_scratchpad:
assert isinstance(message.content, str)
current_scratchpad = AgentScratchpadUnit(
agent_response=message.content,
thought=message.content or "I am thinking about how to help you",
action_str="",
action=None,
observation=None,
)
scratchpads.append(current_scratchpad)
if message.tool_calls:
try:
current_scratchpad.action = AgentScratchpadUnit.Action(
action_name=message.tool_calls[0].function.name,
action_input=json.loads(message.tool_calls[0].function.arguments),
)
current_scratchpad.action_str = json.dumps(current_scratchpad.action.to_dict())
except Exception:
logger.exception("Failed to parse tool call from assistant message")
elif isinstance(message, ToolPromptMessage):
if current_scratchpad:
assert isinstance(message.content, str)
current_scratchpad.observation = message.content
else:
raise NotImplementedError("expected str type")
elif isinstance(message, UserPromptMessage):
if scratchpads:
result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads)))
scratchpads = []
current_scratchpad = None
result.append(message)
if scratchpads:
result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads)))
historic_prompts = AgentHistoryPromptTransform(
model_config=self.model_config,
prompt_messages=current_session_messages or [],
history_messages=result,
memory=self.memory,
).get_prompt()
return historic_prompts

View File

@ -0,0 +1,118 @@
import json
from core.agent.cot_agent_runner import CotAgentRunner
from core.file import file_manager
from core.model_runtime.entities import (
AssistantPromptMessage,
PromptMessage,
SystemPromptMessage,
TextPromptMessageContent,
UserPromptMessage,
)
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
from core.model_runtime.utils.encoders import jsonable_encoder
class CotChatAgentRunner(CotAgentRunner):
def _organize_system_prompt(self) -> SystemPromptMessage:
"""
Organize system prompt
"""
assert self.app_config.agent
assert self.app_config.agent.prompt
prompt_entity = self.app_config.agent.prompt
if not prompt_entity:
raise ValueError("Agent prompt configuration is not set")
first_prompt = prompt_entity.first_prompt
system_prompt = (
first_prompt.replace("{{instruction}}", self._instruction)
.replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools)))
.replace("{{tool_names}}", ", ".join([tool.name for tool in self._prompt_messages_tools]))
)
return SystemPromptMessage(content=system_prompt)
def _organize_user_query(self, query, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""
Organize user query
"""
if self.files:
# get image detail config
image_detail_config = (
self.application_generate_entity.file_upload_config.image_config.detail
if (
self.application_generate_entity.file_upload_config
and self.application_generate_entity.file_upload_config.image_config
)
else None
)
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
for file in self.files:
prompt_message_contents.append(
file_manager.to_prompt_message_content(
file,
image_detail_config=image_detail_config,
)
)
prompt_message_contents.append(TextPromptMessageContent(data=query))
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
else:
prompt_messages.append(UserPromptMessage(content=query))
return prompt_messages
def _organize_prompt_messages(self) -> list[PromptMessage]:
"""
Organize
"""
# organize system prompt
system_message = self._organize_system_prompt()
# organize current assistant messages
agent_scratchpad = self._agent_scratchpad
if not agent_scratchpad:
assistant_messages = []
else:
assistant_message = AssistantPromptMessage(content="")
assistant_message.content = "" # FIXME: type check tell mypy that assistant_message.content is str
for unit in agent_scratchpad:
if unit.is_final():
assert isinstance(assistant_message.content, str)
assistant_message.content += f"Final Answer: {unit.agent_response}"
else:
assert isinstance(assistant_message.content, str)
assistant_message.content += f"Thought: {unit.thought}\n\n"
if unit.action_str:
assistant_message.content += f"Action: {unit.action_str}\n\n"
if unit.observation:
assistant_message.content += f"Observation: {unit.observation}\n\n"
assistant_messages = [assistant_message]
# query messages
query_messages = self._organize_user_query(self._query, [])
if assistant_messages:
# organize historic prompt messages
historic_messages = self._organize_historic_prompt_messages(
[system_message, *query_messages, *assistant_messages, UserPromptMessage(content="continue")]
)
messages = [
system_message,
*historic_messages,
*query_messages,
*assistant_messages,
UserPromptMessage(content="continue"),
]
else:
# organize historic prompt messages
historic_messages = self._organize_historic_prompt_messages([system_message, *query_messages])
messages = [system_message, *historic_messages, *query_messages]
# join all messages
return messages

View File

@ -0,0 +1,87 @@
import json
from core.agent.cot_agent_runner import CotAgentRunner
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
TextPromptMessageContent,
UserPromptMessage,
)
from core.model_runtime.utils.encoders import jsonable_encoder
class CotCompletionAgentRunner(CotAgentRunner):
def _organize_instruction_prompt(self) -> str:
"""
Organize instruction prompt
"""
if self.app_config.agent is None:
raise ValueError("Agent configuration is not set")
prompt_entity = self.app_config.agent.prompt
if prompt_entity is None:
raise ValueError("prompt entity is not set")
first_prompt = prompt_entity.first_prompt
system_prompt = (
first_prompt.replace("{{instruction}}", self._instruction)
.replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools)))
.replace("{{tool_names}}", ", ".join([tool.name for tool in self._prompt_messages_tools]))
)
return system_prompt
def _organize_historic_prompt(self, current_session_messages: list[PromptMessage] | None = None) -> str:
"""
Organize historic prompt
"""
historic_prompt_messages = self._organize_historic_prompt_messages(current_session_messages)
historic_prompt = ""
for message in historic_prompt_messages:
if isinstance(message, UserPromptMessage):
historic_prompt += f"Question: {message.content}\n\n"
elif isinstance(message, AssistantPromptMessage):
if isinstance(message.content, str):
historic_prompt += message.content + "\n\n"
elif isinstance(message.content, list):
for content in message.content:
if not isinstance(content, TextPromptMessageContent):
continue
historic_prompt += content.data
return historic_prompt
def _organize_prompt_messages(self) -> list[PromptMessage]:
"""
Organize prompt messages
"""
# organize system prompt
system_prompt = self._organize_instruction_prompt()
# organize historic prompt messages
historic_prompt = self._organize_historic_prompt()
# organize current assistant messages
agent_scratchpad = self._agent_scratchpad
assistant_prompt = ""
for unit in agent_scratchpad or []:
if unit.is_final():
assistant_prompt += f"Final Answer: {unit.agent_response}"
else:
assistant_prompt += f"Thought: {unit.thought}\n\n"
if unit.action_str:
assistant_prompt += f"Action: {unit.action_str}\n\n"
if unit.observation:
assistant_prompt += f"Observation: {unit.observation}\n\n"
# query messages
query_prompt = f"Question: {self._query}"
# join all messages
prompt = (
system_prompt.replace("{{historic_messages}}", historic_prompt)
.replace("{{agent_scratchpad}}", assistant_prompt)
.replace("{{query}}", query_prompt)
)
return [UserPromptMessage(content=prompt)]

View File

@ -1,5 +1,3 @@
import uuid
from collections.abc import Mapping
from enum import StrEnum
from typing import Any, Union
@ -94,96 +92,3 @@ class AgentInvokeMessage(ToolInvokeMessage):
"""
pass
class ExecutionContext(BaseModel):
"""Execution context containing trace and audit information.
This context carries all the IDs and metadata that are not part of
the core business logic but needed for tracing, auditing, and
correlation purposes.
"""
user_id: str | None = None
app_id: str | None = None
conversation_id: str | None = None
message_id: str | None = None
tenant_id: str | None = None
@classmethod
def create_minimal(cls, user_id: str | None = None) -> "ExecutionContext":
"""Create a minimal context with only essential fields."""
return cls(user_id=user_id)
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary for passing to legacy code."""
return {
"user_id": self.user_id,
"app_id": self.app_id,
"conversation_id": self.conversation_id,
"message_id": self.message_id,
"tenant_id": self.tenant_id,
}
def with_updates(self, **kwargs) -> "ExecutionContext":
"""Create a new context with updated fields."""
data = self.to_dict()
data.update(kwargs)
return ExecutionContext(
user_id=data.get("user_id"),
app_id=data.get("app_id"),
conversation_id=data.get("conversation_id"),
message_id=data.get("message_id"),
tenant_id=data.get("tenant_id"),
)
class AgentLog(BaseModel):
"""
Agent Log.
"""
class LogType(StrEnum):
"""Type of agent log entry."""
ROUND = "round" # A complete iteration round
THOUGHT = "thought" # LLM thinking/reasoning
TOOL_CALL = "tool_call" # Tool invocation
class LogMetadata(StrEnum):
STARTED_AT = "started_at"
FINISHED_AT = "finished_at"
ELAPSED_TIME = "elapsed_time"
TOTAL_PRICE = "total_price"
TOTAL_TOKENS = "total_tokens"
PROVIDER = "provider"
CURRENCY = "currency"
LLM_USAGE = "llm_usage"
ICON = "icon"
ICON_DARK = "icon_dark"
class LogStatus(StrEnum):
START = "start"
ERROR = "error"
SUCCESS = "success"
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="The id of the log")
label: str = Field(..., description="The label of the log")
log_type: LogType = Field(..., description="The type of the log")
parent_id: str | None = Field(default=None, description="Leave empty for root log")
error: str | None = Field(default=None, description="The error message")
status: LogStatus = Field(..., description="The status of the log")
data: Mapping[str, Any] = Field(..., description="Detailed log data")
metadata: Mapping[LogMetadata, Any] = Field(default={}, description="The metadata of the log")
class AgentResult(BaseModel):
"""
Agent execution result.
"""
text: str = Field(default="", description="The generated text")
files: list[Any] = Field(default_factory=list, description="Files produced during execution")
usage: Any | None = Field(default=None, description="LLM usage statistics")
finish_reason: str | None = Field(default=None, description="Reason for completion")

View File

@ -188,7 +188,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
),
)
assistant_message = AssistantPromptMessage(content=response, tool_calls=[])
assistant_message = AssistantPromptMessage(content="", tool_calls=[])
if tool_calls:
assistant_message.tool_calls = [
AssistantPromptMessage.ToolCall(
@ -200,6 +200,8 @@ class FunctionCallAgentRunner(BaseAgentRunner):
)
for tool_call in tool_calls
]
else:
assistant_message.content = response
self._current_thoughts.append(assistant_message)

View File

@ -1,55 +0,0 @@
# Agent Patterns
A unified agent pattern module that powers both Agent V2 workflow nodes and agent applications. Strategies share a common execution contract while adapting to model capabilities and tool availability.
## Overview
The module applies a strategy pattern around LLM/tool orchestration. `StrategyFactory` auto-selects the best implementation based on model features or an explicit agent strategy, and each strategy streams logs and usage consistently.
## Key Features
- **Dual strategies**
- `FunctionCallStrategy`: uses native LLM function/tool calling when the model exposes `TOOL_CALL`, `MULTI_TOOL_CALL`, or `STREAM_TOOL_CALL`.
- `ReActStrategy`: ReAct (reasoning + acting) flow driven by `CotAgentOutputParser`, used when function calling is unavailable or explicitly requested.
- **Explicit or auto selection**
- `StrategyFactory.create_strategy` prefers an explicit `AgentEntity.Strategy` (FUNCTION_CALLING or CHAIN_OF_THOUGHT).
- Otherwise it falls back to function calling when tool-call features exist, or ReAct when they do not.
- **Unified execution contract**
- `AgentPattern.run` yields streaming `AgentLog` entries and `LLMResultChunk` data, returning an `AgentResult` with text, files, usage, and `finish_reason`.
- Iterations are configurable and hard-capped at 99 rounds; the last round forces a final answer by withholding tools.
- **Tool handling and hooks**
- Tools convert to `PromptMessageTool` objects before invocation.
- Optional `tool_invoke_hook` lets callers override tool execution (e.g., agent apps) while workflow runs use `ToolEngine.generic_invoke`.
- Tool outputs support text, links, JSON, variables, blobs, retriever resources, and file attachments; `target=="self"` files are reloaded into model context, others are returned as outputs.
- **File-aware arguments**
- Tool args accept `[File: <id>]` or `[Files: <id1, id2>]` placeholders that resolve to `File` objects before invocation, enabling models to reference uploaded files safely.
- **ReAct prompt shaping**
- System prompts replace `{{instruction}}`, `{{tools}}`, and `{{tool_names}}` placeholders.
- Adds `Observation` to stop sequences and appends scratchpad text so the model sees prior Thought/Action/Observation history.
- **Observability and accounting**
- Standardized `AgentLog` entries for rounds, model thoughts, and tool calls, including usage aggregation (`LLMUsage`) across streaming and non-streaming paths.
## Architecture
```
agent/patterns/
├── base.py # Shared utilities: logging, usage, tool invocation, file handling
├── function_call.py # Native function-calling loop with tool execution
├── react.py # ReAct loop with CoT parsing and scratchpad wiring
└── strategy_factory.py # Strategy selection by model features or explicit override
```
## Usage
- For auto-selection:
- Call `StrategyFactory.create_strategy(model_features, model_instance, context, tools, files, ...)` and run the returned strategy with prompt messages and model params.
- For explicit behavior:
- Pass `agent_strategy=AgentEntity.Strategy.FUNCTION_CALLING` to force native calls (falls back to ReAct if unsupported), or `CHAIN_OF_THOUGHT` to force ReAct.
- Both strategies stream chunks and logs; collect the generator output until it returns an `AgentResult`.
## Integration Points
- **Model runtime**: delegates to `ModelInstance.invoke_llm` for both streaming and non-streaming calls.
- **Tool system**: defaults to `ToolEngine.generic_invoke`, with `tool_invoke_hook` for custom callers.
- **Files**: flows through `File` objects for tool inputs/outputs and model-context attachments.
- **Execution context**: `ExecutionContext` fields (user/app/conversation/message) propagate to tool invocations and logging.

View File

@ -1,19 +0,0 @@
"""Agent patterns module.
This module provides different strategies for agent execution:
- FunctionCallStrategy: Uses native function/tool calling
- ReActStrategy: Uses ReAct (Reasoning + Acting) approach
- StrategyFactory: Factory for creating strategies based on model features
"""
from .base import AgentPattern
from .function_call import FunctionCallStrategy
from .react import ReActStrategy
from .strategy_factory import StrategyFactory
__all__ = [
"AgentPattern",
"FunctionCallStrategy",
"ReActStrategy",
"StrategyFactory",
]

View File

@ -1,474 +0,0 @@
"""Base class for agent strategies."""
from __future__ import annotations
import json
import re
import time
from abc import ABC, abstractmethod
from collections.abc import Callable, Generator
from typing import TYPE_CHECKING, Any
from core.agent.entities import AgentLog, AgentResult, ExecutionContext
from core.file import File
from core.model_manager import ModelInstance
from core.model_runtime.entities import (
AssistantPromptMessage,
LLMResult,
LLMResultChunk,
LLMResultChunkDelta,
PromptMessage,
PromptMessageTool,
)
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.message_entities import TextPromptMessageContent
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolInvokeMeta
if TYPE_CHECKING:
from core.tools.__base.tool import Tool
# Type alias for tool invoke hook
# Returns: (response_content, message_file_ids, tool_invoke_meta)
ToolInvokeHook = Callable[["Tool", dict[str, Any], str], tuple[str, list[str], ToolInvokeMeta]]
class AgentPattern(ABC):
"""Base class for agent execution strategies."""
def __init__(
self,
model_instance: ModelInstance,
tools: list[Tool],
context: ExecutionContext,
max_iterations: int = 10,
workflow_call_depth: int = 0,
files: list[File] = [],
tool_invoke_hook: ToolInvokeHook | None = None,
):
"""Initialize the agent strategy."""
self.model_instance = model_instance
self.tools = tools
self.context = context
self.max_iterations = min(max_iterations, 99) # Cap at 99 iterations
self.workflow_call_depth = workflow_call_depth
self.files: list[File] = files
self.tool_invoke_hook = tool_invoke_hook
@abstractmethod
def run(
self,
prompt_messages: list[PromptMessage],
model_parameters: dict[str, Any],
stop: list[str] = [],
stream: bool = True,
) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]:
"""Execute the agent strategy."""
pass
def _accumulate_usage(self, total_usage: dict[str, Any], delta_usage: LLMUsage) -> None:
"""Accumulate LLM usage statistics."""
if not total_usage.get("usage"):
# Create a copy to avoid modifying the original
total_usage["usage"] = LLMUsage(
prompt_tokens=delta_usage.prompt_tokens,
prompt_unit_price=delta_usage.prompt_unit_price,
prompt_price_unit=delta_usage.prompt_price_unit,
prompt_price=delta_usage.prompt_price,
completion_tokens=delta_usage.completion_tokens,
completion_unit_price=delta_usage.completion_unit_price,
completion_price_unit=delta_usage.completion_price_unit,
completion_price=delta_usage.completion_price,
total_tokens=delta_usage.total_tokens,
total_price=delta_usage.total_price,
currency=delta_usage.currency,
latency=delta_usage.latency,
)
else:
current: LLMUsage = total_usage["usage"]
current.prompt_tokens += delta_usage.prompt_tokens
current.completion_tokens += delta_usage.completion_tokens
current.total_tokens += delta_usage.total_tokens
current.prompt_price += delta_usage.prompt_price
current.completion_price += delta_usage.completion_price
current.total_price += delta_usage.total_price
def _extract_content(self, content: Any) -> str:
"""Extract text content from message content."""
if isinstance(content, list):
# Content items are PromptMessageContentUnionTypes
text_parts = []
for c in content:
# Check if it's a TextPromptMessageContent (which has data attribute)
if isinstance(c, TextPromptMessageContent):
text_parts.append(c.data)
return "".join(text_parts)
return str(content)
def _has_tool_calls(self, chunk: LLMResultChunk) -> bool:
"""Check if chunk contains tool calls."""
# LLMResultChunk always has delta attribute
return bool(chunk.delta.message and chunk.delta.message.tool_calls)
def _has_tool_calls_result(self, result: LLMResult) -> bool:
"""Check if result contains tool calls (non-streaming)."""
# LLMResult always has message attribute
return bool(result.message and result.message.tool_calls)
def _extract_tool_calls(self, chunk: LLMResultChunk) -> list[tuple[str, str, dict[str, Any]]]:
"""Extract tool calls from streaming chunk."""
tool_calls: list[tuple[str, str, dict[str, Any]]] = []
if chunk.delta.message and chunk.delta.message.tool_calls:
for tool_call in chunk.delta.message.tool_calls:
if tool_call.function:
try:
args = json.loads(tool_call.function.arguments) if tool_call.function.arguments else {}
except json.JSONDecodeError:
args = {}
tool_calls.append((tool_call.id or "", tool_call.function.name, args))
return tool_calls
def _extract_tool_calls_result(self, result: LLMResult) -> list[tuple[str, str, dict[str, Any]]]:
"""Extract tool calls from non-streaming result."""
tool_calls = []
if result.message and result.message.tool_calls:
for tool_call in result.message.tool_calls:
if tool_call.function:
try:
args = json.loads(tool_call.function.arguments) if tool_call.function.arguments else {}
except json.JSONDecodeError:
args = {}
tool_calls.append((tool_call.id or "", tool_call.function.name, args))
return tool_calls
def _extract_text_from_message(self, message: PromptMessage) -> str:
"""Extract text content from a prompt message."""
# PromptMessage always has content attribute
content = message.content
if isinstance(content, str):
return content
elif isinstance(content, list):
# Extract text from content list
text_parts = []
for item in content:
if isinstance(item, TextPromptMessageContent):
text_parts.append(item.data)
return " ".join(text_parts)
return ""
def _get_tool_metadata(self, tool_instance: Tool) -> dict[AgentLog.LogMetadata, Any]:
"""Get metadata for a tool including provider and icon info."""
from core.tools.tool_manager import ToolManager
metadata: dict[AgentLog.LogMetadata, Any] = {}
if tool_instance.entity and tool_instance.entity.identity:
identity = tool_instance.entity.identity
if identity.provider:
metadata[AgentLog.LogMetadata.PROVIDER] = identity.provider
# Get icon using ToolManager for proper URL generation
tenant_id = self.context.tenant_id
if tenant_id and identity.provider:
try:
provider_type = tool_instance.tool_provider_type()
icon = ToolManager.get_tool_icon(tenant_id, provider_type, identity.provider)
if isinstance(icon, str):
metadata[AgentLog.LogMetadata.ICON] = icon
elif isinstance(icon, dict):
# Handle icon dict with background/content or light/dark variants
metadata[AgentLog.LogMetadata.ICON] = icon
except Exception:
# Fallback to identity.icon if ToolManager fails
if identity.icon:
metadata[AgentLog.LogMetadata.ICON] = identity.icon
elif identity.icon:
metadata[AgentLog.LogMetadata.ICON] = identity.icon
return metadata
def _create_log(
self,
label: str,
log_type: AgentLog.LogType,
status: AgentLog.LogStatus,
data: dict[str, Any] | None = None,
parent_id: str | None = None,
extra_metadata: dict[AgentLog.LogMetadata, Any] | None = None,
) -> AgentLog:
"""Create a new AgentLog with standard metadata."""
metadata: dict[AgentLog.LogMetadata, Any] = {
AgentLog.LogMetadata.STARTED_AT: time.perf_counter(),
}
if extra_metadata:
metadata.update(extra_metadata)
return AgentLog(
label=label,
log_type=log_type,
status=status,
data=data or {},
parent_id=parent_id,
metadata=metadata,
)
def _finish_log(
self,
log: AgentLog,
data: dict[str, Any] | None = None,
usage: LLMUsage | None = None,
) -> AgentLog:
"""Finish an AgentLog by updating its status and metadata."""
log.status = AgentLog.LogStatus.SUCCESS
if data is not None:
log.data = data
# Calculate elapsed time
started_at = log.metadata.get(AgentLog.LogMetadata.STARTED_AT, time.perf_counter())
finished_at = time.perf_counter()
# Update metadata
log.metadata = {
**log.metadata,
AgentLog.LogMetadata.FINISHED_AT: finished_at,
# Calculate elapsed time in seconds
AgentLog.LogMetadata.ELAPSED_TIME: round(finished_at - started_at, 4),
}
# Add usage information if provided
if usage:
log.metadata.update(
{
AgentLog.LogMetadata.TOTAL_PRICE: usage.total_price,
AgentLog.LogMetadata.CURRENCY: usage.currency,
AgentLog.LogMetadata.TOTAL_TOKENS: usage.total_tokens,
AgentLog.LogMetadata.LLM_USAGE: usage,
}
)
return log
def _replace_file_references(self, tool_args: dict[str, Any]) -> dict[str, Any]:
"""
Replace file references in tool arguments with actual File objects.
Args:
tool_args: Dictionary of tool arguments
Returns:
Updated tool arguments with file references replaced
"""
# Process each argument in the dictionary
processed_args: dict[str, Any] = {}
for key, value in tool_args.items():
processed_args[key] = self._process_file_reference(value)
return processed_args
def _process_file_reference(self, data: Any) -> Any:
"""
Recursively process data to replace file references.
Supports both single file [File: file_id] and multiple files [Files: file_id1, file_id2, ...].
Args:
data: The data to process (can be dict, list, str, or other types)
Returns:
Processed data with file references replaced
"""
single_file_pattern = re.compile(r"^\[File:\s*([^\]]+)\]$")
multiple_files_pattern = re.compile(r"^\[Files:\s*([^\]]+)\]$")
if isinstance(data, dict):
# Process dictionary recursively
return {key: self._process_file_reference(value) for key, value in data.items()}
elif isinstance(data, list):
# Process list recursively
return [self._process_file_reference(item) for item in data]
elif isinstance(data, str):
# Check for single file pattern [File: file_id]
single_match = single_file_pattern.match(data.strip())
if single_match:
file_id = single_match.group(1).strip()
# Find the file in self.files
for file in self.files:
if file.id and str(file.id) == file_id:
return file
# If file not found, return original value
return data
# Check for multiple files pattern [Files: file_id1, file_id2, ...]
multiple_match = multiple_files_pattern.match(data.strip())
if multiple_match:
file_ids_str = multiple_match.group(1).strip()
# Split by comma and strip whitespace
file_ids = [fid.strip() for fid in file_ids_str.split(",")]
# Find all matching files
matched_files: list[File] = []
for file_id in file_ids:
for file in self.files:
if file.id and str(file.id) == file_id:
matched_files.append(file)
break
# Return list of files if any were found, otherwise return original
return matched_files or data
return data
else:
# Return other types as-is
return data
def _create_text_chunk(self, text: str, prompt_messages: list[PromptMessage]) -> LLMResultChunk:
"""Create a text chunk for streaming."""
return LLMResultChunk(
model=self.model_instance.model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(content=text),
usage=None,
),
system_fingerprint="",
)
def _invoke_tool(
self,
tool_instance: Tool,
tool_args: dict[str, Any],
tool_name: str,
) -> tuple[str, list[File], ToolInvokeMeta | None]:
"""
Invoke a tool and collect its response.
Args:
tool_instance: The tool instance to invoke
tool_args: Tool arguments
tool_name: Name of the tool
Returns:
Tuple of (response_content, tool_files, tool_invoke_meta)
"""
# Process tool_args to replace file references with actual File objects
tool_args = self._replace_file_references(tool_args)
# If a tool invoke hook is set, use it instead of generic_invoke
if self.tool_invoke_hook:
response_content, _, tool_invoke_meta = self.tool_invoke_hook(tool_instance, tool_args, tool_name)
# Note: message_file_ids are stored in DB, we don't convert them to File objects here
# The caller (AgentAppRunner) handles file publishing
return response_content, [], tool_invoke_meta
# Default: use generic_invoke for workflow scenarios
# Import here to avoid circular import
from core.tools.tool_engine import DifyWorkflowCallbackHandler, ToolEngine
tool_response = ToolEngine().generic_invoke(
tool=tool_instance,
tool_parameters=tool_args,
user_id=self.context.user_id or "",
workflow_tool_callback=DifyWorkflowCallbackHandler(),
workflow_call_depth=self.workflow_call_depth,
app_id=self.context.app_id,
conversation_id=self.context.conversation_id,
message_id=self.context.message_id,
)
# Collect response and files
response_content = ""
tool_files: list[File] = []
for response in tool_response:
if response.type == ToolInvokeMessage.MessageType.TEXT:
assert isinstance(response.message, ToolInvokeMessage.TextMessage)
response_content += response.message.text
elif response.type == ToolInvokeMessage.MessageType.LINK:
# Handle link messages
if isinstance(response.message, ToolInvokeMessage.TextMessage):
response_content += f"[Link: {response.message.text}]"
elif response.type == ToolInvokeMessage.MessageType.IMAGE:
# Handle image URL messages
if isinstance(response.message, ToolInvokeMessage.TextMessage):
response_content += f"[Image: {response.message.text}]"
elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK:
# Handle image link messages
if isinstance(response.message, ToolInvokeMessage.TextMessage):
response_content += f"[Image: {response.message.text}]"
elif response.type == ToolInvokeMessage.MessageType.BINARY_LINK:
# Handle binary file link messages
if isinstance(response.message, ToolInvokeMessage.TextMessage):
filename = response.meta.get("filename", "file") if response.meta else "file"
response_content += f"[File: {filename} - {response.message.text}]"
elif response.type == ToolInvokeMessage.MessageType.JSON:
# Handle JSON messages
if isinstance(response.message, ToolInvokeMessage.JsonMessage):
response_content += json.dumps(response.message.json_object, ensure_ascii=False, indent=2)
elif response.type == ToolInvokeMessage.MessageType.BLOB:
# Handle blob messages - convert to text representation
if isinstance(response.message, ToolInvokeMessage.BlobMessage):
mime_type = (
response.meta.get("mime_type", "application/octet-stream")
if response.meta
else "application/octet-stream"
)
size = len(response.message.blob)
response_content += f"[Binary data: {mime_type}, size: {size} bytes]"
elif response.type == ToolInvokeMessage.MessageType.VARIABLE:
# Handle variable messages
if isinstance(response.message, ToolInvokeMessage.VariableMessage):
var_name = response.message.variable_name
var_value = response.message.variable_value
if isinstance(var_value, str):
response_content += var_value
else:
response_content += f"[Variable {var_name}: {json.dumps(var_value, ensure_ascii=False)}]"
elif response.type == ToolInvokeMessage.MessageType.BLOB_CHUNK:
# Handle blob chunk messages - these are parts of a larger blob
if isinstance(response.message, ToolInvokeMessage.BlobChunkMessage):
response_content += f"[Blob chunk {response.message.sequence}: {len(response.message.blob)} bytes]"
elif response.type == ToolInvokeMessage.MessageType.RETRIEVER_RESOURCES:
# Handle retriever resources messages
if isinstance(response.message, ToolInvokeMessage.RetrieverResourceMessage):
response_content += response.message.context
elif response.type == ToolInvokeMessage.MessageType.FILE:
# Extract file from meta
if response.meta and "file" in response.meta:
file = response.meta["file"]
if isinstance(file, File):
# Check if file is for model or tool output
if response.meta.get("target") == "self":
# File is for model - add to files for next prompt
self.files.append(file)
response_content += f"File '{file.filename}' has been loaded into your context."
else:
# File is tool output
tool_files.append(file)
return response_content, tool_files, None
def _find_tool_by_name(self, tool_name: str) -> Tool | None:
"""Find a tool instance by its name."""
for tool in self.tools:
if tool.entity.identity.name == tool_name:
return tool
return None
def _convert_tools_to_prompt_format(self) -> list[PromptMessageTool]:
"""Convert tools to prompt message format."""
prompt_tools: list[PromptMessageTool] = []
for tool in self.tools:
prompt_tools.append(tool.to_prompt_message_tool())
return prompt_tools
def _update_usage_with_empty(self, llm_usage: dict[str, Any]) -> None:
"""Initialize usage tracking with empty usage if not set."""
if "usage" not in llm_usage or llm_usage["usage"] is None:
llm_usage["usage"] = LLMUsage.empty_usage()

View File

@ -1,299 +0,0 @@
"""Function Call strategy implementation."""
import json
from collections.abc import Generator
from typing import Any, Union
from core.agent.entities import AgentLog, AgentResult
from core.file import File
from core.model_runtime.entities import (
AssistantPromptMessage,
LLMResult,
LLMResultChunk,
LLMResultChunkDelta,
LLMUsage,
PromptMessage,
PromptMessageTool,
ToolPromptMessage,
)
from core.tools.entities.tool_entities import ToolInvokeMeta
from .base import AgentPattern
class FunctionCallStrategy(AgentPattern):
"""Function Call strategy using model's native tool calling capability."""
def run(
self,
prompt_messages: list[PromptMessage],
model_parameters: dict[str, Any],
stop: list[str] = [],
stream: bool = True,
) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]:
"""Execute the function call agent strategy."""
# Convert tools to prompt format
prompt_tools: list[PromptMessageTool] = self._convert_tools_to_prompt_format()
# Initialize tracking
iteration_step: int = 1
max_iterations: int = self.max_iterations + 1
function_call_state: bool = True
total_usage: dict[str, LLMUsage | None] = {"usage": None}
messages: list[PromptMessage] = list(prompt_messages) # Create mutable copy
final_text: str = ""
finish_reason: str | None = None
output_files: list[File] = [] # Track files produced by tools
while function_call_state and iteration_step <= max_iterations:
function_call_state = False
round_log = self._create_log(
label=f"ROUND {iteration_step}",
log_type=AgentLog.LogType.ROUND,
status=AgentLog.LogStatus.START,
data={},
)
yield round_log
# On last iteration, remove tools to force final answer
current_tools: list[PromptMessageTool] = [] if iteration_step == max_iterations else prompt_tools
model_log = self._create_log(
label=f"{self.model_instance.model} Thought",
log_type=AgentLog.LogType.THOUGHT,
status=AgentLog.LogStatus.START,
data={},
parent_id=round_log.id,
extra_metadata={
AgentLog.LogMetadata.PROVIDER: self.model_instance.provider,
},
)
yield model_log
# Track usage for this round only
round_usage: dict[str, LLMUsage | None] = {"usage": None}
# Invoke model
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = self.model_instance.invoke_llm(
prompt_messages=messages,
model_parameters=model_parameters,
tools=current_tools,
stop=stop,
stream=stream,
user=self.context.user_id,
callbacks=[],
)
# Process response
tool_calls, response_content, chunk_finish_reason = yield from self._handle_chunks(
chunks, round_usage, model_log
)
messages.append(self._create_assistant_message(response_content, tool_calls))
# Accumulate to total usage
round_usage_value = round_usage.get("usage")
if round_usage_value:
self._accumulate_usage(total_usage, round_usage_value)
# Update final text if no tool calls (this is likely the final answer)
if not tool_calls:
final_text = response_content
# Update finish reason
if chunk_finish_reason:
finish_reason = chunk_finish_reason
# Process tool calls
tool_outputs: dict[str, str] = {}
if tool_calls:
function_call_state = True
# Execute tools
for tool_call_id, tool_name, tool_args in tool_calls:
tool_response, tool_files, _ = yield from self._handle_tool_call(
tool_name, tool_args, tool_call_id, messages, round_log
)
tool_outputs[tool_name] = tool_response
# Track files produced by tools
output_files.extend(tool_files)
yield self._finish_log(
round_log,
data={
"llm_result": response_content,
"tool_calls": [
{"name": tc[1], "args": tc[2], "output": tool_outputs.get(tc[1], "")} for tc in tool_calls
]
if tool_calls
else [],
"final_answer": final_text if not function_call_state else None,
},
usage=round_usage.get("usage"),
)
iteration_step += 1
# Return final result
from core.agent.entities import AgentResult
return AgentResult(
text=final_text,
files=output_files,
usage=total_usage.get("usage") or LLMUsage.empty_usage(),
finish_reason=finish_reason,
)
def _handle_chunks(
self,
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult],
llm_usage: dict[str, LLMUsage | None],
start_log: AgentLog,
) -> Generator[
LLMResultChunk | AgentLog,
None,
tuple[list[tuple[str, str, dict[str, Any]]], str, str | None],
]:
"""Handle LLM response chunks and extract tool calls and content.
Returns a tuple of (tool_calls, response_content, finish_reason).
"""
tool_calls: list[tuple[str, str, dict[str, Any]]] = []
response_content: str = ""
finish_reason: str | None = None
if isinstance(chunks, Generator):
# Streaming response
for chunk in chunks:
# Extract tool calls
if self._has_tool_calls(chunk):
tool_calls.extend(self._extract_tool_calls(chunk))
# Extract content
if chunk.delta.message and chunk.delta.message.content:
response_content += self._extract_content(chunk.delta.message.content)
# Track usage
if chunk.delta.usage:
self._accumulate_usage(llm_usage, chunk.delta.usage)
# Capture finish reason
if chunk.delta.finish_reason:
finish_reason = chunk.delta.finish_reason
yield chunk
else:
# Non-streaming response
result: LLMResult = chunks
if self._has_tool_calls_result(result):
tool_calls.extend(self._extract_tool_calls_result(result))
if result.message and result.message.content:
response_content += self._extract_content(result.message.content)
if result.usage:
self._accumulate_usage(llm_usage, result.usage)
# Convert to streaming format
yield LLMResultChunk(
model=result.model,
prompt_messages=result.prompt_messages,
delta=LLMResultChunkDelta(index=0, message=result.message, usage=result.usage),
)
yield self._finish_log(
start_log,
data={
"result": response_content,
},
usage=llm_usage.get("usage"),
)
return tool_calls, response_content, finish_reason
def _create_assistant_message(
self, content: str, tool_calls: list[tuple[str, str, dict[str, Any]]] | None = None
) -> AssistantPromptMessage:
"""Create assistant message with tool calls."""
if tool_calls is None:
return AssistantPromptMessage(content=content)
return AssistantPromptMessage(
content=content or "",
tool_calls=[
AssistantPromptMessage.ToolCall(
id=tc[0],
type="function",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tc[1], arguments=json.dumps(tc[2])),
)
for tc in tool_calls
],
)
def _handle_tool_call(
self,
tool_name: str,
tool_args: dict[str, Any],
tool_call_id: str,
messages: list[PromptMessage],
round_log: AgentLog,
) -> Generator[AgentLog, None, tuple[str, list[File], ToolInvokeMeta | None]]:
"""Handle a single tool call and return response with files and meta."""
# Find tool
tool_instance = self._find_tool_by_name(tool_name)
if not tool_instance:
raise ValueError(f"Tool {tool_name} not found")
# Get tool metadata (provider, icon, etc.)
tool_metadata = self._get_tool_metadata(tool_instance)
# Create tool call log
tool_call_log = self._create_log(
label=f"CALL {tool_name}",
log_type=AgentLog.LogType.TOOL_CALL,
status=AgentLog.LogStatus.START,
data={
"tool_call_id": tool_call_id,
"tool_name": tool_name,
"tool_args": tool_args,
},
parent_id=round_log.id,
extra_metadata=tool_metadata,
)
yield tool_call_log
# Invoke tool using base class method with error handling
try:
response_content, tool_files, tool_invoke_meta = self._invoke_tool(tool_instance, tool_args, tool_name)
yield self._finish_log(
tool_call_log,
data={
**tool_call_log.data,
"output": response_content,
"files": len(tool_files),
"meta": tool_invoke_meta.to_dict() if tool_invoke_meta else None,
},
)
final_content = response_content or "Tool executed successfully"
# Add tool response to messages
messages.append(
ToolPromptMessage(
content=final_content,
tool_call_id=tool_call_id,
name=tool_name,
)
)
return response_content, tool_files, tool_invoke_meta
except Exception as e:
# Tool invocation failed, yield error log
error_message = str(e)
tool_call_log.status = AgentLog.LogStatus.ERROR
tool_call_log.error = error_message
tool_call_log.data = {
**tool_call_log.data,
"error": error_message,
}
yield tool_call_log
# Add error message to conversation
error_content = f"Tool execution failed: {error_message}"
messages.append(
ToolPromptMessage(
content=error_content,
tool_call_id=tool_call_id,
name=tool_name,
)
)
return error_content, [], None

View File

@ -1,418 +0,0 @@
"""ReAct strategy implementation."""
from __future__ import annotations
import json
from collections.abc import Generator
from typing import TYPE_CHECKING, Any, Union
from core.agent.entities import AgentLog, AgentResult, AgentScratchpadUnit, ExecutionContext
from core.agent.output_parser.cot_output_parser import CotAgentOutputParser
from core.file import File
from core.model_manager import ModelInstance
from core.model_runtime.entities import (
AssistantPromptMessage,
LLMResult,
LLMResultChunk,
LLMResultChunkDelta,
PromptMessage,
SystemPromptMessage,
)
from .base import AgentPattern, ToolInvokeHook
if TYPE_CHECKING:
from core.tools.__base.tool import Tool
class ReActStrategy(AgentPattern):
"""ReAct strategy using reasoning and acting approach."""
def __init__(
self,
model_instance: ModelInstance,
tools: list[Tool],
context: ExecutionContext,
max_iterations: int = 10,
workflow_call_depth: int = 0,
files: list[File] = [],
tool_invoke_hook: ToolInvokeHook | None = None,
instruction: str = "",
):
"""Initialize the ReAct strategy with instruction support."""
super().__init__(
model_instance=model_instance,
tools=tools,
context=context,
max_iterations=max_iterations,
workflow_call_depth=workflow_call_depth,
files=files,
tool_invoke_hook=tool_invoke_hook,
)
self.instruction = instruction
def run(
self,
prompt_messages: list[PromptMessage],
model_parameters: dict[str, Any],
stop: list[str] = [],
stream: bool = True,
) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]:
"""Execute the ReAct agent strategy."""
# Initialize tracking
agent_scratchpad: list[AgentScratchpadUnit] = []
iteration_step: int = 1
max_iterations: int = self.max_iterations + 1
react_state: bool = True
total_usage: dict[str, Any] = {"usage": None}
output_files: list[File] = [] # Track files produced by tools
final_text: str = ""
finish_reason: str | None = None
# Add "Observation" to stop sequences
if "Observation" not in stop:
stop = stop.copy()
stop.append("Observation")
while react_state and iteration_step <= max_iterations:
react_state = False
round_log = self._create_log(
label=f"ROUND {iteration_step}",
log_type=AgentLog.LogType.ROUND,
status=AgentLog.LogStatus.START,
data={},
)
yield round_log
# Build prompt with/without tools based on iteration
include_tools = iteration_step < max_iterations
current_messages = self._build_prompt_with_react_format(
prompt_messages, agent_scratchpad, include_tools, self.instruction
)
model_log = self._create_log(
label=f"{self.model_instance.model} Thought",
log_type=AgentLog.LogType.THOUGHT,
status=AgentLog.LogStatus.START,
data={},
parent_id=round_log.id,
extra_metadata={
AgentLog.LogMetadata.PROVIDER: self.model_instance.provider,
},
)
yield model_log
# Track usage for this round only
round_usage: dict[str, Any] = {"usage": None}
# Use current messages directly (files are handled by base class if needed)
messages_to_use = current_messages
# Invoke model
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = self.model_instance.invoke_llm(
prompt_messages=messages_to_use,
model_parameters=model_parameters,
stop=stop,
stream=stream,
user=self.context.user_id or "",
callbacks=[],
)
# Process response
scratchpad, chunk_finish_reason = yield from self._handle_chunks(
chunks, round_usage, model_log, current_messages
)
agent_scratchpad.append(scratchpad)
# Accumulate to total usage
round_usage_value = round_usage.get("usage")
if round_usage_value:
self._accumulate_usage(total_usage, round_usage_value)
# Update finish reason
if chunk_finish_reason:
finish_reason = chunk_finish_reason
# Check if we have an action to execute
if scratchpad.action and scratchpad.action.action_name.lower() != "final answer":
react_state = True
# Execute tool
observation, tool_files = yield from self._handle_tool_call(
scratchpad.action, current_messages, round_log
)
scratchpad.observation = observation
# Track files produced by tools
output_files.extend(tool_files)
# Add observation to scratchpad for display
yield self._create_text_chunk(f"\nObservation: {observation}\n", current_messages)
else:
# Extract final answer
if scratchpad.action and scratchpad.action.action_input:
final_answer = scratchpad.action.action_input
if isinstance(final_answer, dict):
final_answer = json.dumps(final_answer, ensure_ascii=False)
final_text = str(final_answer)
elif scratchpad.thought:
# If no action but we have thought, use thought as final answer
final_text = scratchpad.thought
yield self._finish_log(
round_log,
data={
"thought": scratchpad.thought,
"action": scratchpad.action_str if scratchpad.action else None,
"observation": scratchpad.observation or None,
"final_answer": final_text if not react_state else None,
},
usage=round_usage.get("usage"),
)
iteration_step += 1
# Return final result
from core.agent.entities import AgentResult
return AgentResult(
text=final_text, files=output_files, usage=total_usage.get("usage"), finish_reason=finish_reason
)
def _build_prompt_with_react_format(
self,
original_messages: list[PromptMessage],
agent_scratchpad: list[AgentScratchpadUnit],
include_tools: bool = True,
instruction: str = "",
) -> list[PromptMessage]:
"""Build prompt messages with ReAct format."""
# Copy messages to avoid modifying original
messages = list(original_messages)
# Find and update the system prompt that should already exist
system_prompt_found = False
for i, msg in enumerate(messages):
if isinstance(msg, SystemPromptMessage):
system_prompt_found = True
# The system prompt from frontend already has the template, just replace placeholders
# Format tools
tools_str = ""
tool_names = []
if include_tools and self.tools:
# Convert tools to prompt message tools format
prompt_tools = [tool.to_prompt_message_tool() for tool in self.tools]
tool_names = [tool.name for tool in prompt_tools]
# Format tools as JSON for comprehensive information
from core.model_runtime.utils.encoders import jsonable_encoder
tools_str = json.dumps(jsonable_encoder(prompt_tools), indent=2)
tool_names_str = ", ".join(f'"{name}"' for name in tool_names)
else:
tools_str = "No tools available"
tool_names_str = ""
# Replace placeholders in the existing system prompt
updated_content = msg.content
assert isinstance(updated_content, str)
updated_content = updated_content.replace("{{instruction}}", instruction)
updated_content = updated_content.replace("{{tools}}", tools_str)
updated_content = updated_content.replace("{{tool_names}}", tool_names_str)
# Create new SystemPromptMessage with updated content
messages[i] = SystemPromptMessage(content=updated_content)
break
# If no system prompt found, that's unexpected but add scratchpad anyway
if not system_prompt_found:
# This shouldn't happen if frontend is working correctly
pass
# Format agent scratchpad
scratchpad_str = ""
if agent_scratchpad:
scratchpad_parts: list[str] = []
for unit in agent_scratchpad:
if unit.thought:
scratchpad_parts.append(f"Thought: {unit.thought}")
if unit.action_str:
scratchpad_parts.append(f"Action:\n```\n{unit.action_str}\n```")
if unit.observation:
scratchpad_parts.append(f"Observation: {unit.observation}")
scratchpad_str = "\n".join(scratchpad_parts)
# If there's a scratchpad, append it to the last message
if scratchpad_str:
messages.append(AssistantPromptMessage(content=scratchpad_str))
return messages
def _handle_chunks(
self,
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult],
llm_usage: dict[str, Any],
model_log: AgentLog,
current_messages: list[PromptMessage],
) -> Generator[
LLMResultChunk | AgentLog,
None,
tuple[AgentScratchpadUnit, str | None],
]:
"""Handle LLM response chunks and extract action/thought.
Returns a tuple of (scratchpad_unit, finish_reason).
"""
usage_dict: dict[str, Any] = {}
# Convert non-streaming to streaming format if needed
if isinstance(chunks, LLMResult):
# Create a generator from the LLMResult
def result_to_chunks() -> Generator[LLMResultChunk, None, None]:
yield LLMResultChunk(
model=chunks.model,
prompt_messages=chunks.prompt_messages,
delta=LLMResultChunkDelta(
index=0,
message=chunks.message,
usage=chunks.usage,
finish_reason=None, # LLMResult doesn't have finish_reason, only streaming chunks do
),
system_fingerprint=chunks.system_fingerprint or "",
)
streaming_chunks = result_to_chunks()
else:
streaming_chunks = chunks
react_chunks = CotAgentOutputParser.handle_react_stream_output(streaming_chunks, usage_dict)
# Initialize scratchpad unit
scratchpad = AgentScratchpadUnit(
agent_response="",
thought="",
action_str="",
observation="",
action=None,
)
finish_reason: str | None = None
# Process chunks
for chunk in react_chunks:
if isinstance(chunk, AgentScratchpadUnit.Action):
# Action detected
action_str = json.dumps(chunk.model_dump())
scratchpad.agent_response = (scratchpad.agent_response or "") + action_str
scratchpad.action_str = action_str
scratchpad.action = chunk
yield self._create_text_chunk(json.dumps(chunk.model_dump()), current_messages)
else:
# Text chunk
chunk_text = str(chunk)
scratchpad.agent_response = (scratchpad.agent_response or "") + chunk_text
scratchpad.thought = (scratchpad.thought or "") + chunk_text
yield self._create_text_chunk(chunk_text, current_messages)
# Update usage
if usage_dict.get("usage"):
if llm_usage.get("usage"):
self._accumulate_usage(llm_usage, usage_dict["usage"])
else:
llm_usage["usage"] = usage_dict["usage"]
# Clean up thought
scratchpad.thought = (scratchpad.thought or "").strip() or "I am thinking about how to help you"
# Finish model log
yield self._finish_log(
model_log,
data={
"thought": scratchpad.thought,
"action": scratchpad.action_str if scratchpad.action else None,
},
usage=llm_usage.get("usage"),
)
return scratchpad, finish_reason
def _handle_tool_call(
self,
action: AgentScratchpadUnit.Action,
prompt_messages: list[PromptMessage],
round_log: AgentLog,
) -> Generator[AgentLog, None, tuple[str, list[File]]]:
"""Handle tool call and return observation with files."""
tool_name = action.action_name
tool_args: dict[str, Any] | str = action.action_input
# Find tool instance first to get metadata
tool_instance = self._find_tool_by_name(tool_name)
tool_metadata = self._get_tool_metadata(tool_instance) if tool_instance else {}
# Start tool log with tool metadata
tool_log = self._create_log(
label=f"CALL {tool_name}",
log_type=AgentLog.LogType.TOOL_CALL,
status=AgentLog.LogStatus.START,
data={
"tool_name": tool_name,
"tool_args": tool_args,
},
parent_id=round_log.id,
extra_metadata=tool_metadata,
)
yield tool_log
if not tool_instance:
# Finish tool log with error
yield self._finish_log(
tool_log,
data={
**tool_log.data,
"error": f"Tool {tool_name} not found",
},
)
return f"Tool {tool_name} not found", []
# Ensure tool_args is a dict
tool_args_dict: dict[str, Any]
if isinstance(tool_args, str):
try:
tool_args_dict = json.loads(tool_args)
except json.JSONDecodeError:
tool_args_dict = {"input": tool_args}
elif not isinstance(tool_args, dict):
tool_args_dict = {"input": str(tool_args)}
else:
tool_args_dict = tool_args
# Invoke tool using base class method with error handling
try:
response_content, tool_files, tool_invoke_meta = self._invoke_tool(tool_instance, tool_args_dict, tool_name)
# Finish tool log
yield self._finish_log(
tool_log,
data={
**tool_log.data,
"output": response_content,
"files": len(tool_files),
"meta": tool_invoke_meta.to_dict() if tool_invoke_meta else None,
},
)
return response_content or "Tool executed successfully", tool_files
except Exception as e:
# Tool invocation failed, yield error log
error_message = str(e)
tool_log.status = AgentLog.LogStatus.ERROR
tool_log.error = error_message
tool_log.data = {
**tool_log.data,
"error": error_message,
}
yield tool_log
return f"Tool execution failed: {error_message}", []

View File

@ -1,107 +0,0 @@
"""Strategy factory for creating agent strategies."""
from __future__ import annotations
from typing import TYPE_CHECKING
from core.agent.entities import AgentEntity, ExecutionContext
from core.file.models import File
from core.model_manager import ModelInstance
from core.model_runtime.entities.model_entities import ModelFeature
from .base import AgentPattern, ToolInvokeHook
from .function_call import FunctionCallStrategy
from .react import ReActStrategy
if TYPE_CHECKING:
from core.tools.__base.tool import Tool
class StrategyFactory:
"""Factory for creating agent strategies based on model features."""
# Tool calling related features
TOOL_CALL_FEATURES = {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL}
@staticmethod
def create_strategy(
model_features: list[ModelFeature],
model_instance: ModelInstance,
context: ExecutionContext,
tools: list[Tool],
files: list[File],
max_iterations: int = 10,
workflow_call_depth: int = 0,
agent_strategy: AgentEntity.Strategy | None = None,
tool_invoke_hook: ToolInvokeHook | None = None,
instruction: str = "",
) -> AgentPattern:
"""
Create an appropriate strategy based on model features.
Args:
model_features: List of model features/capabilities
model_instance: Model instance to use
context: Execution context containing trace/audit information
tools: Available tools
files: Available files
max_iterations: Maximum iterations for the strategy
workflow_call_depth: Depth of workflow calls
agent_strategy: Optional explicit strategy override
tool_invoke_hook: Optional hook for custom tool invocation (e.g., agent_invoke)
instruction: Optional instruction for ReAct strategy
Returns:
AgentStrategy instance
"""
# If explicit strategy is provided and it's Function Calling, try to use it if supported
if agent_strategy == AgentEntity.Strategy.FUNCTION_CALLING:
if set(model_features) & StrategyFactory.TOOL_CALL_FEATURES:
return FunctionCallStrategy(
model_instance=model_instance,
context=context,
tools=tools,
files=files,
max_iterations=max_iterations,
workflow_call_depth=workflow_call_depth,
tool_invoke_hook=tool_invoke_hook,
)
# Fallback to ReAct if FC is requested but not supported
# If explicit strategy is Chain of Thought (ReAct)
if agent_strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
return ReActStrategy(
model_instance=model_instance,
context=context,
tools=tools,
files=files,
max_iterations=max_iterations,
workflow_call_depth=workflow_call_depth,
tool_invoke_hook=tool_invoke_hook,
instruction=instruction,
)
# Default auto-selection logic
if set(model_features) & StrategyFactory.TOOL_CALL_FEATURES:
# Model supports native function calling
return FunctionCallStrategy(
model_instance=model_instance,
context=context,
tools=tools,
files=files,
max_iterations=max_iterations,
workflow_call_depth=workflow_call_depth,
tool_invoke_hook=tool_invoke_hook,
)
else:
# Use ReAct strategy for models without function calling
return ReActStrategy(
model_instance=model_instance,
context=context,
tools=tools,
files=files,
max_iterations=max_iterations,
workflow_call_depth=workflow_call_depth,
tool_invoke_hook=tool_invoke_hook,
instruction=instruction,
)

View File

@ -1,3 +1,4 @@
import json
from collections.abc import Sequence
from enum import StrEnum, auto
from typing import Any, Literal
@ -120,7 +121,7 @@ class VariableEntity(BaseModel):
allowed_file_types: Sequence[FileType] | None = Field(default_factory=list)
allowed_file_extensions: Sequence[str] | None = Field(default_factory=list)
allowed_file_upload_methods: Sequence[FileTransferMethod] | None = Field(default_factory=list)
json_schema: dict | None = Field(default=None)
json_schema: str | None = Field(default=None)
@field_validator("description", mode="before")
@classmethod
@ -134,11 +135,17 @@ class VariableEntity(BaseModel):
@field_validator("json_schema")
@classmethod
def validate_json_schema(cls, schema: dict | None) -> dict | None:
def validate_json_schema(cls, schema: str | None) -> str | None:
if schema is None:
return None
try:
Draft7Validator.check_schema(schema)
json_schema = json.loads(schema)
except json.JSONDecodeError:
raise ValueError(f"invalid json_schema value {schema}")
try:
Draft7Validator.check_schema(json_schema)
except SchemaError as e:
raise ValueError(f"Invalid JSON schema: {e.message}")
return schema

View File

@ -26,6 +26,7 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
@classmethod
def get_app_config(cls, app_model: App, workflow: Workflow) -> AdvancedChatAppConfig:
features_dict = workflow.features_dict
app_mode = AppMode.value_of(app_model.mode)
app_config = AdvancedChatAppConfig(
tenant_id=app_model.tenant_id,

View File

@ -20,8 +20,6 @@ from core.app.entities.queue_entities import (
QueueTextChunkEvent,
)
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
from core.app.layers.conversation_variable_persist_layer import ConversationVariablePersistenceLayer
from core.db.session_factory import session_factory
from core.moderation.base import ModerationError
from core.moderation.input_moderation import InputModeration
from core.variables.variables import VariableUnion
@ -39,9 +37,9 @@ from extensions.ext_database import db
from extensions.ext_redis import redis_client
from extensions.otel import WorkflowAppRunnerHandler, trace_span
from models import Workflow
from models.enums import UserFrom
from models.model import App, Conversation, Message, MessageAnnotation
from models.workflow import ConversationVariable
from services.conversation_variable_updater import ConversationVariableUpdater
logger = logging.getLogger(__name__)
@ -105,11 +103,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
if not app_record:
raise ValueError("App not found")
invoke_from = self.application_generate_entity.invoke_from
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
invoke_from = InvokeFrom.DEBUGGER
user_from = self._resolve_user_from(invoke_from)
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
# Handle single iteration or single loop run
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
@ -162,8 +155,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
workflow_id=self._workflow.id,
tenant_id=self._workflow.tenant_id,
user_id=self.application_generate_entity.user_id,
user_from=user_from,
invoke_from=invoke_from,
)
db.session.close()
@ -181,8 +172,12 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
graph=graph,
graph_config=self._workflow.graph_dict,
user_id=self.application_generate_entity.user_id,
user_from=user_from,
invoke_from=invoke_from,
user_from=(
UserFrom.ACCOUNT
if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
else UserFrom.END_USER
),
invoke_from=self.application_generate_entity.invoke_from,
call_depth=self.application_generate_entity.call_depth,
variable_pool=variable_pool,
graph_runtime_state=graph_runtime_state,
@ -205,10 +200,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
)
workflow_entry.graph_engine.layer(persistence_layer)
conversation_variable_layer = ConversationVariablePersistenceLayer(
ConversationVariableUpdater(session_factory.get_session_maker())
)
workflow_entry.graph_engine.layer(conversation_variable_layer)
for layer in self._graph_engine_layers:
workflow_entry.graph_engine.layer(layer)

View File

@ -4,7 +4,6 @@ import re
import time
from collections.abc import Callable, Generator, Mapping
from contextlib import contextmanager
from dataclasses import dataclass, field
from threading import Thread
from typing import Any, Union
@ -20,7 +19,6 @@ from core.app.entities.app_invoke_entities import (
InvokeFrom,
)
from core.app.entities.queue_entities import (
ChunkType,
MessageQueueMessage,
QueueAdvancedChatMessageEndEvent,
QueueAgentLogEvent,
@ -72,122 +70,13 @@ from core.workflow.runtime import GraphRuntimeState
from core.workflow.system_variable import SystemVariable
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models import Account, Conversation, EndUser, LLMGenerationDetail, Message, MessageFile
from models import Account, Conversation, EndUser, Message, MessageFile
from models.enums import CreatorUserRole
from models.workflow import Workflow
logger = logging.getLogger(__name__)
@dataclass
class StreamEventBuffer:
"""
Buffer for recording stream events in order to reconstruct the generation sequence.
Records the exact order of text chunks, thoughts, and tool calls as they stream.
"""
# Accumulated reasoning content (each thought block is a separate element)
reasoning_content: list[str] = field(default_factory=list)
# Current reasoning buffer (accumulates until we see a different event type)
_current_reasoning: str = ""
# Tool calls with their details
tool_calls: list[dict] = field(default_factory=list)
# Tool call ID to index mapping for updating results
_tool_call_id_map: dict[str, int] = field(default_factory=dict)
# Sequence of events in stream order
sequence: list[dict] = field(default_factory=list)
# Current position in answer text
_content_position: int = 0
# Track last event type to detect transitions
_last_event_type: str | None = None
def _flush_current_reasoning(self) -> None:
"""Flush accumulated reasoning to the list and add to sequence."""
if self._current_reasoning.strip():
self.reasoning_content.append(self._current_reasoning.strip())
self.sequence.append({"type": "reasoning", "index": len(self.reasoning_content) - 1})
self._current_reasoning = ""
def record_text_chunk(self, text: str) -> None:
"""Record a text chunk event."""
if not text:
return
# Flush any pending reasoning first
if self._last_event_type == "thought":
self._flush_current_reasoning()
text_len = len(text)
start_pos = self._content_position
# If last event was also content, extend it; otherwise create new
if self.sequence and self.sequence[-1].get("type") == "content":
self.sequence[-1]["end"] = start_pos + text_len
else:
self.sequence.append({"type": "content", "start": start_pos, "end": start_pos + text_len})
self._content_position += text_len
self._last_event_type = "content"
def record_thought_chunk(self, text: str) -> None:
"""Record a thought/reasoning chunk event."""
if not text:
return
# Accumulate thought content
self._current_reasoning += text
self._last_event_type = "thought"
def record_tool_call(self, tool_call_id: str, tool_name: str, tool_arguments: str) -> None:
"""Record a tool call event."""
if not tool_call_id:
return
# Flush any pending reasoning first
if self._last_event_type == "thought":
self._flush_current_reasoning()
# Check if this tool call already exists (we might get multiple chunks)
if tool_call_id in self._tool_call_id_map:
idx = self._tool_call_id_map[tool_call_id]
# Update arguments if provided
if tool_arguments:
self.tool_calls[idx]["arguments"] = tool_arguments
else:
# New tool call
tool_call = {
"id": tool_call_id or "",
"name": tool_name or "",
"arguments": tool_arguments or "",
"result": "",
"elapsed_time": None,
}
self.tool_calls.append(tool_call)
idx = len(self.tool_calls) - 1
self._tool_call_id_map[tool_call_id] = idx
self.sequence.append({"type": "tool_call", "index": idx})
self._last_event_type = "tool_call"
def record_tool_result(self, tool_call_id: str, result: str, tool_elapsed_time: float | None = None) -> None:
"""Record a tool result event (update existing tool call)."""
if not tool_call_id:
return
if tool_call_id in self._tool_call_id_map:
idx = self._tool_call_id_map[tool_call_id]
self.tool_calls[idx]["result"] = result
self.tool_calls[idx]["elapsed_time"] = tool_elapsed_time
def finalize(self) -> None:
"""Finalize the buffer, flushing any pending data."""
if self._last_event_type == "thought":
self._flush_current_reasoning()
def has_data(self) -> bool:
"""Check if there's any meaningful data recorded."""
return bool(self.reasoning_content or self.tool_calls or self.sequence)
class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
"""
AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
@ -255,8 +144,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
self._workflow_run_id: str = ""
self._draft_var_saver_factory = draft_var_saver_factory
self._graph_runtime_state: GraphRuntimeState | None = None
# Stream event buffer for recording generation sequence
self._stream_buffer = StreamEventBuffer()
self._seed_graph_runtime_state_from_queue_manager()
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
@ -496,7 +383,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
queue_message: Union[WorkflowQueueMessage, MessageQueueMessage] | None = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle text chunk events and record to stream buffer for sequence reconstruction."""
"""Handle text chunk events."""
delta_text = event.text
if delta_text is None:
return
@ -518,52 +405,9 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
if tts_publisher and queue_message:
tts_publisher.publish(queue_message)
tool_call = event.tool_call
tool_result = event.tool_result
tool_payload = tool_call or tool_result
tool_call_id = tool_payload.id if tool_payload and tool_payload.id else ""
tool_name = tool_payload.name if tool_payload and tool_payload.name else ""
tool_arguments = tool_call.arguments if tool_call and tool_call.arguments else ""
tool_files = tool_result.files if tool_result else []
tool_elapsed_time = tool_result.elapsed_time if tool_result else None
tool_icon = tool_payload.icon if tool_payload else None
tool_icon_dark = tool_payload.icon_dark if tool_payload else None
# Record stream event based on chunk type
chunk_type = event.chunk_type or ChunkType.TEXT
match chunk_type:
case ChunkType.TEXT:
self._stream_buffer.record_text_chunk(delta_text)
self._task_state.answer += delta_text
case ChunkType.THOUGHT:
# Reasoning should not be part of final answer text
self._stream_buffer.record_thought_chunk(delta_text)
case ChunkType.TOOL_CALL:
self._stream_buffer.record_tool_call(
tool_call_id=tool_call_id,
tool_name=tool_name,
tool_arguments=tool_arguments,
)
case ChunkType.TOOL_RESULT:
self._stream_buffer.record_tool_result(
tool_call_id=tool_call_id,
result=delta_text,
tool_elapsed_time=tool_elapsed_time,
)
self._task_state.answer += delta_text
case _:
pass
self._task_state.answer += delta_text
yield self._message_cycle_manager.message_to_stream_response(
answer=delta_text,
message_id=self._message_id,
from_variable_selector=event.from_variable_selector,
chunk_type=event.chunk_type.value if event.chunk_type else None,
tool_call_id=tool_call_id or None,
tool_name=tool_name or None,
tool_arguments=tool_arguments or None,
tool_files=tool_files,
tool_elapsed_time=tool_elapsed_time,
tool_icon=tool_icon,
tool_icon_dark=tool_icon_dark,
answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector
)
def _handle_iteration_start_event(
@ -931,7 +775,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
# If there are assistant files, remove markdown image links from answer
answer_text = self._task_state.answer
answer_text = self._strip_think_blocks(answer_text)
if self._recorded_files:
# Remove markdown image links since we're storing files separately
answer_text = re.sub(r"!\[.*?\]\(.*?\)", "", answer_text).strip()
@ -983,54 +826,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
]
session.add_all(message_files)
# Save generation detail (reasoning/tool calls/sequence) from stream buffer
self._save_generation_detail(session=session, message=message)
@staticmethod
def _strip_think_blocks(text: str) -> str:
"""Remove <think>...</think> blocks (including their content) from text."""
if not text or "<think" not in text.lower():
return text
clean_text = re.sub(r"<think[^>]*>.*?</think>", "", text, flags=re.IGNORECASE | re.DOTALL)
clean_text = re.sub(r"\n\s*\n", "\n\n", clean_text).strip()
return clean_text
def _save_generation_detail(self, *, session: Session, message: Message) -> None:
"""
Save LLM generation detail for Chatflow using stream event buffer.
The buffer records the exact order of events as they streamed,
allowing accurate reconstruction of the generation sequence.
"""
# Finalize the stream buffer to flush any pending data
self._stream_buffer.finalize()
# Only save if there's meaningful data
if not self._stream_buffer.has_data():
return
reasoning_content = self._stream_buffer.reasoning_content
tool_calls = self._stream_buffer.tool_calls
sequence = self._stream_buffer.sequence
# Check if generation detail already exists for this message
existing = session.query(LLMGenerationDetail).filter_by(message_id=message.id).first()
if existing:
existing.reasoning_content = json.dumps(reasoning_content) if reasoning_content else None
existing.tool_calls = json.dumps(tool_calls) if tool_calls else None
existing.sequence = json.dumps(sequence) if sequence else None
else:
generation_detail = LLMGenerationDetail(
tenant_id=self._application_generate_entity.app_config.tenant_id,
app_id=self._application_generate_entity.app_config.app_id,
message_id=message.id,
reasoning_content=json.dumps(reasoning_content) if reasoning_content else None,
tool_calls=json.dumps(tool_calls) if tool_calls else None,
sequence=json.dumps(sequence) if sequence else None,
)
session.add(generation_detail)
def _seed_graph_runtime_state_from_queue_manager(self) -> None:
"""Bootstrap the cached runtime state from the queue manager when present."""
candidate = self._base_task_pipeline.queue_manager.graph_runtime_state

View File

@ -3,8 +3,10 @@ from typing import cast
from sqlalchemy import select
from core.agent.agent_app_runner import AgentAppRunner
from core.agent.cot_chat_agent_runner import CotChatAgentRunner
from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner
from core.agent.entities import AgentEntity
from core.agent.fc_agent_runner import FunctionCallAgentRunner
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.base_app_runner import AppRunner
@ -12,7 +14,8 @@ from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity
from core.app.entities.queue_entities import QueueAnnotationReplyEvent
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.model_entities import ModelFeature
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.moderation.base import ModerationError
from extensions.ext_database import db
@ -191,7 +194,22 @@ class AgentChatAppRunner(AppRunner):
raise ValueError("Message not found")
db.session.close()
runner = AgentAppRunner(
runner_cls: type[FunctionCallAgentRunner] | type[CotChatAgentRunner] | type[CotCompletionAgentRunner]
# start agent runner
if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
# check LLM mode
if model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT:
runner_cls = CotChatAgentRunner
elif model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.COMPLETION:
runner_cls = CotCompletionAgentRunner
else:
raise ValueError(f"Invalid LLM mode: {model_schema.model_properties.get(ModelPropertyKey.MODE)}")
elif agent_entity.strategy == AgentEntity.Strategy.FUNCTION_CALLING:
runner_cls = FunctionCallAgentRunner
else:
raise ValueError(f"Invalid agent strategy: {agent_entity.strategy}")
runner = runner_cls(
tenant_id=app_config.tenant_id,
application_generate_entity=application_generate_entity,
conversation=conversation_result,

View File

@ -1,3 +1,4 @@
import json
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Union, final
@ -75,24 +76,12 @@ class BaseAppGenerator:
user_inputs = {**user_inputs, **files_inputs, **file_list_inputs}
# Check if all files are converted to File
invalid_dict_keys = [
k
for k, v in user_inputs.items()
if isinstance(v, dict)
and entity_dictionary[k].type not in {VariableEntityType.FILE, VariableEntityType.JSON_OBJECT}
]
if invalid_dict_keys:
raise ValueError(f"Invalid input type for {invalid_dict_keys}")
invalid_list_dict_keys = [
k
for k, v in user_inputs.items()
if isinstance(v, list)
and any(isinstance(item, dict) for item in v)
and entity_dictionary[k].type != VariableEntityType.FILE_LIST
]
if invalid_list_dict_keys:
raise ValueError(f"Invalid input type for {invalid_list_dict_keys}")
if any(filter(lambda v: isinstance(v, dict), user_inputs.values())):
raise ValueError("Invalid input type")
if any(
filter(lambda v: isinstance(v, dict), filter(lambda item: isinstance(item, list), user_inputs.values()))
):
raise ValueError("Invalid input type")
return user_inputs
@ -189,8 +178,12 @@ class BaseAppGenerator:
elif value == 0:
value = False
case VariableEntityType.JSON_OBJECT:
if value and not isinstance(value, dict):
raise ValueError(f"{variable_entity.variable} in input form must be a dict")
if not isinstance(value, str):
raise ValueError(f"{variable_entity.variable} in input form must be a string")
try:
json.loads(value)
except json.JSONDecodeError:
raise ValueError(f"{variable_entity.variable} in input form must be a valid JSON object")
case _:
raise AssertionError("this statement should be unreachable.")

View File

@ -671,7 +671,7 @@ class WorkflowResponseConverter:
task_id=task_id,
data=AgentLogStreamResponse.Data(
node_execution_id=event.node_execution_id,
message_id=event.id,
id=event.id,
parent_id=event.parent_id,
label=event.label,
error=event.error,

View File

@ -130,7 +130,7 @@ class PipelineGenerator(BaseAppGenerator):
pipeline=pipeline, workflow=workflow, start_node_id=start_node_id
)
documents: list[Document] = []
if invoke_from == InvokeFrom.PUBLISHED_PIPELINE and not is_retry and not args.get("original_document_id"):
if invoke_from == InvokeFrom.PUBLISHED and not is_retry and not args.get("original_document_id"):
from services.dataset_service import DocumentService
for datasource_info in datasource_info_list:
@ -156,7 +156,7 @@ class PipelineGenerator(BaseAppGenerator):
for i, datasource_info in enumerate(datasource_info_list):
workflow_run_id = str(uuid.uuid4())
document_id = args.get("original_document_id") or None
if invoke_from == InvokeFrom.PUBLISHED_PIPELINE and not is_retry:
if invoke_from == InvokeFrom.PUBLISHED and not is_retry:
document_id = document_id or documents[i].id
document_pipeline_execution_log = DocumentPipelineExecutionLog(
document_id=document_id,

View File

@ -73,15 +73,9 @@ class PipelineRunner(WorkflowBasedAppRunner):
"""
app_config = self.application_generate_entity.app_config
app_config = cast(PipelineConfig, app_config)
invoke_from = self.application_generate_entity.invoke_from
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
invoke_from = InvokeFrom.DEBUGGER
user_from = self._resolve_user_from(invoke_from)
user_id = None
if invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
end_user = db.session.query(EndUser).where(EndUser.id == self.application_generate_entity.user_id).first()
if end_user:
user_id = end_user.session_id
@ -123,7 +117,7 @@ class PipelineRunner(WorkflowBasedAppRunner):
dataset_id=self.application_generate_entity.dataset_id,
datasource_type=self.application_generate_entity.datasource_type,
datasource_info=self.application_generate_entity.datasource_info,
invoke_from=invoke_from.value,
invoke_from=self.application_generate_entity.invoke_from.value,
)
rag_pipeline_variables = []
@ -155,8 +149,6 @@ class PipelineRunner(WorkflowBasedAppRunner):
graph_runtime_state=graph_runtime_state,
start_node_id=self.application_generate_entity.start_node_id,
workflow=workflow,
user_from=user_from,
invoke_from=invoke_from,
)
# RUN WORKFLOW
@ -167,8 +159,12 @@ class PipelineRunner(WorkflowBasedAppRunner):
graph=graph,
graph_config=workflow.graph_dict,
user_id=self.application_generate_entity.user_id,
user_from=user_from,
invoke_from=invoke_from,
user_from=(
UserFrom.ACCOUNT
if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
else UserFrom.END_USER
),
invoke_from=self.application_generate_entity.invoke_from,
call_depth=self.application_generate_entity.call_depth,
graph_runtime_state=graph_runtime_state,
variable_pool=variable_pool,
@ -214,12 +210,7 @@ class PipelineRunner(WorkflowBasedAppRunner):
return workflow
def _init_rag_pipeline_graph(
self,
workflow: Workflow,
graph_runtime_state: GraphRuntimeState,
start_node_id: str | None = None,
user_from: UserFrom = UserFrom.ACCOUNT,
invoke_from: InvokeFrom = InvokeFrom.SERVICE_API,
self, workflow: Workflow, graph_runtime_state: GraphRuntimeState, start_node_id: str | None = None
) -> Graph:
"""
Init pipeline graph
@ -262,8 +253,8 @@ class PipelineRunner(WorkflowBasedAppRunner):
workflow_id=workflow.id,
graph_config=graph_config,
user_id=self.application_generate_entity.user_id,
user_from=user_from,
invoke_from=invoke_from,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
)

View File

@ -20,6 +20,7 @@ from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_redis import redis_client
from extensions.otel import WorkflowAppRunnerHandler, trace_span
from libs.datetime_utils import naive_utc_now
from models.enums import UserFrom
from models.workflow import Workflow
logger = logging.getLogger(__name__)
@ -73,12 +74,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
workflow_execution_id=self.application_generate_entity.workflow_execution_id,
)
invoke_from = self.application_generate_entity.invoke_from
# if only single iteration or single loop run is requested
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
invoke_from = InvokeFrom.DEBUGGER
user_from = self._resolve_user_from(invoke_from)
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
workflow=self._workflow,
@ -106,8 +102,6 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
workflow_id=self._workflow.id,
tenant_id=self._workflow.tenant_id,
user_id=self.application_generate_entity.user_id,
user_from=user_from,
invoke_from=invoke_from,
root_node_id=self._root_node_id,
)
@ -126,8 +120,12 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
graph=graph,
graph_config=self._workflow.graph_dict,
user_id=self.application_generate_entity.user_id,
user_from=user_from,
invoke_from=invoke_from,
user_from=(
UserFrom.ACCOUNT
if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
else UserFrom.END_USER
),
invoke_from=self.application_generate_entity.invoke_from,
call_depth=self.application_generate_entity.call_depth,
variable_pool=variable_pool,
graph_runtime_state=graph_runtime_state,

View File

@ -13,7 +13,6 @@ from core.app.apps.common.workflow_response_converter import WorkflowResponseCon
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.queue_entities import (
AppQueueEvent,
ChunkType,
MessageQueueMessage,
QueueAgentLogEvent,
QueueErrorEvent,
@ -484,33 +483,11 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
if delta_text is None:
return
tool_call = event.tool_call
tool_result = event.tool_result
tool_payload = tool_call or tool_result
tool_call_id = tool_payload.id if tool_payload and tool_payload.id else None
tool_name = tool_payload.name if tool_payload and tool_payload.name else None
tool_arguments = tool_call.arguments if tool_call else None
tool_elapsed_time = tool_result.elapsed_time if tool_result else None
tool_files = tool_result.files if tool_result else []
tool_icon = tool_payload.icon if tool_payload else None
tool_icon_dark = tool_payload.icon_dark if tool_payload else None
# only publish tts message at text chunk streaming
if tts_publisher and queue_message:
tts_publisher.publish(queue_message)
yield self._text_chunk_to_stream_response(
text=delta_text,
from_variable_selector=event.from_variable_selector,
chunk_type=event.chunk_type,
tool_call_id=tool_call_id,
tool_name=tool_name,
tool_arguments=tool_arguments,
tool_files=tool_files,
tool_elapsed_time=tool_elapsed_time,
tool_icon=tool_icon,
tool_icon_dark=tool_icon_dark,
)
yield self._text_chunk_to_stream_response(delta_text, from_variable_selector=event.from_variable_selector)
def _handle_agent_log_event(self, event: QueueAgentLogEvent, **kwargs) -> Generator[StreamResponse, None, None]:
"""Handle agent log events."""
@ -673,61 +650,16 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
session.add(workflow_app_log)
def _text_chunk_to_stream_response(
self,
text: str,
from_variable_selector: list[str] | None = None,
chunk_type: ChunkType | None = None,
tool_call_id: str | None = None,
tool_name: str | None = None,
tool_arguments: str | None = None,
tool_files: list[str] | None = None,
tool_error: str | None = None,
tool_elapsed_time: float | None = None,
tool_icon: str | dict | None = None,
tool_icon_dark: str | dict | None = None,
self, text: str, from_variable_selector: list[str] | None = None
) -> TextChunkStreamResponse:
"""
Handle completed event.
:param text: text
:return:
"""
from core.app.entities.task_entities import ChunkType as ResponseChunkType
response_chunk_type = ResponseChunkType(chunk_type.value) if chunk_type else ResponseChunkType.TEXT
data = TextChunkStreamResponse.Data(
text=text,
from_variable_selector=from_variable_selector,
chunk_type=response_chunk_type,
)
if response_chunk_type == ResponseChunkType.TOOL_CALL:
data = data.model_copy(
update={
"tool_call_id": tool_call_id,
"tool_name": tool_name,
"tool_arguments": tool_arguments,
"tool_icon": tool_icon,
"tool_icon_dark": tool_icon_dark,
}
)
elif response_chunk_type == ResponseChunkType.TOOL_RESULT:
data = data.model_copy(
update={
"tool_call_id": tool_call_id,
"tool_name": tool_name,
"tool_arguments": tool_arguments,
"tool_files": tool_files,
"tool_error": tool_error,
"tool_elapsed_time": tool_elapsed_time,
"tool_icon": tool_icon,
"tool_icon_dark": tool_icon_dark,
}
)
response = TextChunkStreamResponse(
task_id=self._application_generate_entity.task_id,
data=data,
data=TextChunkStreamResponse.Data(text=text, from_variable_selector=from_variable_selector),
)
return response

View File

@ -77,18 +77,10 @@ class WorkflowBasedAppRunner:
self._app_id = app_id
self._graph_engine_layers = graph_engine_layers
@staticmethod
def _resolve_user_from(invoke_from: InvokeFrom) -> UserFrom:
if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}:
return UserFrom.ACCOUNT
return UserFrom.END_USER
def _init_graph(
self,
graph_config: Mapping[str, Any],
graph_runtime_state: GraphRuntimeState,
user_from: UserFrom,
invoke_from: InvokeFrom,
workflow_id: str = "",
tenant_id: str = "",
user_id: str = "",
@ -113,8 +105,8 @@ class WorkflowBasedAppRunner:
workflow_id=workflow_id,
graph_config=graph_config,
user_id=user_id,
user_from=user_from,
invoke_from=invoke_from,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
)
@ -258,7 +250,7 @@ class WorkflowBasedAppRunner:
graph_config=graph_config,
user_id="",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
)
@ -463,20 +455,12 @@ class WorkflowBasedAppRunner:
)
)
elif isinstance(event, NodeRunStreamChunkEvent):
from core.app.entities.queue_entities import ChunkType as QueueChunkType
if event.is_final and not event.chunk:
return
self._publish_event(
QueueTextChunkEvent(
text=event.chunk,
from_variable_selector=list(event.selector),
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
chunk_type=QueueChunkType(event.chunk_type.value),
tool_call=event.tool_call,
tool_result=event.tool_result,
)
)
elif isinstance(event, NodeRunRetrieverResourceEvent):

View File

@ -42,8 +42,7 @@ class InvokeFrom(StrEnum):
# DEBUGGER indicates that this invocation is from
# the workflow (or chatflow) edit page.
DEBUGGER = "debugger"
# PUBLISHED_PIPELINE indicates that this invocation runs a published RAG pipeline workflow.
PUBLISHED_PIPELINE = "published"
PUBLISHED = "published"
# VALIDATION indicates that this invocation is from validation.
VALIDATION = "validation"

View File

@ -1,70 +0,0 @@
"""
LLM Generation Detail entities.
Defines the structure for storing and transmitting LLM generation details
including reasoning content, tool calls, and their sequence.
"""
from typing import Literal
from pydantic import BaseModel, Field
class ContentSegment(BaseModel):
"""Represents a content segment in the generation sequence."""
type: Literal["content"] = "content"
start: int = Field(..., description="Start position in the text")
end: int = Field(..., description="End position in the text")
class ReasoningSegment(BaseModel):
"""Represents a reasoning segment in the generation sequence."""
type: Literal["reasoning"] = "reasoning"
index: int = Field(..., description="Index into reasoning_content array")
class ToolCallSegment(BaseModel):
"""Represents a tool call segment in the generation sequence."""
type: Literal["tool_call"] = "tool_call"
index: int = Field(..., description="Index into tool_calls array")
SequenceSegment = ContentSegment | ReasoningSegment | ToolCallSegment
class ToolCallDetail(BaseModel):
"""Represents a tool call with its arguments and result."""
id: str = Field(default="", description="Unique identifier for the tool call")
name: str = Field(..., description="Name of the tool")
arguments: str = Field(default="", description="JSON string of tool arguments")
result: str = Field(default="", description="Result from the tool execution")
elapsed_time: float | None = Field(default=None, description="Elapsed time in seconds")
class LLMGenerationDetailData(BaseModel):
"""
Domain model for LLM generation detail.
Contains the structured data for reasoning content, tool calls,
and their display sequence.
"""
reasoning_content: list[str] = Field(default_factory=list, description="List of reasoning segments")
tool_calls: list[ToolCallDetail] = Field(default_factory=list, description="List of tool call details")
sequence: list[SequenceSegment] = Field(default_factory=list, description="Display order of segments")
def is_empty(self) -> bool:
"""Check if there's any meaningful generation detail."""
return not self.reasoning_content and not self.tool_calls
def to_response_dict(self) -> dict:
"""Convert to dictionary for API response."""
return {
"reasoning_content": self.reasoning_content,
"tool_calls": [tc.model_dump() for tc in self.tool_calls],
"sequence": [seg.model_dump() for seg in self.sequence],
}

View File

@ -7,7 +7,7 @@ from pydantic import BaseModel, ConfigDict, Field
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities import AgentNodeStrategyInit, ToolCall, ToolResult
from core.workflow.entities import AgentNodeStrategyInit
from core.workflow.enums import WorkflowNodeExecutionMetadataKey
from core.workflow.nodes import NodeType
@ -177,17 +177,6 @@ class QueueLoopCompletedEvent(AppQueueEvent):
error: str | None = None
class ChunkType(StrEnum):
"""Stream chunk type for LLM-related events."""
TEXT = "text" # Normal text streaming
TOOL_CALL = "tool_call" # Tool call arguments streaming
TOOL_RESULT = "tool_result" # Tool execution result
THOUGHT = "thought" # Agent thinking process (ReAct)
THOUGHT_START = "thought_start" # Agent thought start
THOUGHT_END = "thought_end" # Agent thought end
class QueueTextChunkEvent(AppQueueEvent):
"""
QueueTextChunkEvent entity
@ -202,16 +191,6 @@ class QueueTextChunkEvent(AppQueueEvent):
in_loop_id: str | None = None
"""loop id if node is in loop"""
# Extended fields for Agent/Tool streaming
chunk_type: ChunkType = ChunkType.TEXT
"""type of the chunk"""
# Tool streaming payloads
tool_call: ToolCall | None = None
"""structured tool call info"""
tool_result: ToolResult | None = None
"""structured tool result info"""
class QueueAgentMessageEvent(AppQueueEvent):
"""

View File

@ -113,38 +113,6 @@ class MessageStreamResponse(StreamResponse):
answer: str
from_variable_selector: list[str] | None = None
# Extended fields for Agent/Tool streaming (imported at runtime to avoid circular import)
chunk_type: str | None = None
"""type of the chunk: text, tool_call, tool_result, thought"""
# Tool call fields (when chunk_type == "tool_call")
tool_call_id: str | None = None
"""unique identifier for this tool call"""
tool_name: str | None = None
"""name of the tool being called"""
tool_arguments: str | None = None
"""accumulated tool arguments JSON"""
# Tool result fields (when chunk_type == "tool_result")
tool_files: list[str] | None = None
"""file IDs produced by tool"""
tool_error: str | None = None
"""error message if tool failed"""
tool_elapsed_time: float | None = None
"""elapsed time spent executing the tool"""
tool_icon: str | dict | None = None
"""icon of the tool"""
tool_icon_dark: str | dict | None = None
"""dark theme icon of the tool"""
def model_dump(self, *args, **kwargs) -> dict[str, object]:
kwargs.setdefault("exclude_none", True)
return super().model_dump(*args, **kwargs)
def model_dump_json(self, *args, **kwargs) -> str:
kwargs.setdefault("exclude_none", True)
return super().model_dump_json(*args, **kwargs)
class MessageAudioStreamResponse(StreamResponse):
"""
@ -614,17 +582,6 @@ class LoopNodeCompletedStreamResponse(StreamResponse):
data: Data
class ChunkType(StrEnum):
"""Stream chunk type for LLM-related events."""
TEXT = "text" # Normal text streaming
TOOL_CALL = "tool_call" # Tool call arguments streaming
TOOL_RESULT = "tool_result" # Tool execution result
THOUGHT = "thought" # Agent thinking process (ReAct)
THOUGHT_START = "thought_start" # Agent thought start
THOUGHT_END = "thought_end" # Agent thought end
class TextChunkStreamResponse(StreamResponse):
"""
TextChunkStreamResponse entity
@ -638,36 +595,6 @@ class TextChunkStreamResponse(StreamResponse):
text: str
from_variable_selector: list[str] | None = None
# Extended fields for Agent/Tool streaming
chunk_type: ChunkType = ChunkType.TEXT
"""type of the chunk"""
# Tool call fields (when chunk_type == TOOL_CALL)
tool_call_id: str | None = None
"""unique identifier for this tool call"""
tool_name: str | None = None
"""name of the tool being called"""
tool_arguments: str | None = None
"""accumulated tool arguments JSON"""
# Tool result fields (when chunk_type == TOOL_RESULT)
tool_files: list[str] | None = None
"""file IDs produced by tool"""
tool_error: str | None = None
"""error message if tool failed"""
# Tool elapsed time fields (when chunk_type == TOOL_RESULT)
tool_elapsed_time: float | None = None
"""elapsed time spent executing the tool"""
def model_dump(self, *args, **kwargs) -> dict[str, object]:
kwargs.setdefault("exclude_none", True)
return super().model_dump(*args, **kwargs)
def model_dump_json(self, *args, **kwargs) -> str:
kwargs.setdefault("exclude_none", True)
return super().model_dump_json(*args, **kwargs)
event: StreamEvent = StreamEvent.TEXT_CHUNK
data: Data
@ -816,7 +743,7 @@ class AgentLogStreamResponse(StreamResponse):
"""
node_execution_id: str
message_id: str
id: str
label: str
parent_id: str | None = None
error: str | None = None

View File

@ -75,7 +75,7 @@ class AnnotationReplyFeature:
AppAnnotationService.add_annotation_history(
annotation.id,
app_record.id,
annotation.question_text,
annotation.question,
annotation.content,
query,
user_id,

View File

@ -1,60 +0,0 @@
import logging
from core.variables import Variable
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
from core.workflow.enums import NodeType
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_events import GraphEngineEvent, NodeRunSucceededEvent
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
logger = logging.getLogger(__name__)
class ConversationVariablePersistenceLayer(GraphEngineLayer):
def __init__(self, conversation_variable_updater: ConversationVariableUpdater) -> None:
super().__init__()
self._conversation_variable_updater = conversation_variable_updater
def on_graph_start(self) -> None:
pass
def on_event(self, event: GraphEngineEvent) -> None:
if not isinstance(event, NodeRunSucceededEvent):
return
if event.node_type != NodeType.VARIABLE_ASSIGNER:
return
if self.graph_runtime_state is None:
return
updated_variables = common_helpers.get_updated_variables(event.node_run_result.process_data) or []
if not updated_variables:
return
conversation_id = self.graph_runtime_state.system_variable.conversation_id
if conversation_id is None:
return
updated_any = False
for item in updated_variables:
selector = item.selector
if len(selector) < 2:
logger.warning("Conversation variable selector invalid. selector=%s", selector)
continue
if selector[0] != CONVERSATION_VARIABLE_NODE_ID:
continue
variable = self.graph_runtime_state.variable_pool.get(selector)
if not isinstance(variable, Variable):
logger.warning(
"Conversation variable not found in variable pool. selector=%s",
selector,
)
continue
self._conversation_variable_updater.update(conversation_id=conversation_id, variable=variable)
updated_any = True
if updated_any:
self._conversation_variable_updater.flush()
def on_graph_end(self, error: Exception | None) -> None:
pass

View File

@ -66,7 +66,6 @@ class PauseStatePersistenceLayer(GraphEngineLayer):
"""
if isinstance(session_factory, Engine):
session_factory = sessionmaker(session_factory)
super().__init__()
self._session_maker = session_factory
self._state_owner_user_id = state_owner_user_id
self._generate_entity = generate_entity
@ -99,6 +98,8 @@ class PauseStatePersistenceLayer(GraphEngineLayer):
if not isinstance(event, GraphRunPausedEvent):
return
assert self.graph_runtime_state is not None
entity_wrapper: _GenerateEntityUnion
if isinstance(self._generate_entity, WorkflowAppGenerateEntity):
entity_wrapper = _WorkflowGenerateEntityWrapper(entity=self._generate_entity)

View File

@ -33,7 +33,6 @@ class TriggerPostLayer(GraphEngineLayer):
trigger_log_id: str,
session_maker: sessionmaker[Session],
):
super().__init__()
self.trigger_log_id = trigger_log_id
self.start_time = start_time
self.cfs_plan_scheduler_entity = cfs_plan_scheduler_entity
@ -58,6 +57,10 @@ class TriggerPostLayer(GraphEngineLayer):
elapsed_time = (datetime.now(UTC) - self.start_time).total_seconds()
# Extract relevant data from result
if not self.graph_runtime_state:
logger.exception("Graph runtime state is not set")
return
outputs = self.graph_runtime_state.outputs
# BASICLY, workflow_execution_id is the same as workflow_run_id

View File

@ -1,5 +1,4 @@
import logging
import re
import time
from collections.abc import Generator
from threading import Thread
@ -59,7 +58,7 @@ from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from events.message_event import message_was_created
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.model import AppMode, Conversation, LLMGenerationDetail, Message, MessageAgentThought
from models.model import AppMode, Conversation, Message, MessageAgentThought
logger = logging.getLogger(__name__)
@ -69,8 +68,6 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
EasyUIBasedGenerateTaskPipeline is a class that generate stream output and state management for Application.
"""
_THINK_PATTERN = re.compile(r"<think[^>]*>(.*?)</think>", re.IGNORECASE | re.DOTALL)
_task_state: EasyUITaskState
_application_generate_entity: Union[ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity]
@ -412,136 +409,11 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
)
)
# Save LLM generation detail if there's reasoning_content
self._save_generation_detail(session=session, message=message, llm_result=llm_result)
message_was_created.send(
message,
application_generate_entity=self._application_generate_entity,
)
def _save_generation_detail(self, *, session: Session, message: Message, llm_result: LLMResult) -> None:
"""
Save LLM generation detail for Completion/Chat/Agent-Chat applications.
For Agent-Chat, also merges MessageAgentThought records.
"""
import json
reasoning_list: list[str] = []
tool_calls_list: list[dict] = []
sequence: list[dict] = []
answer = message.answer or ""
# Check if this is Agent-Chat mode by looking for agent thoughts
agent_thoughts = (
session.query(MessageAgentThought)
.filter_by(message_id=message.id)
.order_by(MessageAgentThought.position.asc())
.all()
)
if agent_thoughts:
# Agent-Chat mode: merge MessageAgentThought records
content_pos = 0
cleaned_answer_parts: list[str] = []
for thought in agent_thoughts:
# Add thought/reasoning
if thought.thought:
reasoning_text = thought.thought
if "<think" in reasoning_text.lower():
clean_text, extracted_reasoning = self._split_reasoning_from_answer(reasoning_text)
if extracted_reasoning:
reasoning_text = extracted_reasoning
thought.thought = clean_text or extracted_reasoning
reasoning_list.append(reasoning_text)
sequence.append({"type": "reasoning", "index": len(reasoning_list) - 1})
# Add tool calls
if thought.tool:
tool_calls_list.append(
{
"name": thought.tool,
"arguments": thought.tool_input or "",
"result": thought.observation or "",
}
)
sequence.append({"type": "tool_call", "index": len(tool_calls_list) - 1})
# Add answer content if present
if thought.answer:
content_text = thought.answer
if "<think" in content_text.lower():
clean_answer, extracted_reasoning = self._split_reasoning_from_answer(content_text)
if extracted_reasoning:
reasoning_list.append(extracted_reasoning)
sequence.append({"type": "reasoning", "index": len(reasoning_list) - 1})
content_text = clean_answer
thought.answer = clean_answer or content_text
if content_text:
start = content_pos
end = content_pos + len(content_text)
sequence.append({"type": "content", "start": start, "end": end})
content_pos = end
cleaned_answer_parts.append(content_text)
if cleaned_answer_parts:
merged_answer = "".join(cleaned_answer_parts)
message.answer = merged_answer
llm_result.message.content = merged_answer
else:
# Completion/Chat mode: use reasoning_content from llm_result
reasoning_content = llm_result.reasoning_content
if not reasoning_content and answer:
# Extract reasoning from <think> blocks and clean the final answer
clean_answer, reasoning_content = self._split_reasoning_from_answer(answer)
if reasoning_content:
answer = clean_answer
llm_result.message.content = clean_answer
llm_result.reasoning_content = reasoning_content
message.answer = clean_answer
if reasoning_content:
reasoning_list = [reasoning_content]
# Content comes first, then reasoning
if answer:
sequence.append({"type": "content", "start": 0, "end": len(answer)})
sequence.append({"type": "reasoning", "index": 0})
# Only save if there's meaningful generation detail
if not reasoning_list and not tool_calls_list:
return
# Check if generation detail already exists
existing = session.query(LLMGenerationDetail).filter_by(message_id=message.id).first()
if existing:
existing.reasoning_content = json.dumps(reasoning_list) if reasoning_list else None
existing.tool_calls = json.dumps(tool_calls_list) if tool_calls_list else None
existing.sequence = json.dumps(sequence) if sequence else None
else:
generation_detail = LLMGenerationDetail(
tenant_id=self._application_generate_entity.app_config.tenant_id,
app_id=self._application_generate_entity.app_config.app_id,
message_id=message.id,
reasoning_content=json.dumps(reasoning_list) if reasoning_list else None,
tool_calls=json.dumps(tool_calls_list) if tool_calls_list else None,
sequence=json.dumps(sequence) if sequence else None,
)
session.add(generation_detail)
@classmethod
def _split_reasoning_from_answer(cls, text: str) -> tuple[str, str]:
"""
Extract reasoning segments from <think> blocks and return (clean_text, reasoning).
"""
matches = cls._THINK_PATTERN.findall(text)
reasoning_content = "\n".join(match.strip() for match in matches) if matches else ""
clean_text = cls._THINK_PATTERN.sub("", text)
clean_text = re.sub(r"\n\s*\n", "\n\n", clean_text).strip()
return clean_text, reasoning_content or ""
def _handle_stop(self, event: QueueStopEvent):
"""
Handle stop.

View File

@ -232,31 +232,15 @@ class MessageCycleManager:
answer: str,
message_id: str,
from_variable_selector: list[str] | None = None,
chunk_type: str | None = None,
tool_call_id: str | None = None,
tool_name: str | None = None,
tool_arguments: str | None = None,
tool_files: list[str] | None = None,
tool_error: str | None = None,
tool_elapsed_time: float | None = None,
tool_icon: str | dict | None = None,
tool_icon_dark: str | dict | None = None,
event_type: StreamEvent | None = None,
) -> MessageStreamResponse:
"""
Message to stream response.
:param answer: answer
:param message_id: message id
:param from_variable_selector: from variable selector
:param chunk_type: type of the chunk (text, function_call, tool_result, thought)
:param tool_call_id: unique identifier for this tool call
:param tool_name: name of the tool being called
:param tool_arguments: accumulated tool arguments JSON
:param tool_files: file IDs produced by tool
:param tool_error: error message if tool failed
:return:
"""
response = MessageStreamResponse(
return MessageStreamResponse(
task_id=self._application_generate_entity.task_id,
id=message_id,
answer=answer,
@ -264,35 +248,6 @@ class MessageCycleManager:
event=event_type or StreamEvent.MESSAGE,
)
if chunk_type:
response = response.model_copy(update={"chunk_type": chunk_type})
if chunk_type == "tool_call":
response = response.model_copy(
update={
"tool_call_id": tool_call_id,
"tool_name": tool_name,
"tool_arguments": tool_arguments,
"tool_icon": tool_icon,
"tool_icon_dark": tool_icon_dark,
}
)
elif chunk_type == "tool_result":
response = response.model_copy(
update={
"tool_call_id": tool_call_id,
"tool_name": tool_name,
"tool_arguments": tool_arguments,
"tool_files": tool_files,
"tool_error": tool_error,
"tool_elapsed_time": tool_elapsed_time,
"tool_icon": tool_icon,
"tool_icon_dark": tool_icon_dark,
}
)
return response
def message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse:
"""
Message replace to stream response.

Some files were not shown because too many files have changed in this diff Show More