Compare commits

..

18 Commits

Author SHA1 Message Date
c0c64c75d6 chore: upgrade fickling to 0.1.6 (#30495) 2026-01-16 02:30:04 -08:00
9b0984eab7 chore: Harden API image Node.js runtime install (#30497) 2026-01-16 02:29:36 -08:00
83a943d8c4 build: require node 24.13.0 (#30945) (#31027)
Co-authored-by: Stephen Zhou <38493346+hyoban@users.noreply.github.com>
2026-01-15 11:47:27 +08:00
c21f7c48fb fix: update permission in member list caused page crash (#30164) 2026-01-12 18:45:28 -08:00
8371a26cdf fix: reinstall packages, rebuild pnpm-lock.yaml file 2026-01-12 15:42:39 +08:00
a6e43f0fa0 build: limit esbuild, glob, docker base version to avoid cve (#30848) 2026-01-12 15:35:46 +08:00
30f9199fba Merge remote-tracking branch 'origin/hotfix/1.11.2-fix.3' into release/e-1.11.2 2026-01-12 13:05:04 +08:00
b47afdd314 Revert "feat: implement workspace permission checks for member invitations and owner transfer"
This reverts commit 248871fca1.
2026-01-11 20:17:55 -08:00
fc81605ae8 feat: add queue credential sync when tenant created
- Add queue credential sync functionality when tenant is created
- Replace FeatureService with dify_config for enterprise feature check
- Improve logging format in WorkspaceSyncService
- Update timestamp creation to use UTC
- Simplify tenant creation event emission by removing unnecessary source parameter
2026-01-11 18:43:34 -08:00
248871fca1 feat: implement workspace permission checks for member invitations and owner transfer 2026-01-11 18:42:46 -08:00
4e0d3c224f fix: web app login code encrypt (#30705) 2026-01-08 15:39:37 +08:00
c9858f851f feat: add decryption decorators for password and code fields in login API (#30680) 2026-01-07 23:24:10 -08:00
c17052b8b4 fix: create from template permission set error 2026-01-07 15:57:23 +08:00
70571b53ad fix: use query param for delete method (#30206) 2025-12-29 21:48:54 -08:00
44ef3cc27d fix multimodal embedding retrival test 2025-12-26 17:30:51 +08:00
676063890c fix multimodal embedding retrival test 2025-12-26 17:05:37 +08:00
901cc64ac9 fix multimodal embedding retrival test 2025-12-26 17:04:46 +08:00
894a3c03a2 fix: load i18n on server (#30171) 2025-12-26 10:30:27 +08:00
175 changed files with 3958 additions and 11499 deletions

View File

@ -1,8 +0,0 @@
{
"enabledPlugins": {
"feature-dev@claude-plugins-official": true,
"context7@claude-plugins-official": true,
"typescript-lsp@claude-plugins-official": true,
"pyright-lsp@claude-plugins-official": true
}
}

View File

@ -0,0 +1,19 @@
{
"permissions": {
"allow": [],
"deny": []
},
"env": {
"__comment": "Environment variables for MCP servers. Override in .claude/settings.local.json with actual values.",
"GITHUB_PERSONAL_ACCESS_TOKEN": "ghp_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
},
"enabledMcpjsonServers": [
"context7",
"sequential-thinking",
"github",
"fetch",
"playwright",
"ide"
],
"enableAllProjectMcpServers": true
}

View File

@ -1,483 +0,0 @@
---
name: component-refactoring
description: Refactor high-complexity React components in Dify frontend. Use when `pnpm analyze-component --json` shows complexity > 50 or lineCount > 300, when the user asks for code splitting, hook extraction, or complexity reduction, or when `pnpm analyze-component` warns to refactor before testing; avoid for simple/well-structured components, third-party wrappers, or when the user explicitly wants testing without refactoring.
---
# Dify Component Refactoring Skill
Refactor high-complexity React components in the Dify frontend codebase with the patterns and workflow below.
> **Complexity Threshold**: Components with complexity > 50 (measured by `pnpm analyze-component`) should be refactored before testing.
## Quick Reference
### Commands (run from `web/`)
Use paths relative to `web/` (e.g., `app/components/...`).
Use `refactor-component` for refactoring prompts and `analyze-component` for testing prompts and metrics.
```bash
cd web
# Generate refactoring prompt
pnpm refactor-component <path>
# Output refactoring analysis as JSON
pnpm refactor-component <path> --json
# Generate testing prompt (after refactoring)
pnpm analyze-component <path>
# Output testing analysis as JSON
pnpm analyze-component <path> --json
```
### Complexity Analysis
```bash
# Analyze component complexity
pnpm analyze-component <path> --json
# Key metrics to check:
# - complexity: normalized score 0-100 (target < 50)
# - maxComplexity: highest single function complexity
# - lineCount: total lines (target < 300)
```
### Complexity Score Interpretation
| Score | Level | Action |
|-------|-------|--------|
| 0-25 | 🟢 Simple | Ready for testing |
| 26-50 | 🟡 Medium | Consider minor refactoring |
| 51-75 | 🟠 Complex | **Refactor before testing** |
| 76-100 | 🔴 Very Complex | **Must refactor** |
## Core Refactoring Patterns
### Pattern 1: Extract Custom Hooks
**When**: Component has complex state management, multiple `useState`/`useEffect`, or business logic mixed with UI.
**Dify Convention**: Place hooks in a `hooks/` subdirectory or alongside the component as `use-<feature>.ts`.
```typescript
// ❌ Before: Complex state logic in component
const Configuration: FC = () => {
const [modelConfig, setModelConfig] = useState<ModelConfig>(...)
const [datasetConfigs, setDatasetConfigs] = useState<DatasetConfigs>(...)
const [completionParams, setCompletionParams] = useState<FormValue>({})
// 50+ lines of state management logic...
return <div>...</div>
}
// ✅ After: Extract to custom hook
// hooks/use-model-config.ts
export const useModelConfig = (appId: string) => {
const [modelConfig, setModelConfig] = useState<ModelConfig>(...)
const [completionParams, setCompletionParams] = useState<FormValue>({})
// Related state management logic here
return { modelConfig, setModelConfig, completionParams, setCompletionParams }
}
// Component becomes cleaner
const Configuration: FC = () => {
const { modelConfig, setModelConfig } = useModelConfig(appId)
return <div>...</div>
}
```
**Dify Examples**:
- `web/app/components/app/configuration/hooks/use-advanced-prompt-config.ts`
- `web/app/components/app/configuration/debug/hooks.tsx`
- `web/app/components/workflow/hooks/use-workflow.ts`
### Pattern 2: Extract Sub-Components
**When**: Single component has multiple UI sections, conditional rendering blocks, or repeated patterns.
**Dify Convention**: Place sub-components in subdirectories or as separate files in the same directory.
```typescript
// ❌ Before: Monolithic JSX with multiple sections
const AppInfo = () => {
return (
<div>
{/* 100 lines of header UI */}
{/* 100 lines of operations UI */}
{/* 100 lines of modals */}
</div>
)
}
// ✅ After: Split into focused components
// app-info/
// ├── index.tsx (orchestration only)
// ├── app-header.tsx (header UI)
// ├── app-operations.tsx (operations UI)
// └── app-modals.tsx (modal management)
const AppInfo = () => {
const { showModal, setShowModal } = useAppInfoModals()
return (
<div>
<AppHeader appDetail={appDetail} />
<AppOperations onAction={handleAction} />
<AppModals show={showModal} onClose={() => setShowModal(null)} />
</div>
)
}
```
**Dify Examples**:
- `web/app/components/app/configuration/` directory structure
- `web/app/components/workflow/nodes/` per-node organization
### Pattern 3: Simplify Conditional Logic
**When**: Deep nesting (> 3 levels), complex ternaries, or multiple `if/else` chains.
```typescript
// ❌ Before: Deeply nested conditionals
const Template = useMemo(() => {
if (appDetail?.mode === AppModeEnum.CHAT) {
switch (locale) {
case LanguagesSupported[1]:
return <TemplateChatZh />
case LanguagesSupported[7]:
return <TemplateChatJa />
default:
return <TemplateChatEn />
}
}
if (appDetail?.mode === AppModeEnum.ADVANCED_CHAT) {
// Another 15 lines...
}
// More conditions...
}, [appDetail, locale])
// ✅ After: Use lookup tables + early returns
const TEMPLATE_MAP = {
[AppModeEnum.CHAT]: {
[LanguagesSupported[1]]: TemplateChatZh,
[LanguagesSupported[7]]: TemplateChatJa,
default: TemplateChatEn,
},
[AppModeEnum.ADVANCED_CHAT]: {
[LanguagesSupported[1]]: TemplateAdvancedChatZh,
// ...
},
}
const Template = useMemo(() => {
const modeTemplates = TEMPLATE_MAP[appDetail?.mode]
if (!modeTemplates) return null
const TemplateComponent = modeTemplates[locale] || modeTemplates.default
return <TemplateComponent appDetail={appDetail} />
}, [appDetail, locale])
```
### Pattern 4: Extract API/Data Logic
**When**: Component directly handles API calls, data transformation, or complex async operations.
**Dify Convention**: Use `@tanstack/react-query` hooks from `web/service/use-*.ts` or create custom data hooks.
```typescript
// ❌ Before: API logic in component
const MCPServiceCard = () => {
const [basicAppConfig, setBasicAppConfig] = useState({})
useEffect(() => {
if (isBasicApp && appId) {
(async () => {
const res = await fetchAppDetail({ url: '/apps', id: appId })
setBasicAppConfig(res?.model_config || {})
})()
}
}, [appId, isBasicApp])
// More API-related logic...
}
// ✅ After: Extract to data hook using React Query
// use-app-config.ts
import { useQuery } from '@tanstack/react-query'
import { get } from '@/service/base'
const NAME_SPACE = 'appConfig'
export const useAppConfig = (appId: string, isBasicApp: boolean) => {
return useQuery({
enabled: isBasicApp && !!appId,
queryKey: [NAME_SPACE, 'detail', appId],
queryFn: () => get<AppDetailResponse>(`/apps/${appId}`),
select: data => data?.model_config || {},
})
}
// Component becomes cleaner
const MCPServiceCard = () => {
const { data: config, isLoading } = useAppConfig(appId, isBasicApp)
// UI only
}
```
**React Query Best Practices in Dify**:
- Define `NAME_SPACE` for query key organization
- Use `enabled` option for conditional fetching
- Use `select` for data transformation
- Export invalidation hooks: `useInvalidXxx`
**Dify Examples**:
- `web/service/use-workflow.ts`
- `web/service/use-common.ts`
- `web/service/knowledge/use-dataset.ts`
- `web/service/knowledge/use-document.ts`
### Pattern 5: Extract Modal/Dialog Management
**When**: Component manages multiple modals with complex open/close states.
**Dify Convention**: Modals should be extracted with their state management.
```typescript
// ❌ Before: Multiple modal states in component
const AppInfo = () => {
const [showEditModal, setShowEditModal] = useState(false)
const [showDuplicateModal, setShowDuplicateModal] = useState(false)
const [showConfirmDelete, setShowConfirmDelete] = useState(false)
const [showSwitchModal, setShowSwitchModal] = useState(false)
const [showImportDSLModal, setShowImportDSLModal] = useState(false)
// 5+ more modal states...
}
// ✅ After: Extract to modal management hook
type ModalType = 'edit' | 'duplicate' | 'delete' | 'switch' | 'import' | null
const useAppInfoModals = () => {
const [activeModal, setActiveModal] = useState<ModalType>(null)
const openModal = useCallback((type: ModalType) => setActiveModal(type), [])
const closeModal = useCallback(() => setActiveModal(null), [])
return {
activeModal,
openModal,
closeModal,
isOpen: (type: ModalType) => activeModal === type,
}
}
```
### Pattern 6: Extract Form Logic
**When**: Complex form validation, submission handling, or field transformation.
**Dify Convention**: Use `@tanstack/react-form` patterns from `web/app/components/base/form/`.
```typescript
// ✅ Use existing form infrastructure
import { useAppForm } from '@/app/components/base/form'
const ConfigForm = () => {
const form = useAppForm({
defaultValues: { name: '', description: '' },
onSubmit: handleSubmit,
})
return <form.Provider>...</form.Provider>
}
```
## Dify-Specific Refactoring Guidelines
### 1. Context Provider Extraction
**When**: Component provides complex context values with multiple states.
```typescript
// ❌ Before: Large context value object
const value = {
appId, isAPIKeySet, isTrailFinished, mode, modelModeType,
promptMode, isAdvancedMode, isAgent, isOpenAI, isFunctionCall,
// 50+ more properties...
}
return <ConfigContext.Provider value={value}>...</ConfigContext.Provider>
// ✅ After: Split into domain-specific contexts
<ModelConfigProvider value={modelConfigValue}>
<DatasetConfigProvider value={datasetConfigValue}>
<UIConfigProvider value={uiConfigValue}>
{children}
</UIConfigProvider>
</DatasetConfigProvider>
</ModelConfigProvider>
```
**Dify Reference**: `web/context/` directory structure
### 2. Workflow Node Components
**When**: Refactoring workflow node components (`web/app/components/workflow/nodes/`).
**Conventions**:
- Keep node logic in `use-interactions.ts`
- Extract panel UI to separate files
- Use `_base` components for common patterns
```
nodes/<node-type>/
├── index.tsx # Node registration
├── node.tsx # Node visual component
├── panel.tsx # Configuration panel
├── use-interactions.ts # Node-specific hooks
└── types.ts # Type definitions
```
### 3. Configuration Components
**When**: Refactoring app configuration components.
**Conventions**:
- Separate config sections into subdirectories
- Use existing patterns from `web/app/components/app/configuration/`
- Keep feature toggles in dedicated components
### 4. Tool/Plugin Components
**When**: Refactoring tool-related components (`web/app/components/tools/`).
**Conventions**:
- Follow existing modal patterns
- Use service hooks from `web/service/use-tools.ts`
- Keep provider-specific logic isolated
## Refactoring Workflow
### Step 1: Generate Refactoring Prompt
```bash
pnpm refactor-component <path>
```
This command will:
- Analyze component complexity and features
- Identify specific refactoring actions needed
- Generate a prompt for AI assistant (auto-copied to clipboard on macOS)
- Provide detailed requirements based on detected patterns
### Step 2: Analyze Details
```bash
pnpm analyze-component <path> --json
```
Identify:
- Total complexity score
- Max function complexity
- Line count
- Features detected (state, effects, API, etc.)
### Step 3: Plan
Create a refactoring plan based on detected features:
| Detected Feature | Refactoring Action |
|------------------|-------------------|
| `hasState: true` + `hasEffects: true` | Extract custom hook |
| `hasAPI: true` | Extract data/service hook |
| `hasEvents: true` (many) | Extract event handlers |
| `lineCount > 300` | Split into sub-components |
| `maxComplexity > 50` | Simplify conditional logic |
### Step 4: Execute Incrementally
1. **Extract one piece at a time**
2. **Run lint, type-check, and tests after each extraction**
3. **Verify functionality before next step**
```
For each extraction:
┌────────────────────────────────────────┐
│ 1. Extract code │
│ 2. Run: pnpm lint:fix │
│ 3. Run: pnpm type-check:tsgo │
│ 4. Run: pnpm test │
│ 5. Test functionality manually │
│ 6. PASS? → Next extraction │
│ FAIL? → Fix before continuing │
└────────────────────────────────────────┘
```
### Step 5: Verify
After refactoring:
```bash
# Re-run refactor command to verify improvements
pnpm refactor-component <path>
# If complexity < 25 and lines < 200, you'll see:
# ✅ COMPONENT IS WELL-STRUCTURED
# For detailed metrics:
pnpm analyze-component <path> --json
# Target metrics:
# - complexity < 50
# - lineCount < 300
# - maxComplexity < 30
```
## Common Mistakes to Avoid
### ❌ Over-Engineering
```typescript
// ❌ Too many tiny hooks
const useButtonText = () => useState('Click')
const useButtonDisabled = () => useState(false)
const useButtonLoading = () => useState(false)
// ✅ Cohesive hook with related state
const useButtonState = () => {
const [text, setText] = useState('Click')
const [disabled, setDisabled] = useState(false)
const [loading, setLoading] = useState(false)
return { text, setText, disabled, setDisabled, loading, setLoading }
}
```
### ❌ Breaking Existing Patterns
- Follow existing directory structures
- Maintain naming conventions
- Preserve export patterns for compatibility
### ❌ Premature Abstraction
- Only extract when there's clear complexity benefit
- Don't create abstractions for single-use code
- Keep refactored code in the same domain area
## References
### Dify Codebase Examples
- **Hook extraction**: `web/app/components/app/configuration/hooks/`
- **Component splitting**: `web/app/components/app/configuration/`
- **Service hooks**: `web/service/use-*.ts`
- **Workflow patterns**: `web/app/components/workflow/hooks/`
- **Form patterns**: `web/app/components/base/form/`
### Related Skills
- `frontend-testing` - For testing refactored components
- `web/testing/testing.md` - Testing specification

View File

@ -1,493 +0,0 @@
# Complexity Reduction Patterns
This document provides patterns for reducing cognitive complexity in Dify React components.
## Understanding Complexity
### SonarJS Cognitive Complexity
The `pnpm analyze-component` tool uses SonarJS cognitive complexity metrics:
- **Total Complexity**: Sum of all functions' complexity in the file
- **Max Complexity**: Highest single function complexity
### What Increases Complexity
| Pattern | Complexity Impact |
|---------|-------------------|
| `if/else` | +1 per branch |
| Nested conditions | +1 per nesting level |
| `switch/case` | +1 per case |
| `for/while/do` | +1 per loop |
| `&&`/`||` chains | +1 per operator |
| Nested callbacks | +1 per nesting level |
| `try/catch` | +1 per catch |
| Ternary expressions | +1 per nesting |
## Pattern 1: Replace Conditionals with Lookup Tables
**Before** (complexity: ~15):
```typescript
const Template = useMemo(() => {
if (appDetail?.mode === AppModeEnum.CHAT) {
switch (locale) {
case LanguagesSupported[1]:
return <TemplateChatZh appDetail={appDetail} />
case LanguagesSupported[7]:
return <TemplateChatJa appDetail={appDetail} />
default:
return <TemplateChatEn appDetail={appDetail} />
}
}
if (appDetail?.mode === AppModeEnum.ADVANCED_CHAT) {
switch (locale) {
case LanguagesSupported[1]:
return <TemplateAdvancedChatZh appDetail={appDetail} />
case LanguagesSupported[7]:
return <TemplateAdvancedChatJa appDetail={appDetail} />
default:
return <TemplateAdvancedChatEn appDetail={appDetail} />
}
}
if (appDetail?.mode === AppModeEnum.WORKFLOW) {
// Similar pattern...
}
return null
}, [appDetail, locale])
```
**After** (complexity: ~3):
```typescript
// Define lookup table outside component
const TEMPLATE_MAP: Record<AppModeEnum, Record<string, FC<TemplateProps>>> = {
[AppModeEnum.CHAT]: {
[LanguagesSupported[1]]: TemplateChatZh,
[LanguagesSupported[7]]: TemplateChatJa,
default: TemplateChatEn,
},
[AppModeEnum.ADVANCED_CHAT]: {
[LanguagesSupported[1]]: TemplateAdvancedChatZh,
[LanguagesSupported[7]]: TemplateAdvancedChatJa,
default: TemplateAdvancedChatEn,
},
[AppModeEnum.WORKFLOW]: {
[LanguagesSupported[1]]: TemplateWorkflowZh,
[LanguagesSupported[7]]: TemplateWorkflowJa,
default: TemplateWorkflowEn,
},
// ...
}
// Clean component logic
const Template = useMemo(() => {
if (!appDetail?.mode) return null
const templates = TEMPLATE_MAP[appDetail.mode]
if (!templates) return null
const TemplateComponent = templates[locale] ?? templates.default
return <TemplateComponent appDetail={appDetail} />
}, [appDetail, locale])
```
## Pattern 2: Use Early Returns
**Before** (complexity: ~10):
```typescript
const handleSubmit = () => {
if (isValid) {
if (hasChanges) {
if (isConnected) {
submitData()
} else {
showConnectionError()
}
} else {
showNoChangesMessage()
}
} else {
showValidationError()
}
}
```
**After** (complexity: ~4):
```typescript
const handleSubmit = () => {
if (!isValid) {
showValidationError()
return
}
if (!hasChanges) {
showNoChangesMessage()
return
}
if (!isConnected) {
showConnectionError()
return
}
submitData()
}
```
## Pattern 3: Extract Complex Conditions
**Before** (complexity: high):
```typescript
const canPublish = (() => {
if (mode !== AppModeEnum.COMPLETION) {
if (!isAdvancedMode)
return true
if (modelModeType === ModelModeType.completion) {
if (!hasSetBlockStatus.history || !hasSetBlockStatus.query)
return false
return true
}
return true
}
return !promptEmpty
})()
```
**After** (complexity: lower):
```typescript
// Extract to named functions
const canPublishInCompletionMode = () => !promptEmpty
const canPublishInChatMode = () => {
if (!isAdvancedMode) return true
if (modelModeType !== ModelModeType.completion) return true
return hasSetBlockStatus.history && hasSetBlockStatus.query
}
// Clean main logic
const canPublish = mode === AppModeEnum.COMPLETION
? canPublishInCompletionMode()
: canPublishInChatMode()
```
## Pattern 4: Replace Chained Ternaries
**Before** (complexity: ~5):
```typescript
const statusText = serverActivated
? t('status.running')
: serverPublished
? t('status.inactive')
: appUnpublished
? t('status.unpublished')
: t('status.notConfigured')
```
**After** (complexity: ~2):
```typescript
const getStatusText = () => {
if (serverActivated) return t('status.running')
if (serverPublished) return t('status.inactive')
if (appUnpublished) return t('status.unpublished')
return t('status.notConfigured')
}
const statusText = getStatusText()
```
Or use lookup:
```typescript
const STATUS_TEXT_MAP = {
running: 'status.running',
inactive: 'status.inactive',
unpublished: 'status.unpublished',
notConfigured: 'status.notConfigured',
} as const
const getStatusKey = (): keyof typeof STATUS_TEXT_MAP => {
if (serverActivated) return 'running'
if (serverPublished) return 'inactive'
if (appUnpublished) return 'unpublished'
return 'notConfigured'
}
const statusText = t(STATUS_TEXT_MAP[getStatusKey()])
```
## Pattern 5: Flatten Nested Loops
**Before** (complexity: high):
```typescript
const processData = (items: Item[]) => {
const results: ProcessedItem[] = []
for (const item of items) {
if (item.isValid) {
for (const child of item.children) {
if (child.isActive) {
for (const prop of child.properties) {
if (prop.value !== null) {
results.push({
itemId: item.id,
childId: child.id,
propValue: prop.value,
})
}
}
}
}
}
}
return results
}
```
**After** (complexity: lower):
```typescript
// Use functional approach
const processData = (items: Item[]) => {
return items
.filter(item => item.isValid)
.flatMap(item =>
item.children
.filter(child => child.isActive)
.flatMap(child =>
child.properties
.filter(prop => prop.value !== null)
.map(prop => ({
itemId: item.id,
childId: child.id,
propValue: prop.value,
}))
)
)
}
```
## Pattern 6: Extract Event Handler Logic
**Before** (complexity: high in component):
```typescript
const Component = () => {
const handleSelect = (data: DataSet[]) => {
if (isEqual(data.map(item => item.id), dataSets.map(item => item.id))) {
hideSelectDataSet()
return
}
formattingChangedDispatcher()
let newDatasets = data
if (data.find(item => !item.name)) {
const newSelected = produce(data, (draft) => {
data.forEach((item, index) => {
if (!item.name) {
const newItem = dataSets.find(i => i.id === item.id)
if (newItem)
draft[index] = newItem
}
})
})
setDataSets(newSelected)
newDatasets = newSelected
}
else {
setDataSets(data)
}
hideSelectDataSet()
// 40 more lines of logic...
}
return <div>...</div>
}
```
**After** (complexity: lower):
```typescript
// Extract to hook or utility
const useDatasetSelection = (dataSets: DataSet[], setDataSets: SetState<DataSet[]>) => {
const normalizeSelection = (data: DataSet[]) => {
const hasUnloadedItem = data.some(item => !item.name)
if (!hasUnloadedItem) return data
return produce(data, (draft) => {
data.forEach((item, index) => {
if (!item.name) {
const existing = dataSets.find(i => i.id === item.id)
if (existing) draft[index] = existing
}
})
})
}
const hasSelectionChanged = (newData: DataSet[]) => {
return !isEqual(
newData.map(item => item.id),
dataSets.map(item => item.id)
)
}
return { normalizeSelection, hasSelectionChanged }
}
// Component becomes cleaner
const Component = () => {
const { normalizeSelection, hasSelectionChanged } = useDatasetSelection(dataSets, setDataSets)
const handleSelect = (data: DataSet[]) => {
if (!hasSelectionChanged(data)) {
hideSelectDataSet()
return
}
formattingChangedDispatcher()
const normalized = normalizeSelection(data)
setDataSets(normalized)
hideSelectDataSet()
}
return <div>...</div>
}
```
## Pattern 7: Reduce Boolean Logic Complexity
**Before** (complexity: ~8):
```typescript
const toggleDisabled = hasInsufficientPermissions
|| appUnpublished
|| missingStartNode
|| triggerModeDisabled
|| (isAdvancedApp && !currentWorkflow?.graph)
|| (isBasicApp && !basicAppConfig.updated_at)
```
**After** (complexity: ~3):
```typescript
// Extract meaningful boolean functions
const isAppReady = () => {
if (isAdvancedApp) return !!currentWorkflow?.graph
return !!basicAppConfig.updated_at
}
const hasRequiredPermissions = () => {
return isCurrentWorkspaceEditor && !hasInsufficientPermissions
}
const canToggle = () => {
if (!hasRequiredPermissions()) return false
if (!isAppReady()) return false
if (missingStartNode) return false
if (triggerModeDisabled) return false
return true
}
const toggleDisabled = !canToggle()
```
## Pattern 8: Simplify useMemo/useCallback Dependencies
**Before** (complexity: multiple recalculations):
```typescript
const payload = useMemo(() => {
let parameters: Parameter[] = []
let outputParameters: OutputParameter[] = []
if (!published) {
parameters = (inputs || []).map((item) => ({
name: item.variable,
description: '',
form: 'llm',
required: item.required,
type: item.type,
}))
outputParameters = (outputs || []).map((item) => ({
name: item.variable,
description: '',
type: item.value_type,
}))
}
else if (detail && detail.tool) {
parameters = (inputs || []).map((item) => ({
// Complex transformation...
}))
outputParameters = (outputs || []).map((item) => ({
// Complex transformation...
}))
}
return {
icon: detail?.icon || icon,
label: detail?.label || name,
// ...more fields
}
}, [detail, published, workflowAppId, icon, name, description, inputs, outputs])
```
**After** (complexity: separated concerns):
```typescript
// Separate transformations
const useParameterTransform = (inputs: InputVar[], detail?: ToolDetail, published?: boolean) => {
return useMemo(() => {
if (!published) {
return inputs.map(item => ({
name: item.variable,
description: '',
form: 'llm',
required: item.required,
type: item.type,
}))
}
if (!detail?.tool) return []
return inputs.map(item => ({
name: item.variable,
required: item.required,
type: item.type === 'paragraph' ? 'string' : item.type,
description: detail.tool.parameters.find(p => p.name === item.variable)?.llm_description || '',
form: detail.tool.parameters.find(p => p.name === item.variable)?.form || 'llm',
}))
}, [inputs, detail, published])
}
// Component uses hook
const parameters = useParameterTransform(inputs, detail, published)
const outputParameters = useOutputTransform(outputs, detail, published)
const payload = useMemo(() => ({
icon: detail?.icon || icon,
label: detail?.label || name,
parameters,
outputParameters,
// ...
}), [detail, icon, name, parameters, outputParameters])
```
## Target Metrics After Refactoring
| Metric | Target |
|--------|--------|
| Total Complexity | < 50 |
| Max Function Complexity | < 30 |
| Function Length | < 30 lines |
| Nesting Depth | 3 levels |
| Conditional Chains | 3 conditions |

View File

@ -1,477 +0,0 @@
# Component Splitting Patterns
This document provides detailed guidance on splitting large components into smaller, focused components in Dify.
## When to Split Components
Split a component when you identify:
1. **Multiple UI sections** - Distinct visual areas with minimal coupling that can be composed independently
1. **Conditional rendering blocks** - Large `{condition && <JSX />}` blocks
1. **Repeated patterns** - Similar UI structures used multiple times
1. **300+ lines** - Component exceeds manageable size
1. **Modal clusters** - Multiple modals rendered in one component
## Splitting Strategies
### Strategy 1: Section-Based Splitting
Identify visual sections and extract each as a component.
```typescript
// ❌ Before: Monolithic component (500+ lines)
const ConfigurationPage = () => {
return (
<div>
{/* Header Section - 50 lines */}
<div className="header">
<h1>{t('configuration.title')}</h1>
<div className="actions">
{isAdvancedMode && <Badge>Advanced</Badge>}
<ModelParameterModal ... />
<AppPublisher ... />
</div>
</div>
{/* Config Section - 200 lines */}
<div className="config">
<Config />
</div>
{/* Debug Section - 150 lines */}
<div className="debug">
<Debug ... />
</div>
{/* Modals Section - 100 lines */}
{showSelectDataSet && <SelectDataSet ... />}
{showHistoryModal && <EditHistoryModal ... />}
{showUseGPT4Confirm && <Confirm ... />}
</div>
)
}
// ✅ After: Split into focused components
// configuration/
// ├── index.tsx (orchestration)
// ├── configuration-header.tsx
// ├── configuration-content.tsx
// ├── configuration-debug.tsx
// └── configuration-modals.tsx
// configuration-header.tsx
interface ConfigurationHeaderProps {
isAdvancedMode: boolean
onPublish: () => void
}
const ConfigurationHeader: FC<ConfigurationHeaderProps> = ({
isAdvancedMode,
onPublish,
}) => {
const { t } = useTranslation()
return (
<div className="header">
<h1>{t('configuration.title')}</h1>
<div className="actions">
{isAdvancedMode && <Badge>Advanced</Badge>}
<ModelParameterModal ... />
<AppPublisher onPublish={onPublish} />
</div>
</div>
)
}
// index.tsx (orchestration only)
const ConfigurationPage = () => {
const { modelConfig, setModelConfig } = useModelConfig()
const { activeModal, openModal, closeModal } = useModalState()
return (
<div>
<ConfigurationHeader
isAdvancedMode={isAdvancedMode}
onPublish={handlePublish}
/>
<ConfigurationContent
modelConfig={modelConfig}
onConfigChange={setModelConfig}
/>
{!isMobile && (
<ConfigurationDebug
inputs={inputs}
onSetting={handleSetting}
/>
)}
<ConfigurationModals
activeModal={activeModal}
onClose={closeModal}
/>
</div>
)
}
```
### Strategy 2: Conditional Block Extraction
Extract large conditional rendering blocks.
```typescript
// ❌ Before: Large conditional blocks
const AppInfo = () => {
return (
<div>
{expand ? (
<div className="expanded">
{/* 100 lines of expanded view */}
</div>
) : (
<div className="collapsed">
{/* 50 lines of collapsed view */}
</div>
)}
</div>
)
}
// ✅ After: Separate view components
const AppInfoExpanded: FC<AppInfoViewProps> = ({ appDetail, onAction }) => {
return (
<div className="expanded">
{/* Clean, focused expanded view */}
</div>
)
}
const AppInfoCollapsed: FC<AppInfoViewProps> = ({ appDetail, onAction }) => {
return (
<div className="collapsed">
{/* Clean, focused collapsed view */}
</div>
)
}
const AppInfo = () => {
return (
<div>
{expand
? <AppInfoExpanded appDetail={appDetail} onAction={handleAction} />
: <AppInfoCollapsed appDetail={appDetail} onAction={handleAction} />
}
</div>
)
}
```
### Strategy 3: Modal Extraction
Extract modals with their trigger logic.
```typescript
// ❌ Before: Multiple modals in one component
const AppInfo = () => {
const [showEdit, setShowEdit] = useState(false)
const [showDuplicate, setShowDuplicate] = useState(false)
const [showDelete, setShowDelete] = useState(false)
const [showSwitch, setShowSwitch] = useState(false)
const onEdit = async (data) => { /* 20 lines */ }
const onDuplicate = async (data) => { /* 20 lines */ }
const onDelete = async () => { /* 15 lines */ }
return (
<div>
{/* Main content */}
{showEdit && <EditModal onConfirm={onEdit} onClose={() => setShowEdit(false)} />}
{showDuplicate && <DuplicateModal onConfirm={onDuplicate} onClose={() => setShowDuplicate(false)} />}
{showDelete && <DeleteConfirm onConfirm={onDelete} onClose={() => setShowDelete(false)} />}
{showSwitch && <SwitchModal ... />}
</div>
)
}
// ✅ After: Modal manager component
// app-info-modals.tsx
type ModalType = 'edit' | 'duplicate' | 'delete' | 'switch' | null
interface AppInfoModalsProps {
appDetail: AppDetail
activeModal: ModalType
onClose: () => void
onSuccess: () => void
}
const AppInfoModals: FC<AppInfoModalsProps> = ({
appDetail,
activeModal,
onClose,
onSuccess,
}) => {
const handleEdit = async (data) => { /* logic */ }
const handleDuplicate = async (data) => { /* logic */ }
const handleDelete = async () => { /* logic */ }
return (
<>
{activeModal === 'edit' && (
<EditModal
appDetail={appDetail}
onConfirm={handleEdit}
onClose={onClose}
/>
)}
{activeModal === 'duplicate' && (
<DuplicateModal
appDetail={appDetail}
onConfirm={handleDuplicate}
onClose={onClose}
/>
)}
{activeModal === 'delete' && (
<DeleteConfirm
onConfirm={handleDelete}
onClose={onClose}
/>
)}
{activeModal === 'switch' && (
<SwitchModal
appDetail={appDetail}
onClose={onClose}
/>
)}
</>
)
}
// Parent component
const AppInfo = () => {
const { activeModal, openModal, closeModal } = useModalState()
return (
<div>
{/* Main content with openModal triggers */}
<Button onClick={() => openModal('edit')}>Edit</Button>
<AppInfoModals
appDetail={appDetail}
activeModal={activeModal}
onClose={closeModal}
onSuccess={handleSuccess}
/>
</div>
)
}
```
### Strategy 4: List Item Extraction
Extract repeated item rendering.
```typescript
// ❌ Before: Inline item rendering
const OperationsList = () => {
return (
<div>
{operations.map(op => (
<div key={op.id} className="operation-item">
<span className="icon">{op.icon}</span>
<span className="title">{op.title}</span>
<span className="description">{op.description}</span>
<button onClick={() => op.onClick()}>
{op.actionLabel}
</button>
{op.badge && <Badge>{op.badge}</Badge>}
{/* More complex rendering... */}
</div>
))}
</div>
)
}
// ✅ After: Extracted item component
interface OperationItemProps {
operation: Operation
onAction: (id: string) => void
}
const OperationItem: FC<OperationItemProps> = ({ operation, onAction }) => {
return (
<div className="operation-item">
<span className="icon">{operation.icon}</span>
<span className="title">{operation.title}</span>
<span className="description">{operation.description}</span>
<button onClick={() => onAction(operation.id)}>
{operation.actionLabel}
</button>
{operation.badge && <Badge>{operation.badge}</Badge>}
</div>
)
}
const OperationsList = () => {
const handleAction = useCallback((id: string) => {
const op = operations.find(o => o.id === id)
op?.onClick()
}, [operations])
return (
<div>
{operations.map(op => (
<OperationItem
key={op.id}
operation={op}
onAction={handleAction}
/>
))}
</div>
)
}
```
## Directory Structure Patterns
### Pattern A: Flat Structure (Simple Components)
For components with 2-3 sub-components:
```
component-name/
├── index.tsx # Main component
├── sub-component-a.tsx
├── sub-component-b.tsx
└── types.ts # Shared types
```
### Pattern B: Nested Structure (Complex Components)
For components with many sub-components:
```
component-name/
├── index.tsx # Main orchestration
├── types.ts # Shared types
├── hooks/
│ ├── use-feature-a.ts
│ └── use-feature-b.ts
├── components/
│ ├── header/
│ │ └── index.tsx
│ ├── content/
│ │ └── index.tsx
│ └── modals/
│ └── index.tsx
└── utils/
└── helpers.ts
```
### Pattern C: Feature-Based Structure (Dify Standard)
Following Dify's existing patterns:
```
configuration/
├── index.tsx # Main page component
├── base/ # Base/shared components
│ ├── feature-panel/
│ ├── group-name/
│ └── operation-btn/
├── config/ # Config section
│ ├── index.tsx
│ ├── agent/
│ └── automatic/
├── dataset-config/ # Dataset section
│ ├── index.tsx
│ ├── card-item/
│ └── params-config/
├── debug/ # Debug section
│ ├── index.tsx
│ └── hooks.tsx
└── hooks/ # Shared hooks
└── use-advanced-prompt-config.ts
```
## Props Design
### Minimal Props Principle
Pass only what's needed:
```typescript
// ❌ Bad: Passing entire objects when only some fields needed
<ConfigHeader appDetail={appDetail} modelConfig={modelConfig} />
// ✅ Good: Destructure to minimum required
<ConfigHeader
appName={appDetail.name}
isAdvancedMode={modelConfig.isAdvanced}
onPublish={handlePublish}
/>
```
### Callback Props Pattern
Use callbacks for child-to-parent communication:
```typescript
// Parent
const Parent = () => {
const [value, setValue] = useState('')
return (
<Child
value={value}
onChange={setValue}
onSubmit={handleSubmit}
/>
)
}
// Child
interface ChildProps {
value: string
onChange: (value: string) => void
onSubmit: () => void
}
const Child: FC<ChildProps> = ({ value, onChange, onSubmit }) => {
return (
<div>
<input value={value} onChange={e => onChange(e.target.value)} />
<button onClick={onSubmit}>Submit</button>
</div>
)
}
```
### Render Props for Flexibility
When sub-components need parent context:
```typescript
interface ListProps<T> {
items: T[]
renderItem: (item: T, index: number) => React.ReactNode
renderEmpty?: () => React.ReactNode
}
function List<T>({ items, renderItem, renderEmpty }: ListProps<T>) {
if (items.length === 0 && renderEmpty) {
return <>{renderEmpty()}</>
}
return (
<div>
{items.map((item, index) => renderItem(item, index))}
</div>
)
}
// Usage
<List
items={operations}
renderItem={(op, i) => <OperationItem key={i} operation={op} />}
renderEmpty={() => <EmptyState message="No operations" />}
/>
```

View File

@ -1,317 +0,0 @@
# Hook Extraction Patterns
This document provides detailed guidance on extracting custom hooks from complex components in Dify.
## When to Extract Hooks
Extract a custom hook when you identify:
1. **Coupled state groups** - Multiple `useState` hooks that are always used together
1. **Complex effects** - `useEffect` with multiple dependencies or cleanup logic
1. **Business logic** - Data transformations, validations, or calculations
1. **Reusable patterns** - Logic that appears in multiple components
## Extraction Process
### Step 1: Identify State Groups
Look for state variables that are logically related:
```typescript
// ❌ These belong together - extract to hook
const [modelConfig, setModelConfig] = useState<ModelConfig>(...)
const [completionParams, setCompletionParams] = useState<FormValue>({})
const [modelModeType, setModelModeType] = useState<ModelModeType>(...)
// These are model-related state that should be in useModelConfig()
```
### Step 2: Identify Related Effects
Find effects that modify the grouped state:
```typescript
// ❌ These effects belong with the state above
useEffect(() => {
if (hasFetchedDetail && !modelModeType) {
const mode = currModel?.model_properties.mode
if (mode) {
const newModelConfig = produce(modelConfig, (draft) => {
draft.mode = mode
})
setModelConfig(newModelConfig)
}
}
}, [textGenerationModelList, hasFetchedDetail, modelModeType, currModel])
```
### Step 3: Create the Hook
```typescript
// hooks/use-model-config.ts
import type { FormValue } from '@/app/components/header/account-setting/model-provider-page/declarations'
import type { ModelConfig } from '@/models/debug'
import { produce } from 'immer'
import { useEffect, useState } from 'react'
import { ModelModeType } from '@/types/app'
interface UseModelConfigParams {
initialConfig?: Partial<ModelConfig>
currModel?: { model_properties?: { mode?: ModelModeType } }
hasFetchedDetail: boolean
}
interface UseModelConfigReturn {
modelConfig: ModelConfig
setModelConfig: (config: ModelConfig) => void
completionParams: FormValue
setCompletionParams: (params: FormValue) => void
modelModeType: ModelModeType
}
export const useModelConfig = ({
initialConfig,
currModel,
hasFetchedDetail,
}: UseModelConfigParams): UseModelConfigReturn => {
const [modelConfig, setModelConfig] = useState<ModelConfig>({
provider: 'langgenius/openai/openai',
model_id: 'gpt-3.5-turbo',
mode: ModelModeType.unset,
// ... default values
...initialConfig,
})
const [completionParams, setCompletionParams] = useState<FormValue>({})
const modelModeType = modelConfig.mode
// Fill old app data missing model mode
useEffect(() => {
if (hasFetchedDetail && !modelModeType) {
const mode = currModel?.model_properties?.mode
if (mode) {
setModelConfig(produce(modelConfig, (draft) => {
draft.mode = mode
}))
}
}
}, [hasFetchedDetail, modelModeType, currModel])
return {
modelConfig,
setModelConfig,
completionParams,
setCompletionParams,
modelModeType,
}
}
```
### Step 4: Update Component
```typescript
// Before: 50+ lines of state management
const Configuration: FC = () => {
const [modelConfig, setModelConfig] = useState<ModelConfig>(...)
// ... lots of related state and effects
}
// After: Clean component
const Configuration: FC = () => {
const {
modelConfig,
setModelConfig,
completionParams,
setCompletionParams,
modelModeType,
} = useModelConfig({
currModel,
hasFetchedDetail,
})
// Component now focuses on UI
}
```
## Naming Conventions
### Hook Names
- Use `use` prefix: `useModelConfig`, `useDatasetConfig`
- Be specific: `useAdvancedPromptConfig` not `usePrompt`
- Include domain: `useWorkflowVariables`, `useMCPServer`
### File Names
- Kebab-case: `use-model-config.ts`
- Place in `hooks/` subdirectory when multiple hooks exist
- Place alongside component for single-use hooks
### Return Type Names
- Suffix with `Return`: `UseModelConfigReturn`
- Suffix params with `Params`: `UseModelConfigParams`
## Common Hook Patterns in Dify
### 1. Data Fetching Hook (React Query)
```typescript
// Pattern: Use @tanstack/react-query for data fetching
import { useQuery, useQueryClient } from '@tanstack/react-query'
import { get } from '@/service/base'
import { useInvalid } from '@/service/use-base'
const NAME_SPACE = 'appConfig'
// Query keys for cache management
export const appConfigQueryKeys = {
detail: (appId: string) => [NAME_SPACE, 'detail', appId] as const,
}
// Main data hook
export const useAppConfig = (appId: string) => {
return useQuery({
enabled: !!appId,
queryKey: appConfigQueryKeys.detail(appId),
queryFn: () => get<AppDetailResponse>(`/apps/${appId}`),
select: data => data?.model_config || null,
})
}
// Invalidation hook for refreshing data
export const useInvalidAppConfig = () => {
return useInvalid([NAME_SPACE])
}
// Usage in component
const Component = () => {
const { data: config, isLoading, error, refetch } = useAppConfig(appId)
const invalidAppConfig = useInvalidAppConfig()
const handleRefresh = () => {
invalidAppConfig() // Invalidates cache and triggers refetch
}
return <div>...</div>
}
```
### 2. Form State Hook
```typescript
// Pattern: Form state + validation + submission
export const useConfigForm = (initialValues: ConfigFormValues) => {
const [values, setValues] = useState(initialValues)
const [errors, setErrors] = useState<Record<string, string>>({})
const [isSubmitting, setIsSubmitting] = useState(false)
const validate = useCallback(() => {
const newErrors: Record<string, string> = {}
if (!values.name) newErrors.name = 'Name is required'
setErrors(newErrors)
return Object.keys(newErrors).length === 0
}, [values])
const handleChange = useCallback((field: string, value: any) => {
setValues(prev => ({ ...prev, [field]: value }))
}, [])
const handleSubmit = useCallback(async (onSubmit: (values: ConfigFormValues) => Promise<void>) => {
if (!validate()) return
setIsSubmitting(true)
try {
await onSubmit(values)
} finally {
setIsSubmitting(false)
}
}, [values, validate])
return { values, errors, isSubmitting, handleChange, handleSubmit }
}
```
### 3. Modal State Hook
```typescript
// Pattern: Multiple modal management
type ModalType = 'edit' | 'delete' | 'duplicate' | null
export const useModalState = () => {
const [activeModal, setActiveModal] = useState<ModalType>(null)
const [modalData, setModalData] = useState<any>(null)
const openModal = useCallback((type: ModalType, data?: any) => {
setActiveModal(type)
setModalData(data)
}, [])
const closeModal = useCallback(() => {
setActiveModal(null)
setModalData(null)
}, [])
return {
activeModal,
modalData,
openModal,
closeModal,
isOpen: useCallback((type: ModalType) => activeModal === type, [activeModal]),
}
}
```
### 4. Toggle/Boolean Hook
```typescript
// Pattern: Boolean state with convenience methods
export const useToggle = (initialValue = false) => {
const [value, setValue] = useState(initialValue)
const toggle = useCallback(() => setValue(v => !v), [])
const setTrue = useCallback(() => setValue(true), [])
const setFalse = useCallback(() => setValue(false), [])
return [value, { toggle, setTrue, setFalse, set: setValue }] as const
}
// Usage
const [isExpanded, { toggle, setTrue: expand, setFalse: collapse }] = useToggle()
```
## Testing Extracted Hooks
After extraction, test hooks in isolation:
```typescript
// use-model-config.spec.ts
import { renderHook, act } from '@testing-library/react'
import { useModelConfig } from './use-model-config'
describe('useModelConfig', () => {
it('should initialize with default values', () => {
const { result } = renderHook(() => useModelConfig({
hasFetchedDetail: false,
}))
expect(result.current.modelConfig.provider).toBe('langgenius/openai/openai')
expect(result.current.modelModeType).toBe(ModelModeType.unset)
})
it('should update model config', () => {
const { result } = renderHook(() => useModelConfig({
hasFetchedDetail: true,
}))
act(() => {
result.current.setModelConfig({
...result.current.modelConfig,
model_id: 'gpt-4',
})
})
expect(result.current.modelConfig.model_id).toBe('gpt-4')
})
})
```

View File

@ -318,5 +318,5 @@ For more detailed information, refer to:
- `web/vitest.config.ts` - Vitest configuration
- `web/vitest.setup.ts` - Test environment setup
- `web/scripts/analyze-component.js` - Component analysis tool
- `web/testing/analyze-component.js` - Component analysis tool
- Modules are not mocked automatically. Global mocks live in `web/vitest.setup.ts` (for example `react-i18next`, `next/image`); mock other modules like `ky` or `mime` locally in test files.

View File

@ -22,12 +22,12 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
persist-credentials: false
- name: Setup UV and Python
uses: astral-sh/setup-uv@v7
uses: astral-sh/setup-uv@v6
with:
enable-cache: true
python-version: ${{ matrix.python-version }}
@ -57,7 +57,7 @@ jobs:
run: sh .github/workflows/expose_service_ports.sh
- name: Set up Sandbox
uses: hoverkraft-tech/compose-action@v2
uses: hoverkraft-tech/compose-action@v2.0.2
with:
compose-file: |
docker/docker-compose.middleware.yaml

View File

@ -12,7 +12,7 @@ jobs:
if: github.repository == 'langgenius/dify'
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v6
- uses: actions/checkout@v4
- name: Check Docker Compose inputs
id: docker-compose-changes
@ -27,7 +27,7 @@ jobs:
with:
python-version: "3.11"
- uses: astral-sh/setup-uv@v7
- uses: astral-sh/setup-uv@v6
- name: Generate Docker Compose
if: steps.docker-compose-changes.outputs.any_changed == 'true'

View File

@ -90,7 +90,7 @@ jobs:
touch "/tmp/digests/${sanitized_digest}"
- name: Upload digest
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v4
with:
name: digests-${{ matrix.context }}-${{ env.PLATFORM_PAIR }}
path: /tmp/digests/*

View File

@ -13,13 +13,13 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
fetch-depth: 0
persist-credentials: false
- name: Setup UV and Python
uses: astral-sh/setup-uv@v7
uses: astral-sh/setup-uv@v6
with:
enable-cache: true
python-version: "3.12"
@ -63,13 +63,13 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
fetch-depth: 0
persist-credentials: false
- name: Setup UV and Python
uses: astral-sh/setup-uv@v7
uses: astral-sh/setup-uv@v6
with:
enable-cache: true
python-version: "3.12"

View File

@ -27,7 +27,7 @@ jobs:
vdb-changed: ${{ steps.changes.outputs.vdb }}
migration-changed: ${{ steps.changes.outputs.migration }}
steps:
- uses: actions/checkout@v6
- uses: actions/checkout@v4
- uses: dorny/paths-filter@v3
id: changes
with:
@ -38,7 +38,6 @@ jobs:
- '.github/workflows/api-tests.yml'
web:
- 'web/**'
- '.github/workflows/web-tests.yml'
vdb:
- 'api/core/rag/datasource/**'
- 'docker/**'

View File

@ -19,13 +19,13 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
persist-credentials: false
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@v47
uses: tj-actions/changed-files@v46
with:
files: |
api/**
@ -33,7 +33,7 @@ jobs:
- name: Setup UV and Python
if: steps.changed-files.outputs.any_changed == 'true'
uses: astral-sh/setup-uv@v7
uses: astral-sh/setup-uv@v6
with:
enable-cache: false
python-version: "3.12"
@ -68,17 +68,15 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
persist-credentials: false
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@v47
uses: tj-actions/changed-files@v46
with:
files: |
web/**
.github/workflows/style.yml
files: web/**
- name: Install pnpm
uses: pnpm/action-setup@v4
@ -87,10 +85,10 @@ jobs:
run_install: false
- name: Setup NodeJS
uses: actions/setup-node@v6
uses: actions/setup-node@v4
if: steps.changed-files.outputs.any_changed == 'true'
with:
node-version: 22
node-version: 24
cache: pnpm
cache-dependency-path: ./web/pnpm-lock.yaml
@ -116,14 +114,14 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
fetch-depth: 0
persist-credentials: false
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@v47
uses: tj-actions/changed-files@v46
with:
files: |
**.sh

View File

@ -16,23 +16,24 @@ jobs:
name: unit test for Node.js SDK
runs-on: ubuntu-latest
strategy:
matrix:
node-version: [16, 18, 20, 22]
defaults:
run:
working-directory: sdks/nodejs-client
steps:
- uses: actions/checkout@v6
- uses: actions/checkout@v4
with:
persist-credentials: false
<<<<<<< HEAD
- name: Use Node.js ${{ matrix.node-version }}
uses: actions/setup-node@v4
=======
- name: Use Node.js
uses: actions/setup-node@v6
>>>>>>> 328897f81c (build: require node 24.13.0 (#30945))
with:
node-version: ${{ matrix.node-version }}
node-version: 24
cache: ''
cache-dependency-path: 'pnpm-lock.yaml'

View File

@ -18,7 +18,7 @@ jobs:
run:
working-directory: web
steps:
- uses: actions/checkout@v6
- uses: actions/checkout@v4
with:
fetch-depth: 0
token: ${{ secrets.GITHUB_TOKEN }}
@ -51,7 +51,7 @@ jobs:
- name: Set up Node.js
if: env.FILES_CHANGED == 'true'
uses: actions/setup-node@v6
uses: actions/setup-node@v4
with:
node-version: 'lts/*'
cache: pnpm

View File

@ -0,0 +1,421 @@
name: Translate i18n Files with Claude Code
# Note: claude-code-action doesn't support push events directly.
# Push events are handled by trigger-i18n-sync.yml which sends repository_dispatch.
# See: https://github.com/langgenius/dify/issues/30743
on:
repository_dispatch:
types: [i18n-sync]
workflow_dispatch:
inputs:
files:
description: 'Specific files to translate (space-separated, e.g., "app common"). Leave empty for all files.'
required: false
type: string
languages:
description: 'Specific languages to translate (space-separated, e.g., "zh-Hans ja-JP"). Leave empty for all supported languages.'
required: false
type: string
mode:
description: 'Sync mode: incremental (only changes) or full (re-check all keys)'
required: false
default: 'incremental'
type: choice
options:
- incremental
- full
permissions:
contents: write
pull-requests: write
jobs:
translate:
if: github.repository == 'langgenius/dify'
runs-on: ubuntu-latest
timeout-minutes: 60
steps:
- name: Checkout repository
uses: actions/checkout@v6
with:
fetch-depth: 0
token: ${{ secrets.GITHUB_TOKEN }}
- name: Configure Git
run: |
git config --global user.name "github-actions[bot]"
git config --global user.email "github-actions[bot]@users.noreply.github.com"
- name: Install pnpm
uses: pnpm/action-setup@v4
with:
package_json_file: web/package.json
run_install: false
- name: Set up Node.js
uses: actions/setup-node@v6
with:
node-version: 24
cache: pnpm
cache-dependency-path: ./web/pnpm-lock.yaml
- name: Detect changed files and generate diff
id: detect_changes
run: |
if [ "${{ github.event_name }}" == "workflow_dispatch" ]; then
# Manual trigger
if [ -n "${{ github.event.inputs.files }}" ]; then
echo "CHANGED_FILES=${{ github.event.inputs.files }}" >> $GITHUB_OUTPUT
else
# Get all JSON files in en-US directory
files=$(ls web/i18n/en-US/*.json 2>/dev/null | xargs -n1 basename | sed 's/.json$//' | tr '\n' ' ')
echo "CHANGED_FILES=$files" >> $GITHUB_OUTPUT
fi
echo "TARGET_LANGS=${{ github.event.inputs.languages }}" >> $GITHUB_OUTPUT
echo "SYNC_MODE=${{ github.event.inputs.mode || 'incremental' }}" >> $GITHUB_OUTPUT
# For manual trigger with incremental mode, get diff from last commit
# For full mode, we'll do a complete check anyway
if [ "${{ github.event.inputs.mode }}" == "full" ]; then
echo "Full mode: will check all keys" > /tmp/i18n-diff.txt
echo "DIFF_AVAILABLE=false" >> $GITHUB_OUTPUT
else
git diff HEAD~1..HEAD -- 'web/i18n/en-US/*.json' > /tmp/i18n-diff.txt 2>/dev/null || echo "" > /tmp/i18n-diff.txt
if [ -s /tmp/i18n-diff.txt ]; then
echo "DIFF_AVAILABLE=true" >> $GITHUB_OUTPUT
else
echo "DIFF_AVAILABLE=false" >> $GITHUB_OUTPUT
fi
fi
elif [ "${{ github.event_name }}" == "repository_dispatch" ]; then
# Triggered by push via trigger-i18n-sync.yml workflow
# Validate required payload fields
if [ -z "${{ github.event.client_payload.changed_files }}" ]; then
echo "Error: repository_dispatch payload missing required 'changed_files' field" >&2
exit 1
fi
echo "CHANGED_FILES=${{ github.event.client_payload.changed_files }}" >> $GITHUB_OUTPUT
echo "TARGET_LANGS=" >> $GITHUB_OUTPUT
echo "SYNC_MODE=${{ github.event.client_payload.sync_mode || 'incremental' }}" >> $GITHUB_OUTPUT
# Decode the base64-encoded diff from the trigger workflow
if [ -n "${{ github.event.client_payload.diff_base64 }}" ]; then
if ! echo "${{ github.event.client_payload.diff_base64 }}" | base64 -d > /tmp/i18n-diff.txt 2>&1; then
echo "Warning: Failed to decode base64 diff payload" >&2
echo "" > /tmp/i18n-diff.txt
echo "DIFF_AVAILABLE=false" >> $GITHUB_OUTPUT
elif [ -s /tmp/i18n-diff.txt ]; then
echo "DIFF_AVAILABLE=true" >> $GITHUB_OUTPUT
else
echo "DIFF_AVAILABLE=false" >> $GITHUB_OUTPUT
fi
else
echo "" > /tmp/i18n-diff.txt
echo "DIFF_AVAILABLE=false" >> $GITHUB_OUTPUT
fi
else
echo "Unsupported event type: ${{ github.event_name }}"
exit 1
fi
# Truncate diff if too large (keep first 50KB)
if [ -f /tmp/i18n-diff.txt ]; then
head -c 50000 /tmp/i18n-diff.txt > /tmp/i18n-diff-truncated.txt
mv /tmp/i18n-diff-truncated.txt /tmp/i18n-diff.txt
fi
echo "Detected files: $(cat $GITHUB_OUTPUT | grep CHANGED_FILES || echo 'none')"
- name: Run Claude Code for Translation Sync
if: steps.detect_changes.outputs.CHANGED_FILES != ''
uses: anthropics/claude-code-action@v1
with:
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
github_token: ${{ secrets.GITHUB_TOKEN }}
prompt: |
You are a professional i18n synchronization engineer for the Dify project.
Your task is to keep all language translations in sync with the English source (en-US).
## CRITICAL TOOL RESTRICTIONS
- Use **Read** tool to read files (NOT cat or bash)
- Use **Edit** tool to modify JSON files (NOT node, jq, or bash scripts)
- Use **Bash** ONLY for: git commands, gh commands, pnpm commands
- Run bash commands ONE BY ONE, never combine with && or ||
- NEVER use `$()` command substitution - it's not supported. Split into separate commands instead.
## WORKING DIRECTORY & ABSOLUTE PATHS
Claude Code sandbox working directory may vary. Always use absolute paths:
- For pnpm: `pnpm --dir ${{ github.workspace }}/web <command>`
- For git: `git -C ${{ github.workspace }} <command>`
- For gh: `gh --repo ${{ github.repository }} <command>`
- For file paths: `${{ github.workspace }}/web/i18n/`
## EFFICIENCY RULES
- **ONE Edit per language file** - batch all key additions into a single Edit
- Insert new keys at the beginning of JSON (after `{`), lint:fix will sort them
- Translate ALL keys for a language mentally first, then do ONE Edit
## Context
- Changed/target files: ${{ steps.detect_changes.outputs.CHANGED_FILES }}
- Target languages (empty means all supported): ${{ steps.detect_changes.outputs.TARGET_LANGS }}
- Sync mode: ${{ steps.detect_changes.outputs.SYNC_MODE }}
- Translation files are located in: ${{ github.workspace }}/web/i18n/{locale}/{filename}.json
- Language configuration is in: ${{ github.workspace }}/web/i18n-config/languages.ts
- Git diff is available: ${{ steps.detect_changes.outputs.DIFF_AVAILABLE }}
## CRITICAL DESIGN: Verify First, Then Sync
You MUST follow this three-phase approach:
═══════════════════════════════════════════════════════════════
║ PHASE 1: VERIFY - Analyze and Generate Change Report ║
═══════════════════════════════════════════════════════════════
### Step 1.1: Analyze Git Diff (for incremental mode)
Use the Read tool to read `/tmp/i18n-diff.txt` to see the git diff.
Parse the diff to categorize changes:
- Lines with `+` (not `+++`): Added or modified values
- Lines with `-` (not `---`): Removed or old values
- Identify specific keys for each category:
* ADD: Keys that appear only in `+` lines (new keys)
* UPDATE: Keys that appear in both `-` and `+` lines (value changed)
* DELETE: Keys that appear only in `-` lines (removed keys)
### Step 1.2: Read Language Configuration
Use the Read tool to read `${{ github.workspace }}/web/i18n-config/languages.ts`.
Extract all languages with `supported: true`.
### Step 1.3: Run i18n:check for Each Language
```bash
pnpm --dir ${{ github.workspace }}/web install --frozen-lockfile
```
```bash
pnpm --dir ${{ github.workspace }}/web run i18n:check
```
This will report:
- Missing keys (need to ADD)
- Extra keys (need to DELETE)
### Step 1.4: Generate Change Report
Create a structured report identifying:
```
╔══════════════════════════════════════════════════════════════╗
║ I18N SYNC CHANGE REPORT ║
╠══════════════════════════════════════════════════════════════╣
║ Files to process: [list] ║
║ Languages to sync: [list] ║
╠══════════════════════════════════════════════════════════════╣
║ ADD (New Keys): ║
║ - [filename].[key]: "English value" ║
║ ... ║
╠══════════════════════════════════════════════════════════════╣
║ UPDATE (Modified Keys - MUST re-translate): ║
║ - [filename].[key]: "Old value" → "New value" ║
║ ... ║
╠══════════════════════════════════════════════════════════════╣
║ DELETE (Extra Keys): ║
║ - [language]/[filename].[key] ║
║ ... ║
╚══════════════════════════════════════════════════════════════╝
```
**IMPORTANT**: For UPDATE detection, compare git diff to find keys where
the English value changed. These MUST be re-translated even if target
language already has a translation (it's now stale!).
═══════════════════════════════════════════════════════════════
║ PHASE 2: SYNC - Execute Changes Based on Report ║
═══════════════════════════════════════════════════════════════
### Step 2.1: Process ADD Operations (BATCH per language file)
**CRITICAL WORKFLOW for efficiency:**
1. First, translate ALL new keys for ALL languages mentally
2. Then, for EACH language file, do ONE Edit operation:
- Read the file once
- Insert ALL new keys at the beginning (right after the opening `{`)
- Don't worry about alphabetical order - lint:fix will sort them later
Example Edit (adding 3 keys to zh-Hans/app.json):
```
old_string: '{\n "accessControl"'
new_string: '{\n "newKey1": "translation1",\n "newKey2": "translation2",\n "newKey3": "translation3",\n "accessControl"'
```
**IMPORTANT**:
- ONE Edit per language file (not one Edit per key!)
- Always use the Edit tool. NEVER use bash scripts, node, or jq.
### Step 2.2: Process UPDATE Operations
**IMPORTANT: Special handling for zh-Hans and ja-JP**
If zh-Hans or ja-JP files were ALSO modified in the same push:
- Run: `git -C ${{ github.workspace }} diff HEAD~1 --name-only` and check for zh-Hans or ja-JP files
- If found, it means someone manually translated them. Apply these rules:
1. **Missing keys**: Still ADD them (completeness required)
2. **Existing translations**: Compare with the NEW English value:
- If translation is **completely wrong** or **unrelated** → Update it
- If translation is **roughly correct** (captures the meaning) → Keep it, respect manual work
- When in doubt, **keep the manual translation**
Example:
- English changed: "Save" → "Save Changes"
- Manual translation: "保存更改" → Keep it (correct meaning)
- Manual translation: "删除" → Update it (completely wrong)
For other languages:
Use Edit tool to replace the old value with the new translation.
You can batch multiple updates in one Edit if they are adjacent.
### Step 2.3: Process DELETE Operations
For extra keys reported by i18n:check:
- Run: `pnpm --dir ${{ github.workspace }}/web run i18n:check --auto-remove`
- Or manually remove from target language JSON files
## Translation Guidelines
- PRESERVE all placeholders exactly as-is:
- `{{variable}}` - Mustache interpolation
- `${variable}` - Template literal
- `<tag>content</tag>` - HTML tags
- `_one`, `_other` - Pluralization suffixes (these are KEY suffixes, not values)
- Use appropriate language register (formal/informal) based on existing translations
- Match existing translation style in each language
- Technical terms: check existing conventions per language
- For CJK languages: no spaces between characters unless necessary
- For RTL languages (ar-TN, fa-IR): ensure proper text handling
## Output Format Requirements
- Alphabetical key ordering (if original file uses it)
- 2-space indentation
- Trailing newline at end of file
- Valid JSON (use proper escaping for special characters)
═══════════════════════════════════════════════════════════════
║ PHASE 3: RE-VERIFY - Confirm All Issues Resolved ║
═══════════════════════════════════════════════════════════════
### Step 3.1: Run Lint Fix (IMPORTANT!)
```bash
pnpm --dir ${{ github.workspace }}/web lint:fix --quiet -- 'i18n/**/*.json'
```
This ensures:
- JSON keys are sorted alphabetically (jsonc/sort-keys rule)
- Valid i18n keys (dify-i18n/valid-i18n-keys rule)
- No extra keys (dify-i18n/no-extra-keys rule)
### Step 3.2: Run Final i18n Check
```bash
pnpm --dir ${{ github.workspace }}/web run i18n:check
```
### Step 3.3: Fix Any Remaining Issues
If check reports issues:
- Go back to PHASE 2 for unresolved items
- Repeat until check passes
### Step 3.4: Generate Final Summary
```
╔══════════════════════════════════════════════════════════════╗
║ SYNC COMPLETED SUMMARY ║
╠══════════════════════════════════════════════════════════════╣
║ Language │ Added │ Updated │ Deleted │ Status ║
╠══════════════════════════════════════════════════════════════╣
║ zh-Hans │ 5 │ 2 │ 1 │ ✓ Complete ║
║ ja-JP │ 5 │ 2 │ 1 │ ✓ Complete ║
║ ... │ ... │ ... │ ... │ ... ║
╠══════════════════════════════════════════════════════════════╣
║ i18n:check │ PASSED - All keys in sync ║
╚══════════════════════════════════════════════════════════════╝
```
## Mode-Specific Behavior
**SYNC_MODE = "incremental"** (default):
- Focus on keys identified from git diff
- Also check i18n:check output for any missing/extra keys
- Efficient for small changes
**SYNC_MODE = "full"**:
- Compare ALL keys between en-US and each language
- Run i18n:check to identify all discrepancies
- Use for first-time sync or fixing historical issues
## Important Notes
1. Always run i18n:check BEFORE and AFTER making changes
2. The check script is the source of truth for missing/extra keys
3. For UPDATE scenario: git diff is the source of truth for changed values
4. Create a single commit with all translation changes
5. If any translation fails, continue with others and report failures
═══════════════════════════════════════════════════════════════
║ PHASE 4: COMMIT AND CREATE PR ║
═══════════════════════════════════════════════════════════════
After all translations are complete and verified:
### Step 4.1: Check for changes
```bash
git -C ${{ github.workspace }} status --porcelain
```
If there are changes:
### Step 4.2: Create a new branch and commit
Run these git commands ONE BY ONE (not combined with &&).
**IMPORTANT**: Do NOT use `$()` command substitution. Use two separate commands:
1. First, get the timestamp:
```bash
date +%Y%m%d-%H%M%S
```
(Note the output, e.g., "20260115-143052")
2. Then create branch using the timestamp value:
```bash
git -C ${{ github.workspace }} checkout -b chore/i18n-sync-20260115-143052
```
(Replace "20260115-143052" with the actual timestamp from step 1)
3. Stage changes:
```bash
git -C ${{ github.workspace }} add web/i18n/
```
4. Commit:
```bash
git -C ${{ github.workspace }} commit -m "chore(i18n): sync translations with en-US - Mode: ${{ steps.detect_changes.outputs.SYNC_MODE }}"
```
5. Push:
```bash
git -C ${{ github.workspace }} push origin HEAD
```
### Step 4.3: Create Pull Request
```bash
gh pr create --repo ${{ github.repository }} --title "chore(i18n): sync translations with en-US" --body "## Summary
This PR was automatically generated to sync i18n translation files.
### Changes
- Mode: ${{ steps.detect_changes.outputs.SYNC_MODE }}
- Files processed: ${{ steps.detect_changes.outputs.CHANGED_FILES }}
### Verification
- [x] \`i18n:check\` passed
- [x] \`lint:fix\` applied
🤖 Generated with Claude Code GitHub Action" --base main
```
claude_args: |
--max-turns 150
--allowedTools "Read,Write,Edit,Bash(git *),Bash(git:*),Bash(gh *),Bash(gh:*),Bash(pnpm *),Bash(pnpm:*),Bash(date *),Bash(date:*),Glob,Grep"

View File

@ -19,19 +19,19 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
persist-credentials: false
- name: Free Disk Space
uses: endersonmenezes/free-disk-space@v3
uses: endersonmenezes/free-disk-space@v2
with:
remove_dotnet: true
remove_haskell: true
remove_tool_cache: true
- name: Setup UV and Python
uses: astral-sh/setup-uv@v7
uses: astral-sh/setup-uv@v6
with:
enable-cache: true
python-version: ${{ matrix.python-version }}

View File

@ -18,7 +18,7 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
persist-credentials: false
@ -29,9 +29,9 @@ jobs:
run_install: false
- name: Setup Node.js
uses: actions/setup-node@v6
uses: actions/setup-node@v4
with:
node-version: 22
node-version: 24
cache: pnpm
cache-dependency-path: ./web/pnpm-lock.yaml
@ -360,7 +360,7 @@ jobs:
- name: Upload Coverage Artifact
if: steps.coverage-summary.outputs.has_coverage == 'true'
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v4
with:
name: web-coverage-report
path: web/coverage

34
.mcp.json Normal file
View File

@ -0,0 +1,34 @@
{
"mcpServers": {
"context7": {
"type": "http",
"url": "https://mcp.context7.com/mcp"
},
"sequential-thinking": {
"type": "stdio",
"command": "npx",
"args": ["-y", "@modelcontextprotocol/server-sequential-thinking"],
"env": {}
},
"github": {
"type": "stdio",
"command": "npx",
"args": ["-y", "@modelcontextprotocol/server-github"],
"env": {
"GITHUB_PERSONAL_ACCESS_TOKEN": "${GITHUB_PERSONAL_ACCESS_TOKEN}"
}
},
"fetch": {
"type": "stdio",
"command": "uvx",
"args": ["mcp-server-fetch"],
"env": {}
},
"playwright": {
"type": "stdio",
"command": "npx",
"args": ["-y", "@playwright/mcp@latest"],
"env": {}
}
}
}

View File

@ -50,16 +50,33 @@ WORKDIR /app/api
# Create non-root user
ARG dify_uid=1001
ARG NODE_MAJOR=22
ARG NODE_PACKAGE_VERSION=22.21.0-1nodesource1
ARG NODESOURCE_KEY_FPR=6F71F525282841EEDAF851B42F59B5F99B1BE0B4
RUN groupadd -r -g ${dify_uid} dify && \
useradd -r -u ${dify_uid} -g ${dify_uid} -s /bin/bash dify && \
chown -R dify:dify /app
RUN \
apt-get update \
&& apt-get install -y --no-install-recommends \
ca-certificates \
curl \
gnupg \
&& mkdir -p /etc/apt/keyrings \
&& curl -fsSL https://deb.nodesource.com/gpgkey/nodesource-repo.gpg.key -o /tmp/nodesource.gpg \
&& gpg --show-keys --with-colons /tmp/nodesource.gpg \
| awk -F: '/^fpr:/ {print $10}' \
| grep -Fx "${NODESOURCE_KEY_FPR}" \
&& gpg --dearmor -o /etc/apt/keyrings/nodesource.gpg /tmp/nodesource.gpg \
&& rm -f /tmp/nodesource.gpg \
&& echo "deb [signed-by=/etc/apt/keyrings/nodesource.gpg] https://deb.nodesource.com/node_${NODE_MAJOR}.x nodistro main" \
> /etc/apt/sources.list.d/nodesource.list \
&& apt-get update \
# Install dependencies
&& apt-get install -y --no-install-recommends \
# basic environment
curl nodejs \
nodejs=${NODE_PACKAGE_VERSION} \
# for gmpy2 \
libgmp-dev libmpfr-dev libmpc-dev \
# For Security

View File

@ -1,9 +1,8 @@
import base64
from typing import Literal
from flask import request
from flask_restx import Resource, fields
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator
from werkzeug.exceptions import BadRequest
from controllers.console import console_ns
@ -16,8 +15,22 @@ DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class SubscriptionQuery(BaseModel):
plan: Literal[CloudPlan.PROFESSIONAL, CloudPlan.TEAM] = Field(..., description="Subscription plan")
interval: Literal["month", "year"] = Field(..., description="Billing interval")
plan: str = Field(..., description="Subscription plan")
interval: str = Field(..., description="Billing interval")
@field_validator("plan")
@classmethod
def validate_plan(cls, value: str) -> str:
if value not in [CloudPlan.PROFESSIONAL, CloudPlan.TEAM]:
raise ValueError("Invalid plan")
return value
@field_validator("interval")
@classmethod
def validate_interval(cls, value: str) -> str:
if value not in {"month", "year"}:
raise ValueError("Invalid interval")
return value
class PartnerTenantsPayload(BaseModel):

View File

@ -1,5 +1,6 @@
import logging
from typing import Literal
from uuid import UUID
from flask import request
from flask_restx import marshal_with
@ -25,7 +26,6 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
from core.model_runtime.errors.invoke import InvokeError
from fields.message_fields import message_infinite_scroll_pagination_fields
from libs import helper
from libs.helper import UUIDStrOrEmpty
from libs.login import current_account_with_tenant
from models.model import AppMode
from services.app_generate_service import AppGenerateService
@ -44,8 +44,8 @@ logger = logging.getLogger(__name__)
class MessageListQuery(BaseModel):
conversation_id: UUIDStrOrEmpty
first_id: UUIDStrOrEmpty | None = None
conversation_id: UUID
first_id: UUID | None = None
limit: int = Field(default=20, ge=1, le=100)

View File

@ -1,3 +1,5 @@
from uuid import UUID
from flask import request
from flask_restx import fields, marshal_with
from pydantic import BaseModel, Field
@ -8,19 +10,19 @@ from controllers.console import console_ns
from controllers.console.explore.error import NotCompletionAppError
from controllers.console.explore.wraps import InstalledAppResource
from fields.conversation_fields import message_file_fields
from libs.helper import TimestampField, UUIDStrOrEmpty
from libs.helper import TimestampField
from libs.login import current_account_with_tenant
from services.errors.message import MessageNotExistsError
from services.saved_message_service import SavedMessageService
class SavedMessageListQuery(BaseModel):
last_id: UUIDStrOrEmpty | None = None
last_id: UUID | None = None
limit: int = Field(default=20, ge=1, le=100)
class SavedMessageCreatePayload(BaseModel):
message_id: UUIDStrOrEmpty
message_id: UUID
register_schema_models(console_ns, SavedMessageListQuery, SavedMessageCreatePayload)

View File

@ -1,8 +1,6 @@
from flask_restx import Resource
from pydantic import BaseModel
from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.entities.model_entities import ModelType
@ -12,20 +10,10 @@ from models import TenantAccountRole
from services.model_load_balancing_service import ModelLoadBalancingService
class LoadBalancingCredentialPayload(BaseModel):
model: str
model_type: ModelType
credentials: dict[str, object]
register_schema_models(console_ns, LoadBalancingCredentialPayload)
@console_ns.route(
"/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/credentials-validate"
)
class LoadBalancingCredentialsValidateApi(Resource):
@console_ns.expect(console_ns.models[LoadBalancingCredentialPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@ -36,7 +24,20 @@ class LoadBalancingCredentialsValidateApi(Resource):
tenant_id = current_tenant_id
payload = LoadBalancingCredentialPayload.model_validate(console_ns.payload or {})
parser = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
)
args = parser.parse_args()
# validate model load balancing credentials
model_load_balancing_service = ModelLoadBalancingService()
@ -48,9 +49,9 @@ class LoadBalancingCredentialsValidateApi(Resource):
model_load_balancing_service.validate_load_balancing_credentials(
tenant_id=tenant_id,
provider=provider,
model=payload.model,
model_type=payload.model_type,
credentials=payload.credentials,
model=args["model"],
model_type=args["model_type"],
credentials=args["credentials"],
)
except CredentialsValidateFailedError as ex:
result = False
@ -68,7 +69,6 @@ class LoadBalancingCredentialsValidateApi(Resource):
"/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/<string:config_id>/credentials-validate"
)
class LoadBalancingConfigCredentialsValidateApi(Resource):
@console_ns.expect(console_ns.models[LoadBalancingCredentialPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@ -79,7 +79,20 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
tenant_id = current_tenant_id
payload = LoadBalancingCredentialPayload.model_validate(console_ns.payload or {})
parser = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
)
args = parser.parse_args()
# validate model load balancing config credentials
model_load_balancing_service = ModelLoadBalancingService()
@ -91,9 +104,9 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
model_load_balancing_service.validate_load_balancing_credentials(
tenant_id=tenant_id,
provider=provider,
model=payload.model,
model_type=payload.model_type,
credentials=payload.credentials,
model=args["model"],
model_type=args["model_type"],
credentials=args["credentials"],
config_id=config_id,
)
except CredentialsValidateFailedError as ex:

View File

@ -1,5 +1,4 @@
import io
import logging
from urllib.parse import urlparse
from flask import make_response, redirect, request, send_file
@ -18,7 +17,6 @@ from controllers.console.wraps import (
is_admin_or_owner_required,
setup_required,
)
from core.db.session_factory import session_factory
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration
from core.helper.tool_provider_cache import ToolProviderListCache
from core.mcp.auth.auth_flow import auth, handle_callback
@ -42,8 +40,6 @@ from services.tools.tools_manage_service import ToolCommonService
from services.tools.tools_transform_service import ToolTransformService
from services.tools.workflow_tools_manage_service import WorkflowToolManageService
logger = logging.getLogger(__name__)
def is_valid_url(url: str) -> bool:
if not url:
@ -949,8 +945,8 @@ class ToolProviderMCPApi(Resource):
configuration = MCPConfiguration.model_validate(args["configuration"])
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
# 1) Create provider in a short transaction (no network I/O inside)
with session_factory.create_session() as session, session.begin():
# Create provider in transaction
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
result = service.create_provider(
tenant_id=tenant_id,
@ -966,28 +962,7 @@ class ToolProviderMCPApi(Resource):
authentication=authentication,
)
# 2) Try to fetch tools immediately after creation so they appear without a second save.
# Perform network I/O outside any DB session to avoid holding locks.
try:
reconnect = MCPToolManageService.reconnect_with_url(
server_url=args["server_url"],
headers=args.get("headers") or {},
timeout=configuration.timeout,
sse_read_timeout=configuration.sse_read_timeout,
)
# Update just-created provider with authed/tools in a new short transaction
with session_factory.create_session() as session, session.begin():
service = MCPToolManageService(session=session)
db_provider = service.get_provider(provider_id=result.id, tenant_id=tenant_id)
db_provider.authed = reconnect.authed
db_provider.tools = reconnect.tools
result = ToolTransformService.mcp_provider_to_user_provider(db_provider, for_list=True)
except Exception:
# Best-effort: if initial fetch fails (e.g., auth required), return created provider as-is
logger.warning("Failed to fetch MCP tools after creation", exc_info=True)
# Final cache invalidation to ensure list views are up to date
# Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations
ToolProviderListCache.invalidate_cache(tenant_id)
return jsonable_encoder(result)

View File

@ -13,6 +13,7 @@ from controllers.service_api.dataset.error import DatasetInUseError, DatasetName
from controllers.service_api.wraps import (
DatasetApiResource,
cloud_edition_billing_rate_limit_check,
validate_dataset_token,
)
from core.model_runtime.entities.model_entities import ModelType
from core.provider_manager import ProviderManager
@ -459,8 +460,9 @@ class DatasetTagsApi(DatasetApiResource):
401: "Unauthorized - invalid API token",
}
)
@validate_dataset_token
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
def get(self, _):
def get(self, _, dataset_id):
"""Get all knowledge type tags."""
assert isinstance(current_user, Account)
cid = current_user.current_tenant_id
@ -480,7 +482,8 @@ class DatasetTagsApi(DatasetApiResource):
}
)
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
def post(self, _):
@validate_dataset_token
def post(self, _, dataset_id):
"""Add a knowledge type tag."""
assert isinstance(current_user, Account)
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
@ -503,7 +506,8 @@ class DatasetTagsApi(DatasetApiResource):
}
)
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
def patch(self, _):
@validate_dataset_token
def patch(self, _, dataset_id):
assert isinstance(current_user, Account)
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
@ -529,8 +533,9 @@ class DatasetTagsApi(DatasetApiResource):
403: "Forbidden - insufficient permissions",
}
)
@validate_dataset_token
@edit_permission_required
def delete(self, _):
def delete(self, _, dataset_id):
"""Delete a knowledge type tag."""
payload = TagDeletePayload.model_validate(service_api_ns.payload or {})
TagService.delete_tag(payload.tag_id)
@ -550,7 +555,8 @@ class DatasetTagBindingApi(DatasetApiResource):
403: "Forbidden - insufficient permissions",
}
)
def post(self, _):
@validate_dataset_token
def post(self, _, dataset_id):
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
assert isinstance(current_user, Account)
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
@ -574,7 +580,8 @@ class DatasetTagUnbindingApi(DatasetApiResource):
403: "Forbidden - insufficient permissions",
}
)
def post(self, _):
@validate_dataset_token
def post(self, _, dataset_id):
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
assert isinstance(current_user, Account)
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
@ -597,6 +604,7 @@ class DatasetTagsBindingStatusApi(DatasetApiResource):
401: "Unauthorized - invalid API token",
}
)
@validate_dataset_token
def get(self, _, *args, **kwargs):
"""Get all knowledge type tags."""
dataset_id = kwargs.get("dataset_id")

View File

@ -10,7 +10,12 @@ from controllers.console.auth.error import (
InvalidEmailError,
)
from controllers.console.error import AccountBannedError
from controllers.console.wraps import only_edition_enterprise, setup_required
from controllers.console.wraps import (
decrypt_code_field,
decrypt_password_field,
only_edition_enterprise,
setup_required,
)
from controllers.web import web_ns
from controllers.web.wraps import decode_jwt_token
from libs.helper import email
@ -42,6 +47,7 @@ class LoginApi(Resource):
404: "Account not found",
}
)
@decrypt_password_field
def post(self):
"""Authenticate user and login."""
parser = (
@ -181,6 +187,7 @@ class EmailCodeLoginApi(Resource):
404: "Account not found",
}
)
@decrypt_code_field
def post(self):
parser = (
reqparse.RequestParser()

View File

@ -256,7 +256,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
single_iteration_run=AdvancedChatAppGenerateEntity.SingleIterationRunEntity(
node_id=node_id, inputs=args["inputs"]
),
workflow_run_id=str(uuid.uuid4()),
)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
@ -340,7 +339,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
invoke_from=InvokeFrom.DEBUGGER,
extras={"auto_generate_conversation_name": False},
single_loop_run=AdvancedChatAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]),
workflow_run_id=str(uuid.uuid4()),
)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())

View File

@ -26,10 +26,7 @@ from core.variables.variables import VariableUnion
from core.workflow.enums import WorkflowType
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_engine.layers.persistence import (
PersistenceWorkflowInfo,
WorkflowPersistenceLayer,
)
from core.workflow.graph_engine.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.runtime import GraphRuntimeState, VariablePool
@ -112,7 +109,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
workflow=self._workflow,
single_iteration_run=self.application_generate_entity.single_iteration_run,
single_loop_run=self.application_generate_entity.single_loop_run,
system_variables=system_inputs,
)
else:
inputs = self.application_generate_entity.inputs
@ -190,20 +186,20 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
self._queue_manager.graph_runtime_state = graph_runtime_state
if not self.application_generate_entity.is_single_stepping_container_nodes():
persistence_layer = WorkflowPersistenceLayer(
application_generate_entity=self.application_generate_entity,
workflow_info=PersistenceWorkflowInfo(
workflow_id=self._workflow.id,
workflow_type=WorkflowType(self._workflow.type),
version=self._workflow.version,
graph_data=self._workflow.graph_dict,
),
workflow_execution_repository=self._workflow_execution_repository,
workflow_node_execution_repository=self._workflow_node_execution_repository,
trace_manager=self.application_generate_entity.trace_manager,
)
workflow_entry.graph_engine.layer(persistence_layer)
persistence_layer = WorkflowPersistenceLayer(
application_generate_entity=self.application_generate_entity,
workflow_info=PersistenceWorkflowInfo(
workflow_id=self._workflow.id,
workflow_type=WorkflowType(self._workflow.type),
version=self._workflow.version,
graph_data=self._workflow.graph_dict,
),
workflow_execution_repository=self._workflow_execution_repository,
workflow_node_execution_repository=self._workflow_node_execution_repository,
trace_manager=self.application_generate_entity.trace_manager,
)
workflow_entry.graph_engine.layer(persistence_layer)
for layer in self._graph_engine_layers:
workflow_entry.graph_engine.layer(layer)

View File

@ -90,7 +90,6 @@ class AppQueueManager:
"""
self._clear_task_belong_cache()
self._q.put(None)
self._graph_runtime_state = None # Release reference to allow GC to reclaim memory
def _clear_task_belong_cache(self) -> None:
"""

View File

@ -92,22 +92,6 @@ class PipelineRunner(WorkflowBasedAppRunner):
db.session.close()
files = self.application_generate_entity.files
system_inputs = SystemVariable(
files=files,
user_id=user_id,
app_id=app_config.app_id,
workflow_id=app_config.workflow_id,
workflow_execution_id=self.application_generate_entity.workflow_execution_id,
document_id=self.application_generate_entity.document_id,
original_document_id=self.application_generate_entity.original_document_id,
batch=self.application_generate_entity.batch,
dataset_id=self.application_generate_entity.dataset_id,
datasource_type=self.application_generate_entity.datasource_type,
datasource_info=self.application_generate_entity.datasource_info,
invoke_from=self.application_generate_entity.invoke_from.value,
)
# if only single iteration run is requested
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
# Handle single iteration or single loop run
@ -115,12 +99,27 @@ class PipelineRunner(WorkflowBasedAppRunner):
workflow=workflow,
single_iteration_run=self.application_generate_entity.single_iteration_run,
single_loop_run=self.application_generate_entity.single_loop_run,
system_variables=system_inputs,
)
else:
inputs = self.application_generate_entity.inputs
files = self.application_generate_entity.files
# Create a variable pool.
system_inputs = SystemVariable(
files=files,
user_id=user_id,
app_id=app_config.app_id,
workflow_id=app_config.workflow_id,
workflow_execution_id=self.application_generate_entity.workflow_execution_id,
document_id=self.application_generate_entity.document_id,
original_document_id=self.application_generate_entity.original_document_id,
batch=self.application_generate_entity.batch,
dataset_id=self.application_generate_entity.dataset_id,
datasource_type=self.application_generate_entity.datasource_type,
datasource_info=self.application_generate_entity.datasource_info,
invoke_from=self.application_generate_entity.invoke_from.value,
)
rag_pipeline_variables = []
if workflow.rag_pipeline_variables:
for v in workflow.rag_pipeline_variables:
@ -172,21 +171,21 @@ class PipelineRunner(WorkflowBasedAppRunner):
)
self._queue_manager.graph_runtime_state = graph_runtime_state
if not self.application_generate_entity.is_single_stepping_container_nodes():
persistence_layer = WorkflowPersistenceLayer(
application_generate_entity=self.application_generate_entity,
workflow_info=PersistenceWorkflowInfo(
workflow_id=workflow.id,
workflow_type=WorkflowType(workflow.type),
version=workflow.version,
graph_data=workflow.graph_dict,
),
workflow_execution_repository=self._workflow_execution_repository,
workflow_node_execution_repository=self._workflow_node_execution_repository,
trace_manager=self.application_generate_entity.trace_manager,
)
workflow_entry.graph_engine.layer(persistence_layer)
persistence_layer = WorkflowPersistenceLayer(
application_generate_entity=self.application_generate_entity,
workflow_info=PersistenceWorkflowInfo(
workflow_id=workflow.id,
workflow_type=WorkflowType(workflow.type),
version=workflow.version,
graph_data=workflow.graph_dict,
),
workflow_execution_repository=self._workflow_execution_repository,
workflow_node_execution_repository=self._workflow_node_execution_repository,
trace_manager=self.application_generate_entity.trace_manager,
)
workflow_entry.graph_engine.layer(persistence_layer)
generator = workflow_entry.run()

View File

@ -10,10 +10,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerat
from core.workflow.enums import WorkflowType
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_engine.layers.persistence import (
PersistenceWorkflowInfo,
WorkflowPersistenceLayer,
)
from core.workflow.graph_engine.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.runtime import GraphRuntimeState, VariablePool
@ -83,7 +80,6 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
workflow=self._workflow,
single_iteration_run=self.application_generate_entity.single_iteration_run,
single_loop_run=self.application_generate_entity.single_loop_run,
system_variables=system_inputs,
)
else:
inputs = self.application_generate_entity.inputs
@ -136,21 +132,20 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
command_channel=command_channel,
)
if not self.application_generate_entity.is_single_stepping_container_nodes():
persistence_layer = WorkflowPersistenceLayer(
application_generate_entity=self.application_generate_entity,
workflow_info=PersistenceWorkflowInfo(
workflow_id=self._workflow.id,
workflow_type=WorkflowType(self._workflow.type),
version=self._workflow.version,
graph_data=self._workflow.graph_dict,
),
workflow_execution_repository=self._workflow_execution_repository,
workflow_node_execution_repository=self._workflow_node_execution_repository,
trace_manager=self.application_generate_entity.trace_manager,
)
persistence_layer = WorkflowPersistenceLayer(
application_generate_entity=self.application_generate_entity,
workflow_info=PersistenceWorkflowInfo(
workflow_id=self._workflow.id,
workflow_type=WorkflowType(self._workflow.type),
version=self._workflow.version,
graph_data=self._workflow.graph_dict,
),
workflow_execution_repository=self._workflow_execution_repository,
workflow_node_execution_repository=self._workflow_node_execution_repository,
trace_manager=self.application_generate_entity.trace_manager,
)
workflow_entry.graph_engine.layer(persistence_layer)
workflow_entry.graph_engine.layer(persistence_layer)
for layer in self._graph_engine_layers:
workflow_entry.graph_engine.layer(layer)

View File

@ -130,8 +130,6 @@ class WorkflowBasedAppRunner:
workflow: Workflow,
single_iteration_run: Any | None = None,
single_loop_run: Any | None = None,
*,
system_variables: SystemVariable | None = None,
) -> tuple[Graph, VariablePool, GraphRuntimeState]:
"""
Prepare graph, variable pool, and runtime state for single node execution
@ -149,10 +147,9 @@ class WorkflowBasedAppRunner:
ValueError: If neither single_iteration_run nor single_loop_run is specified
"""
# Create initial runtime state with variable pool containing environment variables
system_variables = system_variables or SystemVariable.empty()
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(
system_variables=system_variables,
system_variables=SystemVariable.empty(),
user_inputs={},
environment_variables=workflow.environment_variables,
),
@ -223,7 +220,7 @@ class WorkflowBasedAppRunner:
# filter nodes only in the specified node type (iteration or loop)
main_node_config = next((n for n in graph_config.get("nodes", []) if n.get("id") == node_id), None)
start_node_id = main_node_config.get("data", {}).get("start_node_id") if main_node_config else None
base_node_configs = [
node_configs = [
node
for node in graph_config.get("nodes", [])
if node.get("id") == node_id
@ -231,74 +228,26 @@ class WorkflowBasedAppRunner:
or (start_node_id and node.get("id") == start_node_id)
]
# Build a base graph config (without synthetic entry) to keep node-level context minimal.
base_graph_config = graph_config.copy()
base_graph_config["nodes"] = base_node_configs
graph_config["nodes"] = node_configs
node_ids = [node.get("id") for node in base_node_configs if isinstance(node.get("id"), str)]
node_ids = [node.get("id") for node in node_configs]
# filter edges only in the specified node type
base_edge_configs = [
edge_configs = [
edge
for edge in graph_config.get("edges", [])
if (edge.get("source") is None or edge.get("source") in node_ids)
and (edge.get("target") is None or edge.get("target") in node_ids)
]
base_graph_config["edges"] = base_edge_configs
# Inject a synthetic start node so Graph validation accepts the single-node graph
# (loop/iteration nodes are containers and cannot serve as graph roots).
synthetic_start_node_id = f"{node_id}_single_step_start"
synthetic_start_node = {
"id": synthetic_start_node_id,
"type": "custom",
"data": {
"type": NodeType.START,
"title": "Start",
"desc": "Synthetic start for single-step run",
"version": "1",
"variables": [],
},
}
synthetic_end_node_id = f"{node_id}_single_step_end"
synthetic_end_node = {
"id": synthetic_end_node_id,
"type": "custom",
"data": {
"type": NodeType.END,
"title": "End",
"desc": "Synthetic end for single-step run",
"version": "1",
"outputs": [],
},
}
graph_config_with_entry = base_graph_config.copy()
graph_config_with_entry["nodes"] = [*base_node_configs, synthetic_start_node, synthetic_end_node]
graph_config_with_entry["edges"] = [
*base_edge_configs,
{
"source": synthetic_start_node_id,
"target": node_id,
"sourceHandle": "source",
"targetHandle": "target",
},
{
"source": node_id,
"target": synthetic_end_node_id,
"sourceHandle": "source",
"targetHandle": "target",
},
]
graph_config["edges"] = edge_configs
# Create required parameters for Graph.init
graph_init_params = GraphInitParams(
tenant_id=workflow.tenant_id,
app_id=self._app_id,
workflow_id=workflow.id,
graph_config=base_graph_config,
graph_config=graph_config,
user_id="",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
@ -311,16 +260,14 @@ class WorkflowBasedAppRunner:
)
# init graph
graph = Graph.init(
graph_config=graph_config_with_entry, node_factory=node_factory, root_node_id=synthetic_start_node_id
)
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=node_id)
if not graph:
raise ValueError("graph not found in workflow")
# fetch node config from node id
target_node_config = None
for node in base_node_configs:
for node in node_configs:
if node.get("id") == node_id:
target_node_config = node
break

View File

@ -228,9 +228,6 @@ class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity):
single_loop_run: SingleLoopRunEntity | None = None
def is_single_stepping_container_nodes(self) -> bool:
return self.single_iteration_run is not None or self.single_loop_run is not None
class WorkflowAppGenerateEntity(AppGenerateEntity):
"""
@ -261,9 +258,6 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
single_loop_run: SingleLoopRunEntity | None = None
def is_single_stepping_container_nodes(self) -> bool:
return self.single_iteration_run is not None or self.single_loop_run is not None
class RagPipelineGenerateEntity(WorkflowAppGenerateEntity):
"""

View File

@ -1,14 +1,9 @@
from collections.abc import Mapping
from textwrap import dedent
from typing import Any
from core.helper.code_executor.template_transformer import TemplateTransformer
class Jinja2TemplateTransformer(TemplateTransformer):
# Use separate placeholder for base64-encoded template to avoid confusion
_template_b64_placeholder: str = "{{template_b64}}"
@classmethod
def transform_response(cls, response: str):
"""
@ -18,35 +13,18 @@ class Jinja2TemplateTransformer(TemplateTransformer):
"""
return {"result": cls.extract_result_str_from_response(response)}
@classmethod
def assemble_runner_script(cls, code: str, inputs: Mapping[str, Any]) -> str:
"""
Override base class to use base64 encoding for template code.
This prevents issues with special characters (quotes, newlines) in templates
breaking the generated Python script. Fixes #26818.
"""
script = cls.get_runner_script()
# Encode template as base64 to safely embed any content including quotes
code_b64 = cls.serialize_code(code)
script = script.replace(cls._template_b64_placeholder, code_b64)
inputs_str = cls.serialize_inputs(inputs)
script = script.replace(cls._inputs_placeholder, inputs_str)
return script
@classmethod
def get_runner_script(cls) -> str:
runner_script = dedent(f"""
import jinja2
import json
from base64 import b64decode
# declare main function
def main(**inputs):
# Decode base64-encoded template to handle special characters safely
template_code = b64decode('{cls._template_b64_placeholder}').decode('utf-8')
template = jinja2.Template(template_code)
import jinja2
template = jinja2.Template('''{cls._code_placeholder}''')
return template.render(**inputs)
import json
from base64 import b64decode
# decode and prepare input dict
inputs_obj = json.loads(b64decode('{cls._inputs_placeholder}').decode('utf-8'))

View File

@ -13,15 +13,6 @@ class TemplateTransformer(ABC):
_inputs_placeholder: str = "{{inputs}}"
_result_tag: str = "<<RESULT>>"
@classmethod
def serialize_code(cls, code: str) -> str:
"""
Serialize template code to base64 to safely embed in generated script.
This prevents issues with special characters like quotes breaking the script.
"""
code_bytes = code.encode("utf-8")
return b64encode(code_bytes).decode("utf-8")
@classmethod
def transform_caller(cls, code: str, inputs: Mapping[str, Any]) -> tuple[str, str]:
"""

View File

@ -1,6 +1,6 @@
import json
import logging
from typing import Any, cast
from typing import Any
from core.tools.entities.api_entities import ToolProviderTypeApiLiteral
from extensions.ext_redis import redis_client, redis_fallback
@ -50,9 +50,7 @@ class ToolProviderListCache:
redis_client.delete(cache_key)
else:
# Invalidate all caches for this tenant
keys = ["builtin", "model", "api", "workflow", "mcp"]
pipeline = redis_client.pipeline()
for key in keys:
cache_key = ToolProviderListCache._generate_cache_key(tenant_id, cast(ToolProviderTypeApiLiteral, key))
pipeline.delete(cache_key)
pipeline.execute()
pattern = f"tool_providers:tenant_id:{tenant_id}:*"
keys = list(redis_client.scan_iter(pattern))
if keys:
redis_client.delete(*keys)

View File

@ -313,20 +313,17 @@ class StreamableHTTPTransport:
if is_initialization:
self._maybe_extract_session_id_from_response(response)
# Per https://modelcontextprotocol.io/specification/2025-06-18/basic#notifications:
# The server MUST NOT send a response to notifications.
if isinstance(message.root, JSONRPCRequest):
content_type = cast(str, response.headers.get(CONTENT_TYPE, "").lower())
content_type = cast(str, response.headers.get(CONTENT_TYPE, "").lower())
if content_type.startswith(JSON):
self._handle_json_response(response, ctx.server_to_client_queue)
elif content_type.startswith(SSE):
self._handle_sse_response(response, ctx)
else:
self._handle_unexpected_content_type(
content_type,
ctx.server_to_client_queue,
)
if content_type.startswith(JSON):
self._handle_json_response(response, ctx.server_to_client_queue)
elif content_type.startswith(SSE):
self._handle_sse_response(response, ctx)
else:
self._handle_unexpected_content_type(
content_type,
ctx.server_to_client_queue,
)
def _handle_json_response(
self,

View File

@ -76,7 +76,7 @@ class PluginParameter(BaseModel):
auto_generate: PluginParameterAutoGenerate | None = None
template: PluginParameterTemplate | None = None
required: bool = False
default: Union[float, int, str, bool, list, dict] | None = None
default: Union[float, int, str, bool] | None = None
min: Union[float, int] | None = None
max: Union[float, int] | None = None
precision: int | None = None

View File

@ -13,7 +13,7 @@ from core.model_runtime.entities.model_entities import ModelType
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.embedding.retrieval import RetrievalChildChunk, RetrievalSegments
from core.rag.embedding.retrieval import RetrievalSegments
from core.rag.entities.metadata_entities import MetadataCondition
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
@ -416,12 +416,12 @@ class RetrievalService:
child_index_node_ids = [i for i in child_index_node_ids if i]
index_node_ids = [i for i in index_node_ids if i]
segment_ids: list[str] = []
segment_ids = []
index_node_segments: list[DocumentSegment] = []
segments: list[DocumentSegment] = []
attachment_map: dict[str, list[dict[str, Any]]] = {}
child_chunk_map: dict[str, list[ChildChunk]] = {}
doc_segment_map: dict[str, list[str]] = {}
attachment_map = {}
child_chunk_map: dict[Any, Any] = {}
doc_segment_map = {}
with session_factory.create_session() as session:
attachments = cls.get_segment_attachment_infos(image_doc_ids, session)
@ -432,7 +432,7 @@ class RetrievalService:
attachment_map[attachment["segment_id"]].append(attachment["attachment_info"])
else:
attachment_map[attachment["segment_id"]] = [attachment["attachment_info"]]
if attachment["segment_id"] in doc_segment_map:
if attachment["attachment_id"] in doc_segment_map:
doc_segment_map[attachment["segment_id"]].append(attachment["attachment_id"])
else:
doc_segment_map[attachment["segment_id"]] = [attachment["attachment_id"]]
@ -502,7 +502,7 @@ class RetrievalService:
"child_chunks": child_chunk_details,
}
segment_child_map[segment.id] = map_detail
record: dict[str, Any] = {
record = {
"segment": segment,
}
records.append(record)
@ -510,13 +510,13 @@ class RetrievalService:
if segment.id not in include_segment_ids:
include_segment_ids.add(segment.id)
max_score = 0.0
segment_document = doc_to_document_map.get(segment.index_node_id)
if segment_document:
max_score = max(max_score, segment_document.metadata.get("score", 0.0))
document = doc_to_document_map.get(segment.index_node_id)
if document:
max_score = max(max_score, document.metadata.get("score", 0.0))
for attachment_info in attachment_infos:
file_doc = doc_to_document_map.get(attachment_info["id"])
if file_doc:
max_score = max(max_score, file_doc.metadata.get("score", 0.0))
file_document = doc_to_document_map.get(attachment_info["id"])
if file_document:
max_score = max(max_score, file_document.metadata.get("score", 0.0))
record = {
"segment": segment,
"score": max_score,
@ -531,26 +531,18 @@ class RetrievalService:
if record["segment"].id in attachment_map:
record["files"] = attachment_map[record["segment"].id] # type: ignore[assignment]
result: list[RetrievalSegments] = []
result = []
for record in records:
# Extract segment
segment = record["segment"]
# Extract child_chunks, ensuring it's a list or None
raw_child_chunks = record.get("child_chunks")
child_chunks_list: list[RetrievalChildChunk] | None = None
if isinstance(raw_child_chunks, list):
# Sort by score descending
sorted_chunks = sorted(raw_child_chunks, key=lambda x: x.get("score", 0.0), reverse=True)
child_chunks_list = [
RetrievalChildChunk(
id=chunk["id"],
content=chunk["content"],
score=chunk.get("score", 0.0),
position=chunk["position"],
)
for chunk in sorted_chunks
]
child_chunks = record.get("child_chunks")
if not isinstance(child_chunks, list):
child_chunks = None
if child_chunks:
child_chunks = sorted(child_chunks, key=lambda x: x.get("score", 0.0), reverse=True)
# Extract files, ensuring it's a list or None
files = record.get("files")
@ -567,11 +559,11 @@ class RetrievalService:
# Create RetrievalSegments object
retrieval_segment = RetrievalSegments(
segment=segment, child_chunks=child_chunks_list, score=score, files=files
segment=segment, child_chunks=child_chunks, score=score, files=files
)
result.append(retrieval_segment)
return sorted(result, key=lambda x: x.score if x.score is not None else 0.0, reverse=True)
return sorted(result, key=lambda x: x.score, reverse=True)
except Exception as e:
db.session.rollback()
raise e

View File

@ -255,10 +255,7 @@ class PGVector(BaseVector):
return
with self._get_cursor() as cur:
cur.execute("SELECT 1 FROM pg_extension WHERE extname = 'vector'")
if not cur.fetchone():
cur.execute("CREATE EXTENSION vector")
cur.execute("CREATE EXTENSION IF NOT EXISTS vector")
cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name, dimension=dimension))
# PG hnsw index only support 2000 dimension or less
# ref: https://github.com/pgvector/pgvector?tab=readme-ov-file#indexing

View File

@ -7,7 +7,7 @@ from collections.abc import Generator, Mapping
from typing import Any, Union, cast
from flask import Flask, current_app
from sqlalchemy import and_, literal, or_, select
from sqlalchemy import and_, or_, select
from sqlalchemy.orm import Session
from core.app.app_config.entities import (
@ -1036,7 +1036,7 @@ class DatasetRetrieval:
if automatic_metadata_filters:
conditions = []
for sequence, filter in enumerate(automatic_metadata_filters):
self.process_metadata_filter_func(
self._process_metadata_filter_func(
sequence,
filter.get("condition"), # type: ignore
filter.get("metadata_name"), # type: ignore
@ -1072,7 +1072,7 @@ class DatasetRetrieval:
value=expected_value,
)
)
filters = self.process_metadata_filter_func(
filters = self._process_metadata_filter_func(
sequence,
condition.comparison_operator,
metadata_name,
@ -1168,9 +1168,8 @@ class DatasetRetrieval:
return None
return automatic_metadata_filters
@classmethod
def process_metadata_filter_func(
cls, sequence: int, condition: str, metadata_name: str, value: Any | None, filters: list
def _process_metadata_filter_func(
self, sequence: int, condition: str, metadata_name: str, value: Any | None, filters: list
):
if value is None and condition not in ("empty", "not empty"):
return filters
@ -1219,20 +1218,6 @@ class DatasetRetrieval:
case "" | ">=":
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() >= value)
case "in" | "not in":
if isinstance(value, str):
value_list = [v.strip() for v in value.split(",") if v.strip()]
elif isinstance(value, (list, tuple)):
value_list = [str(v) for v in value if v is not None]
else:
value_list = [str(value)] if value is not None else []
if not value_list:
# `field in []` is False, `field not in []` is True
filters.append(literal(condition == "not in"))
else:
op = json_field.in_ if condition == "in" else json_field.notin_
filters.append(op(value_list))
case _:
pass

View File

@ -6,15 +6,7 @@ from typing import Any
from core.mcp.auth_client import MCPClientWithAuthRetry
from core.mcp.error import MCPConnectionError
from core.mcp.types import (
AudioContent,
BlobResourceContents,
CallToolResult,
EmbeddedResource,
ImageContent,
TextContent,
TextResourceContents,
)
from core.mcp.types import AudioContent, CallToolResult, ImageContent, TextContent
from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType
@ -61,19 +53,10 @@ class MCPTool(Tool):
for content in result.content:
if isinstance(content, TextContent):
yield from self._process_text_content(content)
elif isinstance(content, ImageContent | AudioContent):
yield self.create_blob_message(
blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType}
)
elif isinstance(content, EmbeddedResource):
resource = content.resource
if isinstance(resource, TextResourceContents):
yield self.create_text_message(resource.text)
elif isinstance(resource, BlobResourceContents):
mime_type = resource.mimeType or "application/octet-stream"
yield self.create_blob_message(blob=base64.b64decode(resource.blob), meta={"mime_type": mime_type})
else:
raise ToolInvokeError(f"Unsupported embedded resource type: {type(resource)}")
elif isinstance(content, ImageContent):
yield self._process_image_content(content)
elif isinstance(content, AudioContent):
yield self._process_audio_content(content)
else:
logger.warning("Unsupported content type=%s", type(content))
@ -118,6 +101,14 @@ class MCPTool(Tool):
for item in json_list:
yield self.create_json_message(item)
def _process_image_content(self, content: ImageContent) -> ToolInvokeMessage:
"""Process image content and return a blob message."""
return self.create_blob_message(blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType})
def _process_audio_content(self, content: AudioContent) -> ToolInvokeMessage:
"""Process audio content and return a blob message."""
return self.create_blob_message(blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType})
def fork_tool_runtime(self, runtime: ToolRuntime) -> "MCPTool":
return MCPTool(
entity=self.entity,

View File

@ -5,7 +5,6 @@ from sqlalchemy.orm import Session
from core.app.app_config.entities import VariableEntity, VariableEntityType
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.db.session_factory import session_factory
from core.plugin.entities.parameters import PluginParameterOption
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.__base.tool_runtime import ToolRuntime
@ -48,30 +47,33 @@ class WorkflowToolProviderController(ToolProviderController):
@classmethod
def from_db(cls, db_provider: WorkflowToolProvider) -> "WorkflowToolProviderController":
with session_factory.create_session() as session, session.begin():
app = session.get(App, db_provider.app_id)
with Session(db.engine, expire_on_commit=False) as session, session.begin():
provider = session.get(WorkflowToolProvider, db_provider.id) if db_provider.id else None
if not provider:
raise ValueError("workflow provider not found")
app = session.get(App, provider.app_id)
if not app:
raise ValueError("app not found")
user = session.get(Account, db_provider.user_id) if db_provider.user_id else None
user = session.get(Account, provider.user_id) if provider.user_id else None
controller = WorkflowToolProviderController(
entity=ToolProviderEntity(
identity=ToolProviderIdentity(
author=user.name if user else "",
name=db_provider.label,
label=I18nObject(en_US=db_provider.label, zh_Hans=db_provider.label),
description=I18nObject(en_US=db_provider.description, zh_Hans=db_provider.description),
icon=db_provider.icon,
name=provider.label,
label=I18nObject(en_US=provider.label, zh_Hans=provider.label),
description=I18nObject(en_US=provider.description, zh_Hans=provider.description),
icon=provider.icon,
),
credentials_schema=[],
plugin_id=None,
),
provider_id="",
provider_id=provider.id or "",
)
controller.tools = [
controller._get_db_provider_tool(db_provider, app, session=session, user=user),
controller._get_db_provider_tool(provider, app, session=session, user=user),
]
return controller

View File

@ -149,49 +149,18 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
return isinstance(variable, NoneSegment) or len(variable.value) == 0
def _handle_empty_iteration(self, variable: ArraySegment | NoneSegment) -> Generator[NodeEventBase, None, None]:
started_at = naive_utc_now()
inputs: dict[str, object] = {"iterator_selector": []}
usage = LLMUsage.empty_usage()
yield IterationStartedEvent(
start_at=started_at,
inputs=inputs,
metadata={"iteration_length": 0},
)
# Try our best to preserve the type information.
if isinstance(variable, ArraySegment):
output = variable.model_copy(update={"value": []})
else:
output = ArrayAnySegment(value=[])
yield IterationSucceededEvent(
start_at=started_at,
inputs=inputs,
outputs={"output": []},
steps=0,
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: {},
},
)
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
# TODO(QuantumGhost): is it possible to compute the type of `output`
# from graph definition?
outputs={"output": output},
inputs=inputs,
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: {},
},
llm_usage=usage,
)
)

View File

@ -6,7 +6,7 @@ from collections import defaultdict
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, cast
from sqlalchemy import and_, func, or_, select
from sqlalchemy import and_, func, literal, or_, select
from sqlalchemy.orm import sessionmaker
from core.app.app_config.entities import DatasetRetrieveConfigEntity
@ -460,7 +460,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
if automatic_metadata_filters:
conditions = []
for sequence, filter in enumerate(automatic_metadata_filters):
DatasetRetrieval.process_metadata_filter_func(
self._process_metadata_filter_func(
sequence,
filter.get("condition", ""),
filter.get("metadata_name", ""),
@ -504,7 +504,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
value=expected_value,
)
)
filters = DatasetRetrieval.process_metadata_filter_func(
filters = self._process_metadata_filter_func(
sequence,
condition.comparison_operator,
metadata_name,
@ -603,6 +603,87 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
return [], usage
return automatic_metadata_filters, usage
def _process_metadata_filter_func(
self, sequence: int, condition: str, metadata_name: str, value: Any, filters: list[Any]
) -> list[Any]:
if value is None and condition not in ("empty", "not empty"):
return filters
json_field = Document.doc_metadata[metadata_name].as_string()
match condition:
case "contains":
filters.append(json_field.like(f"%{value}%"))
case "not contains":
filters.append(json_field.notlike(f"%{value}%"))
case "start with":
filters.append(json_field.like(f"{value}%"))
case "end with":
filters.append(json_field.like(f"%{value}"))
case "in":
if isinstance(value, str):
value_list = [v.strip() for v in value.split(",") if v.strip()]
elif isinstance(value, (list, tuple)):
value_list = [str(v) for v in value if v is not None]
else:
value_list = [str(value)] if value is not None else []
if not value_list:
filters.append(literal(False))
else:
filters.append(json_field.in_(value_list))
case "not in":
if isinstance(value, str):
value_list = [v.strip() for v in value.split(",") if v.strip()]
elif isinstance(value, (list, tuple)):
value_list = [str(v) for v in value if v is not None]
else:
value_list = [str(value)] if value is not None else []
if not value_list:
filters.append(literal(True))
else:
filters.append(json_field.notin_(value_list))
case "is" | "=":
if isinstance(value, str):
filters.append(json_field == value)
elif isinstance(value, (int, float)):
filters.append(Document.doc_metadata[metadata_name].as_float() == value)
case "is not" | "":
if isinstance(value, str):
filters.append(json_field != value)
elif isinstance(value, (int, float)):
filters.append(Document.doc_metadata[metadata_name].as_float() != value)
case "empty":
filters.append(Document.doc_metadata[metadata_name].is_(None))
case "not empty":
filters.append(Document.doc_metadata[metadata_name].isnot(None))
case "before" | "<":
filters.append(Document.doc_metadata[metadata_name].as_float() < value)
case "after" | ">":
filters.append(Document.doc_metadata[metadata_name].as_float() > value)
case "" | "<=":
filters.append(Document.doc_metadata[metadata_name].as_float() <= value)
case "" | ">=":
filters.append(Document.doc_metadata[metadata_name].as_float() >= value)
case _:
pass
return filters
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,

View File

@ -6,6 +6,7 @@ from .create_site_record_when_app_created import handle as handle_create_site_re
from .delete_tool_parameters_cache_when_sync_draft_workflow import (
handle as handle_delete_tool_parameters_cache_when_sync_draft_workflow,
)
from .queue_credential_sync_when_tenant_created import handle as handle_queue_credential_sync_when_tenant_created
from .sync_plugin_trigger_when_app_created import handle as handle_sync_plugin_trigger_when_app_created
from .sync_webhook_when_app_created import handle as handle_sync_webhook_when_app_created
from .sync_workflow_schedule_when_app_published import handle as handle_sync_workflow_schedule_when_app_published
@ -30,6 +31,7 @@ __all__ = [
"handle_create_installed_app_when_app_created",
"handle_create_site_record_when_app_created",
"handle_delete_tool_parameters_cache_when_sync_draft_workflow",
"handle_queue_credential_sync_when_tenant_created",
"handle_sync_plugin_trigger_when_app_created",
"handle_sync_webhook_when_app_created",
"handle_sync_workflow_schedule_when_app_published",

View File

@ -0,0 +1,19 @@
from configs import dify_config
from events.tenant_event import tenant_was_created
from services.enterprise.workspace_sync import WorkspaceSyncService
@tenant_was_created.connect
def handle(sender, **kwargs):
"""Queue credential sync when a tenant/workspace is created."""
# Only queue sync tasks if plugin manager (enterprise feature) is enabled
if not dify_config.ENTERPRISE_ENABLED:
return
tenant = sender
# Determine source from kwargs if available, otherwise use generic
source = kwargs.get("source", "tenant_created")
# Queue credential sync task to Redis for enterprise backend to process
WorkspaceSyncService.queue_credential_sync(tenant.id, source=source)

View File

@ -14,8 +14,7 @@ from enums.quota_type import QuotaType, unlimited
from extensions.otel import AppGenerateHandler, trace_span
from models.model import Account, App, AppMode, EndUser
from models.workflow import Workflow
from services.errors.app import QuotaExceededError, WorkflowIdFormatError, WorkflowNotFoundError
from services.errors.llm import InvokeRateLimitError
from services.errors.app import InvokeRateLimitError, QuotaExceededError, WorkflowIdFormatError, WorkflowNotFoundError
from services.workflow_service import WorkflowService

View File

@ -21,7 +21,7 @@ from models.model import App, EndUser
from models.trigger import WorkflowTriggerLog
from models.workflow import Workflow
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
from services.errors.app import QuotaExceededError, WorkflowNotFoundError, WorkflowQuotaLimitError
from services.errors.app import InvokeRateLimitError, QuotaExceededError, WorkflowNotFoundError
from services.workflow.entities import AsyncTriggerResponse, TriggerData, WorkflowTaskData
from services.workflow.queue_dispatcher import QueueDispatcherManager, QueuePriority
from services.workflow_service import WorkflowService
@ -141,7 +141,7 @@ class AsyncWorkflowService:
trigger_log_repo.update(trigger_log)
session.commit()
raise WorkflowQuotaLimitError(
raise InvokeRateLimitError(
f"Workflow execution quota limit reached for tenant {trigger_data.tenant_id}"
) from e

View File

@ -0,0 +1,58 @@
import json
import logging
import uuid
from datetime import UTC, datetime
from redis import RedisError
from extensions.ext_redis import redis_client
logger = logging.getLogger(__name__)
WORKSPACE_SYNC_QUEUE = "enterprise:workspace:sync:queue"
WORKSPACE_SYNC_PROCESSING = "enterprise:workspace:sync:processing"
class WorkspaceSyncService:
"""Service to publish workspace sync tasks to Redis queue for enterprise backend consumption"""
@staticmethod
def queue_credential_sync(workspace_id: str, *, source: str) -> bool:
"""
Queue a credential sync task for a newly created workspace.
This publishes a task to Redis that will be consumed by the enterprise backend
worker to sync credentials with the plugin-manager.
Args:
workspace_id: The workspace/tenant ID to sync credentials for
source: Source of the sync request (for debugging/tracking)
Returns:
bool: True if task was queued successfully, False otherwise
"""
try:
task = {
"task_id": str(uuid.uuid4()),
"workspace_id": workspace_id,
"retry_count": 0,
"created_at": datetime.now(UTC).isoformat(),
"source": source,
}
# Push to Redis list (queue) - LPUSH adds to the head, worker consumes from tail with RPOP
redis_client.lpush(WORKSPACE_SYNC_QUEUE, json.dumps(task))
logger.info(
"Queued credential sync task for workspace %s, task_id: %s, source: %s",
workspace_id,
task["task_id"],
source,
)
return True
except (RedisError, TypeError) as e:
logger.error("Failed to queue credential sync for workspace %s: %s", workspace_id, str(e), exc_info=True)
# Don't raise - we don't want to fail workspace creation if queueing fails
# The scheduled task will catch it later
return False

View File

@ -18,8 +18,8 @@ class WorkflowIdFormatError(Exception):
pass
class WorkflowQuotaLimitError(Exception):
"""Raised when workflow execution quota is exceeded (for async/background workflows)."""
class InvokeRateLimitError(Exception):
"""Raised when rate limit is exceeded for workflow invocations."""
pass

View File

@ -146,7 +146,7 @@ class PluginParameterService:
provider,
action,
resolved_credentials,
original_subscription.credential_type or CredentialType.UNAUTHORIZED.value,
CredentialType.API_KEY.value,
parameter,
)
.options

View File

@ -319,14 +319,8 @@ class MCPToolManageService:
except MCPError as e:
raise ValueError(f"Failed to connect to MCP server: {e}")
# Update database with retrieved tools (ensure description is a non-null string)
tools_payload = []
for tool in tools:
data = tool.model_dump()
if data.get("description") is None:
data["description"] = ""
tools_payload.append(data)
db_provider.tools = json.dumps(tools_payload)
# Update database with retrieved tools
db_provider.tools = json.dumps([tool.model_dump() for tool in tools])
db_provider.authed = True
db_provider.updated_at = datetime.now()
self._session.flush()
@ -626,21 +620,6 @@ class MCPToolManageService:
server_url_hash=new_server_url_hash,
)
@staticmethod
def reconnect_with_url(
*,
server_url: str,
headers: dict[str, str],
timeout: float | None,
sse_read_timeout: float | None,
) -> ReconnectResult:
return MCPToolManageService._reconnect_with_url(
server_url=server_url,
headers=headers,
timeout=timeout,
sse_read_timeout=sse_read_timeout,
)
@staticmethod
def _reconnect_with_url(
*,
@ -663,16 +642,9 @@ class MCPToolManageService:
sse_read_timeout=sse_read_timeout,
) as mcp_client:
tools = mcp_client.list_tools()
# Ensure tool descriptions are non-null in payload
tools_payload = []
for t in tools:
d = t.model_dump()
if d.get("description") is None:
d["description"] = ""
tools_payload.append(d)
return ReconnectResult(
authed=True,
tools=json.dumps(tools_payload),
tools=json.dumps([tool.model_dump() for tool in tools]),
encrypted_credentials=EMPTY_CREDENTIALS_JSON,
)
except MCPAuthError:

View File

@ -5,8 +5,8 @@ from datetime import datetime
from typing import Any
from sqlalchemy import or_, select
from sqlalchemy.orm import Session
from core.db.session_factory import session_factory
from core.helper.tool_provider_cache import ToolProviderListCache
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.__base.tool_provider import ToolProviderController
@ -68,27 +68,26 @@ class WorkflowToolManageService:
if workflow is None:
raise ValueError(f"Workflow not found for app {workflow_app_id}")
workflow_tool_provider = WorkflowToolProvider(
tenant_id=tenant_id,
user_id=user_id,
app_id=workflow_app_id,
name=name,
label=label,
icon=json.dumps(icon),
description=description,
parameter_configuration=json.dumps(parameters),
privacy_policy=privacy_policy,
version=workflow.version,
)
with Session(db.engine, expire_on_commit=False) as session, session.begin():
workflow_tool_provider = WorkflowToolProvider(
tenant_id=tenant_id,
user_id=user_id,
app_id=workflow_app_id,
name=name,
label=label,
icon=json.dumps(icon),
description=description,
parameter_configuration=json.dumps(parameters),
privacy_policy=privacy_policy,
version=workflow.version,
)
session.add(workflow_tool_provider)
try:
WorkflowToolProviderController.from_db(workflow_tool_provider)
except Exception as e:
raise ValueError(str(e))
with session_factory.create_session() as session, session.begin():
session.add(workflow_tool_provider)
if labels is not None:
ToolLabelManager.update_tool_labels(
ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels

View File

@ -868,111 +868,48 @@ class TriggerProviderService:
if not provider_controller:
raise ValueError(f"Provider {provider_id} not found")
# Use distributed lock to prevent race conditions on the same subscription
lock_key = f"trigger_subscription_rebuild_lock:{tenant_id}_{subscription_id}"
with redis_client.lock(lock_key, timeout=20):
with Session(db.engine, expire_on_commit=False) as session:
try:
# Get subscription within the transaction
subscription: TriggerSubscription | None = (
session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
)
if not subscription:
raise ValueError(f"Subscription {subscription_id} not found")
subscription = TriggerProviderService.get_subscription_by_id(
tenant_id=tenant_id,
subscription_id=subscription_id,
)
if not subscription:
raise ValueError(f"Subscription {subscription_id} not found")
credential_type = CredentialType.of(subscription.credential_type)
if credential_type not in [CredentialType.OAUTH2, CredentialType.API_KEY]:
raise ValueError("Credential type not supported for rebuild")
credential_type = CredentialType.of(subscription.credential_type)
if credential_type not in [CredentialType.OAUTH2, CredentialType.API_KEY]:
raise ValueError("Credential type not supported for rebuild")
# Decrypt existing credentials for merging
credential_encrypter, _ = create_trigger_provider_encrypter_for_subscription(
tenant_id=tenant_id,
controller=provider_controller,
subscription=subscription,
)
decrypted_credentials = dict(credential_encrypter.decrypt(subscription.credentials))
# TODO: Trying to invoke update api of the plugin trigger provider
# Merge credentials: if caller passed HIDDEN_VALUE, retain existing decrypted value
merged_credentials: dict[str, Any] = {
key: value if value != HIDDEN_VALUE else decrypted_credentials.get(key, UNKNOWN_VALUE)
for key, value in credentials.items()
}
# FALLBACK: If the update api is not implemented, delete the previous subscription and create a new one
user_id = subscription.user_id
# Delete the previous subscription
user_id = subscription.user_id
TriggerManager.unsubscribe_trigger(
tenant_id=tenant_id,
user_id=user_id,
provider_id=provider_id,
subscription=subscription.to_entity(),
credentials=subscription.credentials,
credential_type=credential_type,
)
# TODO: Trying to invoke update api of the plugin trigger provider
# FALLBACK: If the update api is not implemented,
# delete the previous subscription and create a new one
# Unsubscribe the previous subscription (external call, but we'll handle errors)
try:
TriggerManager.unsubscribe_trigger(
tenant_id=tenant_id,
user_id=user_id,
provider_id=provider_id,
subscription=subscription.to_entity(),
credentials=decrypted_credentials,
credential_type=credential_type,
)
except Exception as e:
logger.exception("Error unsubscribing trigger during rebuild", exc_info=e)
# Continue anyway - the subscription might already be deleted externally
# Create a new subscription with the same subscription_id and endpoint_id (external call)
new_subscription: TriggerSubscriptionEntity = TriggerManager.subscribe_trigger(
tenant_id=tenant_id,
user_id=user_id,
provider_id=provider_id,
endpoint=generate_plugin_trigger_endpoint_url(subscription.endpoint_id),
parameters=parameters,
credentials=merged_credentials,
credential_type=credential_type,
)
# Update the subscription in the same transaction
# Inline update logic to reuse the same session
if name is not None and name != subscription.name:
existing = (
session.query(TriggerSubscription)
.filter_by(tenant_id=tenant_id, provider_id=str(provider_id), name=name)
.first()
)
if existing and existing.id != subscription.id:
raise ValueError(f"Subscription name '{name}' already exists for this provider")
subscription.name = name
# Update parameters
subscription.parameters = dict(parameters)
# Update credentials with merged (and encrypted) values
subscription.credentials = dict(credential_encrypter.encrypt(merged_credentials))
# Update properties
if new_subscription.properties:
properties_encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=provider_controller.get_properties_schema(),
cache=NoOpProviderCredentialCache(),
)
subscription.properties = dict(properties_encrypter.encrypt(dict(new_subscription.properties)))
# Update expiration timestamp
if new_subscription.expires_at is not None:
subscription.expires_at = new_subscription.expires_at
# Commit the transaction
session.commit()
# Clear subscription cache
delete_cache_for_subscription(
tenant_id=tenant_id,
provider_id=subscription.provider_id,
subscription_id=subscription.id,
)
except Exception as e:
# Rollback on any error
session.rollback()
logger.exception("Failed to rebuild trigger subscription", exc_info=e)
raise
# Create a new subscription with the same subscription_id and endpoint_id
new_subscription: TriggerSubscriptionEntity = TriggerManager.subscribe_trigger(
tenant_id=tenant_id,
user_id=user_id,
provider_id=provider_id,
endpoint=generate_plugin_trigger_endpoint_url(subscription.endpoint_id),
parameters=parameters,
credentials=credentials,
credential_type=credential_type,
)
TriggerProviderService.update_trigger_subscription(
tenant_id=tenant_id,
subscription_id=subscription.id,
name=name,
parameters=parameters,
credentials=credentials,
properties=new_subscription.properties,
expires_at=new_subscription.expires_at,
)

View File

@ -863,18 +863,10 @@ class WebhookService:
not_found_in_cache.append(node_id)
continue
lock_key = f"{cls.__WEBHOOK_NODE_CACHE_KEY__}:apps:{app.id}:lock"
lock = redis_client.lock(lock_key, timeout=10)
lock_acquired = False
try:
# acquire the lock with blocking and timeout
lock_acquired = lock.acquire(blocking=True, blocking_timeout=10)
if not lock_acquired:
logger.warning("Failed to acquire lock for webhook sync, app %s", app.id)
raise RuntimeError("Failed to acquire lock for webhook trigger synchronization")
with Session(db.engine) as session:
with Session(db.engine) as session:
try:
# lock the concurrent webhook trigger creation
redis_client.lock(f"{cls.__WEBHOOK_NODE_CACHE_KEY__}:apps:{app.id}:lock", timeout=10)
# fetch the non-cached nodes from DB
all_records = session.scalars(
select(WorkflowWebhookTrigger).where(
@ -911,16 +903,11 @@ class WebhookService:
session.delete(nodes_id_in_db[node_id])
redis_client.delete(f"{cls.__WEBHOOK_NODE_CACHE_KEY__}:{app.id}:{node_id}")
session.commit()
except Exception:
logger.exception("Failed to sync webhook relationships for app %s", app.id)
raise
finally:
# release the lock only if it was acquired
if lock_acquired:
try:
lock.release()
except Exception:
logger.exception("Failed to release lock for webhook sync, app %s", app.id)
except Exception:
logger.exception("Failed to sync webhook relationships for app %s", app.id)
raise
finally:
redis_client.delete(f"{cls.__WEBHOOK_NODE_CACHE_KEY__}:apps:{app.id}:lock")
@classmethod
def generate_webhook_id(cls) -> str:

View File

@ -7,14 +7,11 @@ CODE_LANGUAGE = CodeLanguage.JINJA2
def test_jinja2():
"""Test basic Jinja2 template rendering."""
template = "Hello {{template}}"
# Template must be base64 encoded to match the new safe embedding approach
template_b64 = base64.b64encode(template.encode("utf-8")).decode("utf-8")
inputs = base64.b64encode(b'{"template": "World"}').decode("utf-8")
code = (
Jinja2TemplateTransformer.get_runner_script()
.replace(Jinja2TemplateTransformer._template_b64_placeholder, template_b64)
.replace(Jinja2TemplateTransformer._code_placeholder, template)
.replace(Jinja2TemplateTransformer._inputs_placeholder, inputs)
)
result = CodeExecutor.execute_code(
@ -24,7 +21,6 @@ def test_jinja2():
def test_jinja2_with_code_template():
"""Test template rendering via the high-level workflow API."""
result = CodeExecutor.execute_workflow_code_template(
language=CODE_LANGUAGE, code="Hello {{template}}", inputs={"template": "World"}
)
@ -32,64 +28,7 @@ def test_jinja2_with_code_template():
def test_jinja2_get_runner_script():
"""Test that runner script contains required placeholders."""
runner_script = Jinja2TemplateTransformer.get_runner_script()
assert runner_script.count(Jinja2TemplateTransformer._template_b64_placeholder) == 1
assert runner_script.count(Jinja2TemplateTransformer._code_placeholder) == 1
assert runner_script.count(Jinja2TemplateTransformer._inputs_placeholder) == 1
assert runner_script.count(Jinja2TemplateTransformer._result_tag) == 2
def test_jinja2_template_with_special_characters():
"""
Test that templates with special characters (quotes, newlines) render correctly.
This is a regression test for issue #26818 where textarea pre-fill values
containing special characters would break template rendering.
"""
# Template with triple quotes, single quotes, double quotes, and newlines
template = """<html>
<body>
<input value="{{ task.get('Task ID', '') }}"/>
<textarea>{{ task.get('Issues', 'No issues reported') }}</textarea>
<p>Status: "{{ status }}"</p>
<pre>'''code block'''</pre>
</body>
</html>"""
inputs = {"task": {"Task ID": "TASK-123", "Issues": "Line 1\nLine 2\nLine 3"}, "status": "completed"}
result = CodeExecutor.execute_workflow_code_template(language=CODE_LANGUAGE, code=template, inputs=inputs)
# Verify the template rendered correctly with all special characters
output = result["result"]
assert 'value="TASK-123"' in output
assert "<textarea>Line 1\nLine 2\nLine 3</textarea>" in output
assert 'Status: "completed"' in output
assert "'''code block'''" in output
def test_jinja2_template_with_html_textarea_prefill():
"""
Specific test for HTML textarea with Jinja2 variable pre-fill.
Verifies fix for issue #26818.
"""
template = "<textarea name='notes'>{{ notes }}</textarea>"
notes_content = "This is a multi-line note.\nWith special chars: 'single' and \"double\" quotes."
inputs = {"notes": notes_content}
result = CodeExecutor.execute_workflow_code_template(language=CODE_LANGUAGE, code=template, inputs=inputs)
expected_output = f"<textarea name='notes'>{notes_content}</textarea>"
assert result["result"] == expected_output
def test_jinja2_assemble_runner_script_encodes_template():
"""Test that assemble_runner_script properly base64 encodes the template."""
template = "Hello {{ name }}!"
inputs = {"name": "World"}
script = Jinja2TemplateTransformer.assemble_runner_script(template, inputs)
# The template should be base64 encoded in the script
template_b64 = base64.b64encode(template.encode("utf-8")).decode("utf-8")
assert template_b64 in script
# The raw template should NOT appear in the script (it's encoded)
assert "Hello {{ name }}!" not in script

View File

@ -1,682 +0,0 @@
from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
from core.plugin.entities.plugin_daemon import CredentialType
from core.trigger.entities.entities import Subscription as TriggerSubscriptionEntity
from extensions.ext_database import db
from models.provider_ids import TriggerProviderID
from models.trigger import TriggerSubscription
from services.trigger.trigger_provider_service import TriggerProviderService
class TestTriggerProviderService:
"""Integration tests for TriggerProviderService using testcontainers."""
@pytest.fixture
def mock_external_service_dependencies(self):
"""Mock setup for external service dependencies."""
with (
patch("services.trigger.trigger_provider_service.TriggerManager") as mock_trigger_manager,
patch("services.trigger.trigger_provider_service.redis_client") as mock_redis_client,
patch("services.trigger.trigger_provider_service.delete_cache_for_subscription") as mock_delete_cache,
patch("services.account_service.FeatureService") as mock_account_feature_service,
):
# Setup default mock returns
mock_provider_controller = MagicMock()
mock_provider_controller.get_credential_schema_config.return_value = MagicMock()
mock_provider_controller.get_properties_schema.return_value = MagicMock()
mock_trigger_manager.get_trigger_provider.return_value = mock_provider_controller
# Mock redis lock
mock_lock = MagicMock()
mock_lock.__enter__ = MagicMock(return_value=None)
mock_lock.__exit__ = MagicMock(return_value=None)
mock_redis_client.lock.return_value = mock_lock
# Setup account feature service mock
mock_account_feature_service.get_system_features.return_value.is_allow_register = True
yield {
"trigger_manager": mock_trigger_manager,
"redis_client": mock_redis_client,
"delete_cache": mock_delete_cache,
"provider_controller": mock_provider_controller,
"account_feature_service": mock_account_feature_service,
}
def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies):
"""
Helper method to create a test account and tenant for testing.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
mock_external_service_dependencies: Mock dependencies
Returns:
tuple: (account, tenant) - Created account and tenant instances
"""
fake = Faker()
from services.account_service import AccountService, TenantService
# Setup mocks for account creation
mock_external_service_dependencies[
"account_feature_service"
].get_system_features.return_value.is_allow_register = True
mock_external_service_dependencies[
"trigger_manager"
].get_trigger_provider.return_value = mock_external_service_dependencies["provider_controller"]
# Create account and tenant
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
return account, tenant
def _create_test_subscription(
self,
db_session_with_containers,
tenant_id,
user_id,
provider_id,
credential_type,
credentials,
mock_external_service_dependencies,
):
"""
Helper method to create a test trigger subscription.
Args:
db_session_with_containers: Database session
tenant_id: Tenant ID
user_id: User ID
provider_id: Provider ID
credential_type: Credential type
credentials: Credentials dict
mock_external_service_dependencies: Mock dependencies
Returns:
TriggerSubscription: Created subscription instance
"""
fake = Faker()
from core.helper.provider_cache import NoOpProviderCredentialCache
from core.helper.provider_encryption import create_provider_encrypter
# Use mock provider controller to encrypt credentials
provider_controller = mock_external_service_dependencies["provider_controller"]
# Create encrypter for credentials
credential_encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=provider_controller.get_credential_schema_config(credential_type),
cache=NoOpProviderCredentialCache(),
)
subscription = TriggerSubscription(
name=fake.word(),
tenant_id=tenant_id,
user_id=user_id,
provider_id=str(provider_id),
endpoint_id=fake.uuid4(),
parameters={"param1": "value1"},
properties={"prop1": "value1"},
credentials=dict(credential_encrypter.encrypt(credentials)),
credential_type=credential_type.value,
credential_expires_at=-1,
expires_at=-1,
)
db.session.add(subscription)
db.session.commit()
db.session.refresh(subscription)
return subscription
def test_rebuild_trigger_subscription_success_with_merged_credentials(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test successful rebuild with credential merging (HIDDEN_VALUE handling).
This test verifies:
- Credentials are properly merged (HIDDEN_VALUE replaced with existing values)
- Single transaction wraps all operations
- Merged credentials are used for subscribe and update
- Database state is correctly updated
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
credential_type = CredentialType.API_KEY
# Create initial subscription with credentials
original_credentials = {"api_key": "original-secret-key", "api_secret": "original-secret"}
subscription = self._create_test_subscription(
db_session_with_containers,
tenant.id,
account.id,
provider_id,
credential_type,
original_credentials,
mock_external_service_dependencies,
)
# Prepare new credentials with HIDDEN_VALUE for api_key (should keep original)
# and new value for api_secret (should update)
new_credentials = {
"api_key": HIDDEN_VALUE, # Should be replaced with original
"api_secret": "new-secret-value", # Should be updated
}
# Mock subscribe_trigger to return a new subscription entity
new_subscription_entity = TriggerSubscriptionEntity(
endpoint=subscription.endpoint_id,
parameters={"param1": "value1"},
properties={"prop1": "new_prop_value"},
expires_at=1234567890,
)
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.return_value = new_subscription_entity
# Mock unsubscribe_trigger
mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.return_value = MagicMock()
# Execute rebuild
TriggerProviderService.rebuild_trigger_subscription(
tenant_id=tenant.id,
provider_id=provider_id,
subscription_id=subscription.id,
credentials=new_credentials,
parameters={"param1": "updated_value"},
name="updated_name",
)
# Verify unsubscribe was called with decrypted original credentials
mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.assert_called_once()
unsubscribe_call_args = mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.call_args
assert unsubscribe_call_args.kwargs["tenant_id"] == tenant.id
assert unsubscribe_call_args.kwargs["provider_id"] == provider_id
assert unsubscribe_call_args.kwargs["credential_type"] == credential_type
# Verify subscribe was called with merged credentials (api_key from original, api_secret new)
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.assert_called_once()
subscribe_call_args = mock_external_service_dependencies["trigger_manager"].subscribe_trigger.call_args
subscribe_credentials = subscribe_call_args.kwargs["credentials"]
assert subscribe_credentials["api_key"] == original_credentials["api_key"] # Merged from original
assert subscribe_credentials["api_secret"] == "new-secret-value" # New value
# Verify database state was updated
db.session.refresh(subscription)
assert subscription.name == "updated_name"
assert subscription.parameters == {"param1": "updated_value"}
# Verify credentials in DB were updated with merged values (decrypt to check)
from core.helper.provider_cache import NoOpProviderCredentialCache
from core.helper.provider_encryption import create_provider_encrypter
# Use mock provider controller to decrypt credentials
provider_controller = mock_external_service_dependencies["provider_controller"]
credential_encrypter, _ = create_provider_encrypter(
tenant_id=tenant.id,
config=provider_controller.get_credential_schema_config(credential_type),
cache=NoOpProviderCredentialCache(),
)
decrypted_db_credentials = dict(credential_encrypter.decrypt(subscription.credentials))
assert decrypted_db_credentials["api_key"] == original_credentials["api_key"]
assert decrypted_db_credentials["api_secret"] == "new-secret-value"
# Verify cache was cleared
mock_external_service_dependencies["delete_cache"].assert_called_once_with(
tenant_id=tenant.id,
provider_id=subscription.provider_id,
subscription_id=subscription.id,
)
def test_rebuild_trigger_subscription_with_all_new_credentials(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test rebuild when all credentials are new (no HIDDEN_VALUE).
This test verifies:
- All new credentials are used when no HIDDEN_VALUE is present
- Merged credentials contain only new values
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
credential_type = CredentialType.API_KEY
# Create initial subscription
original_credentials = {"api_key": "original-key", "api_secret": "original-secret"}
subscription = self._create_test_subscription(
db_session_with_containers,
tenant.id,
account.id,
provider_id,
credential_type,
original_credentials,
mock_external_service_dependencies,
)
# All new credentials (no HIDDEN_VALUE)
new_credentials = {
"api_key": "completely-new-key",
"api_secret": "completely-new-secret",
}
new_subscription_entity = TriggerSubscriptionEntity(
endpoint=subscription.endpoint_id,
parameters={},
properties={},
expires_at=-1,
)
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.return_value = new_subscription_entity
mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.return_value = MagicMock()
# Execute rebuild
TriggerProviderService.rebuild_trigger_subscription(
tenant_id=tenant.id,
provider_id=provider_id,
subscription_id=subscription.id,
credentials=new_credentials,
parameters={},
)
# Verify subscribe was called with all new credentials
subscribe_call_args = mock_external_service_dependencies["trigger_manager"].subscribe_trigger.call_args
subscribe_credentials = subscribe_call_args.kwargs["credentials"]
assert subscribe_credentials["api_key"] == "completely-new-key"
assert subscribe_credentials["api_secret"] == "completely-new-secret"
def test_rebuild_trigger_subscription_with_all_hidden_values(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test rebuild when all credentials are HIDDEN_VALUE (preserve all existing).
This test verifies:
- All HIDDEN_VALUE credentials are replaced with existing values
- Original credentials are preserved
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
credential_type = CredentialType.API_KEY
original_credentials = {"api_key": "original-key", "api_secret": "original-secret"}
subscription = self._create_test_subscription(
db_session_with_containers,
tenant.id,
account.id,
provider_id,
credential_type,
original_credentials,
mock_external_service_dependencies,
)
# All HIDDEN_VALUE (should preserve all original)
new_credentials = {
"api_key": HIDDEN_VALUE,
"api_secret": HIDDEN_VALUE,
}
new_subscription_entity = TriggerSubscriptionEntity(
endpoint=subscription.endpoint_id,
parameters={},
properties={},
expires_at=-1,
)
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.return_value = new_subscription_entity
mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.return_value = MagicMock()
# Execute rebuild
TriggerProviderService.rebuild_trigger_subscription(
tenant_id=tenant.id,
provider_id=provider_id,
subscription_id=subscription.id,
credentials=new_credentials,
parameters={},
)
# Verify subscribe was called with all original credentials
subscribe_call_args = mock_external_service_dependencies["trigger_manager"].subscribe_trigger.call_args
subscribe_credentials = subscribe_call_args.kwargs["credentials"]
assert subscribe_credentials["api_key"] == original_credentials["api_key"]
assert subscribe_credentials["api_secret"] == original_credentials["api_secret"]
def test_rebuild_trigger_subscription_with_missing_key_uses_unknown_value(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test rebuild when HIDDEN_VALUE is used for a key that doesn't exist in original.
This test verifies:
- UNKNOWN_VALUE is used when HIDDEN_VALUE key doesn't exist in original credentials
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
credential_type = CredentialType.API_KEY
# Original has only api_key
original_credentials = {"api_key": "original-key"}
subscription = self._create_test_subscription(
db_session_with_containers,
tenant.id,
account.id,
provider_id,
credential_type,
original_credentials,
mock_external_service_dependencies,
)
# HIDDEN_VALUE for non-existent key should use UNKNOWN_VALUE
new_credentials = {
"api_key": HIDDEN_VALUE,
"non_existent_key": HIDDEN_VALUE, # This key doesn't exist in original
}
new_subscription_entity = TriggerSubscriptionEntity(
endpoint=subscription.endpoint_id,
parameters={},
properties={},
expires_at=-1,
)
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.return_value = new_subscription_entity
mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.return_value = MagicMock()
# Execute rebuild
TriggerProviderService.rebuild_trigger_subscription(
tenant_id=tenant.id,
provider_id=provider_id,
subscription_id=subscription.id,
credentials=new_credentials,
parameters={},
)
# Verify subscribe was called with original api_key and UNKNOWN_VALUE for missing key
subscribe_call_args = mock_external_service_dependencies["trigger_manager"].subscribe_trigger.call_args
subscribe_credentials = subscribe_call_args.kwargs["credentials"]
assert subscribe_credentials["api_key"] == original_credentials["api_key"]
assert subscribe_credentials["non_existent_key"] == UNKNOWN_VALUE
def test_rebuild_trigger_subscription_rollback_on_error(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test that transaction is rolled back on error.
This test verifies:
- Database transaction is rolled back when an error occurs
- Original subscription state is preserved
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
credential_type = CredentialType.API_KEY
original_credentials = {"api_key": "original-key"}
subscription = self._create_test_subscription(
db_session_with_containers,
tenant.id,
account.id,
provider_id,
credential_type,
original_credentials,
mock_external_service_dependencies,
)
original_name = subscription.name
original_parameters = subscription.parameters.copy()
# Make subscribe_trigger raise an error
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.side_effect = ValueError(
"Subscribe failed"
)
mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.return_value = MagicMock()
# Execute rebuild and expect error
with pytest.raises(ValueError, match="Subscribe failed"):
TriggerProviderService.rebuild_trigger_subscription(
tenant_id=tenant.id,
provider_id=provider_id,
subscription_id=subscription.id,
credentials={"api_key": "new-key"},
parameters={},
)
# Verify subscription state was not changed (rolled back)
db.session.refresh(subscription)
assert subscription.name == original_name
assert subscription.parameters == original_parameters
def test_rebuild_trigger_subscription_unsubscribe_error_continues(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test that unsubscribe errors are handled gracefully and operation continues.
This test verifies:
- Unsubscribe errors are caught and logged but don't stop the rebuild
- Rebuild continues even if unsubscribe fails
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
credential_type = CredentialType.API_KEY
original_credentials = {"api_key": "original-key"}
subscription = self._create_test_subscription(
db_session_with_containers,
tenant.id,
account.id,
provider_id,
credential_type,
original_credentials,
mock_external_service_dependencies,
)
# Make unsubscribe_trigger raise an error (should be caught and continue)
mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.side_effect = ValueError(
"Unsubscribe failed"
)
new_subscription_entity = TriggerSubscriptionEntity(
endpoint=subscription.endpoint_id,
parameters={},
properties={},
expires_at=-1,
)
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.return_value = new_subscription_entity
# Execute rebuild - should succeed despite unsubscribe error
TriggerProviderService.rebuild_trigger_subscription(
tenant_id=tenant.id,
provider_id=provider_id,
subscription_id=subscription.id,
credentials={"api_key": "new-key"},
parameters={},
)
# Verify subscribe was still called (operation continued)
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.assert_called_once()
# Verify subscription was updated
db.session.refresh(subscription)
assert subscription.parameters == {}
def test_rebuild_trigger_subscription_subscription_not_found(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test error when subscription is not found.
This test verifies:
- Proper error is raised when subscription doesn't exist
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
fake_subscription_id = fake.uuid4()
with pytest.raises(ValueError, match="not found"):
TriggerProviderService.rebuild_trigger_subscription(
tenant_id=tenant.id,
provider_id=provider_id,
subscription_id=fake_subscription_id,
credentials={},
parameters={},
)
def test_rebuild_trigger_subscription_provider_not_found(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test error when provider is not found.
This test verifies:
- Proper error is raised when provider doesn't exist
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
provider_id = TriggerProviderID("non_existent_org/non_existent_plugin/non_existent_provider")
# Make get_trigger_provider return None
mock_external_service_dependencies["trigger_manager"].get_trigger_provider.return_value = None
with pytest.raises(ValueError, match="Provider.*not found"):
TriggerProviderService.rebuild_trigger_subscription(
tenant_id=tenant.id,
provider_id=provider_id,
subscription_id=fake.uuid4(),
credentials={},
parameters={},
)
def test_rebuild_trigger_subscription_unsupported_credential_type(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test error when credential type is not supported for rebuild.
This test verifies:
- Proper error is raised for unsupported credential types (not OAUTH2 or API_KEY)
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
credential_type = CredentialType.UNAUTHORIZED # Not supported
subscription = self._create_test_subscription(
db_session_with_containers,
tenant.id,
account.id,
provider_id,
credential_type,
{},
mock_external_service_dependencies,
)
with pytest.raises(ValueError, match="Credential type not supported for rebuild"):
TriggerProviderService.rebuild_trigger_subscription(
tenant_id=tenant.id,
provider_id=provider_id,
subscription_id=subscription.id,
credentials={},
parameters={},
)
def test_rebuild_trigger_subscription_name_uniqueness_check(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test that name uniqueness is checked when updating name.
This test verifies:
- Error is raised when new name conflicts with existing subscription
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
credential_type = CredentialType.API_KEY
# Create first subscription
subscription1 = self._create_test_subscription(
db_session_with_containers,
tenant.id,
account.id,
provider_id,
credential_type,
{"api_key": "key1"},
mock_external_service_dependencies,
)
# Create second subscription with different name
subscription2 = self._create_test_subscription(
db_session_with_containers,
tenant.id,
account.id,
provider_id,
credential_type,
{"api_key": "key2"},
mock_external_service_dependencies,
)
new_subscription_entity = TriggerSubscriptionEntity(
endpoint=subscription2.endpoint_id,
parameters={},
properties={},
expires_at=-1,
)
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.return_value = new_subscription_entity
mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.return_value = MagicMock()
# Try to rename subscription2 to subscription1's name (should fail)
with pytest.raises(ValueError, match="already exists"):
TriggerProviderService.rebuild_trigger_subscription(
tenant_id=tenant.id,
provider_id=provider_id,
subscription_id=subscription2.id,
credentials={"api_key": "new-key"},
parameters={},
name=subscription1.name, # Conflicting name
)

View File

@ -705,207 +705,3 @@ class TestWorkflowToolManageService:
db.session.refresh(created_tool)
assert created_tool.name == first_tool_name
assert created_tool.updated_at is not None
def test_create_workflow_tool_with_file_parameter_default(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test workflow tool creation with FILE parameter having a file object as default.
This test verifies:
- FILE parameters can have file object defaults
- The default value (dict with id/base64Url) is properly handled
- Tool creation succeeds without Pydantic validation errors
Related issue: Array[File] default value causes Pydantic validation errors.
"""
fake = Faker()
# Create test data
app, account, workflow = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies
)
# Create workflow graph with a FILE variable that has a default value
workflow_graph = {
"nodes": [
{
"id": "start_node",
"data": {
"type": "start",
"variables": [
{
"variable": "document",
"label": "Document",
"type": "file",
"required": False,
"default": {"id": fake.uuid4(), "base64Url": ""},
}
],
},
}
]
}
workflow.graph = json.dumps(workflow_graph)
# Setup workflow tool parameters with FILE type
file_parameters = [
{
"name": "document",
"description": "Upload a document",
"form": "form",
"type": "file",
"required": False,
}
]
# Execute the method under test
# Note: from_db is mocked, so this test primarily validates the parameter configuration
result = WorkflowToolManageService.create_workflow_tool(
user_id=account.id,
tenant_id=account.current_tenant.id,
workflow_app_id=app.id,
name=fake.word(),
label=fake.word(),
icon={"type": "emoji", "emoji": "📄"},
description=fake.text(max_nb_chars=200),
parameters=file_parameters,
)
# Verify the result
assert result == {"result": "success"}
def test_create_workflow_tool_with_files_parameter_default(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test workflow tool creation with FILES (Array[File]) parameter having file objects as default.
This test verifies:
- FILES parameters can have a list of file objects as default
- The default value (list of dicts with id/base64Url) is properly handled
- Tool creation succeeds without Pydantic validation errors
Related issue: Array[File] default value causes 4 Pydantic validation errors
because PluginParameter.default only accepts Union[float, int, str, bool] | None.
"""
fake = Faker()
# Create test data
app, account, workflow = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies
)
# Create workflow graph with a FILE_LIST variable that has a default value
workflow_graph = {
"nodes": [
{
"id": "start_node",
"data": {
"type": "start",
"variables": [
{
"variable": "documents",
"label": "Documents",
"type": "file-list",
"required": False,
"default": [
{"id": fake.uuid4(), "base64Url": ""},
{"id": fake.uuid4(), "base64Url": ""},
],
}
],
},
}
]
}
workflow.graph = json.dumps(workflow_graph)
# Setup workflow tool parameters with FILES type
files_parameters = [
{
"name": "documents",
"description": "Upload multiple documents",
"form": "form",
"type": "files",
"required": False,
}
]
# Execute the method under test
# Note: from_db is mocked, so this test primarily validates the parameter configuration
result = WorkflowToolManageService.create_workflow_tool(
user_id=account.id,
tenant_id=account.current_tenant.id,
workflow_app_id=app.id,
name=fake.word(),
label=fake.word(),
icon={"type": "emoji", "emoji": "📁"},
description=fake.text(max_nb_chars=200),
parameters=files_parameters,
)
# Verify the result
assert result == {"result": "success"}
def test_create_workflow_tool_db_commit_before_validation(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test that database commit happens before validation, causing DB pollution on validation failure.
This test verifies the second bug:
- WorkflowToolProvider is committed to database BEFORE from_db validation
- If validation fails, the record remains in the database
- Subsequent attempts fail with "Tool already exists" error
This demonstrates why we need to validate BEFORE database commit.
"""
fake = Faker()
# Create test data
app, account, workflow = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies
)
tool_name = fake.word()
# Mock from_db to raise validation error
mock_external_service_dependencies["workflow_tool_provider_controller"].from_db.side_effect = ValueError(
"Validation failed: default parameter type mismatch"
)
# Attempt to create workflow tool (will fail at validation stage)
with pytest.raises(ValueError) as exc_info:
WorkflowToolManageService.create_workflow_tool(
user_id=account.id,
tenant_id=account.current_tenant.id,
workflow_app_id=app.id,
name=tool_name,
label=fake.word(),
icon={"type": "emoji", "emoji": "🔧"},
description=fake.text(max_nb_chars=200),
parameters=self._create_test_workflow_tool_parameters(),
)
assert "Validation failed" in str(exc_info.value)
# Verify the tool was NOT created in database
# This is the expected behavior (no pollution)
from extensions.ext_database import db
tool_count = (
db.session.query(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == account.current_tenant.id,
WorkflowToolProvider.name == tool_name,
)
.count()
)
# The record should NOT exist because the transaction should be rolled back
# Currently, due to the bug, the record might exist (this test documents the bug)
# After the fix, this should always be 0
# For now, we document that the record may exist, demonstrating the bug
# assert tool_count == 0 # Expected after fix

View File

@ -12,12 +12,10 @@ class TestJinja2CodeExecutor(CodeExecutorTestMixin):
_, Jinja2TemplateTransformer = self.jinja2_imports
template = "Hello {{template}}"
# Template must be base64 encoded to match the new safe embedding approach
template_b64 = base64.b64encode(template.encode("utf-8")).decode("utf-8")
inputs = base64.b64encode(b'{"template": "World"}').decode("utf-8")
code = (
Jinja2TemplateTransformer.get_runner_script()
.replace(Jinja2TemplateTransformer._template_b64_placeholder, template_b64)
.replace(Jinja2TemplateTransformer._code_placeholder, template)
.replace(Jinja2TemplateTransformer._inputs_placeholder, inputs)
)
result = CodeExecutor.execute_code(
@ -39,34 +37,6 @@ class TestJinja2CodeExecutor(CodeExecutorTestMixin):
_, Jinja2TemplateTransformer = self.jinja2_imports
runner_script = Jinja2TemplateTransformer.get_runner_script()
assert runner_script.count(Jinja2TemplateTransformer._template_b64_placeholder) == 1
assert runner_script.count(Jinja2TemplateTransformer._code_placeholder) == 1
assert runner_script.count(Jinja2TemplateTransformer._inputs_placeholder) == 1
assert runner_script.count(Jinja2TemplateTransformer._result_tag) == 2
def test_jinja2_template_with_special_characters(self, flask_app_with_containers):
"""
Test that templates with special characters (quotes, newlines) render correctly.
This is a regression test for issue #26818 where textarea pre-fill values
containing special characters would break template rendering.
"""
CodeExecutor, CodeLanguage = self.code_executor_imports
# Template with triple quotes, single quotes, double quotes, and newlines
template = """<html>
<body>
<input value="{{ task.get('Task ID', '') }}"/>
<textarea>{{ task.get('Issues', 'No issues reported') }}</textarea>
<p>Status: "{{ status }}"</p>
<pre>'''code block'''</pre>
</body>
</html>"""
inputs = {"task": {"Task ID": "TASK-123", "Issues": "Line 1\nLine 2\nLine 3"}, "status": "completed"}
result = CodeExecutor.execute_workflow_code_template(language=CodeLanguage.JINJA2, code=template, inputs=inputs)
# Verify the template rendered correctly with all special characters
output = result["result"]
assert 'value="TASK-123"' in output
assert "<textarea>Line 1\nLine 2\nLine 3</textarea>" in output
assert 'Status: "completed"' in output
assert "'''code block'''" in output

View File

@ -1,145 +0,0 @@
"""Unit tests for load balancing credential validation APIs."""
from __future__ import annotations
import builtins
import importlib
import sys
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from flask import Flask
from flask.views import MethodView
from werkzeug.exceptions import Forbidden
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
if not hasattr(builtins, "MethodView"):
builtins.MethodView = MethodView # type: ignore[attr-defined]
from models.account import TenantAccountRole
@pytest.fixture
def app() -> Flask:
app = Flask(__name__)
app.config["TESTING"] = True
return app
@pytest.fixture
def load_balancing_module(monkeypatch: pytest.MonkeyPatch):
"""Reload controller module with lightweight decorators for testing."""
from controllers.console import console_ns, wraps
from libs import login
def _noop(func):
return func
monkeypatch.setattr(login, "login_required", _noop)
monkeypatch.setattr(wraps, "setup_required", _noop)
monkeypatch.setattr(wraps, "account_initialization_required", _noop)
def _noop_route(*args, **kwargs): # type: ignore[override]
def _decorator(cls):
return cls
return _decorator
monkeypatch.setattr(console_ns, "route", _noop_route)
module_name = "controllers.console.workspace.load_balancing_config"
sys.modules.pop(module_name, None)
module = importlib.import_module(module_name)
return module
def _mock_user(role: TenantAccountRole) -> SimpleNamespace:
return SimpleNamespace(current_role=role)
def _prepare_context(module, monkeypatch: pytest.MonkeyPatch, role=TenantAccountRole.OWNER):
user = _mock_user(role)
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "tenant-123"))
mock_service = MagicMock()
monkeypatch.setattr(module, "ModelLoadBalancingService", lambda: mock_service)
return mock_service
def _request_payload():
return {"model": "gpt-4o", "model_type": ModelType.LLM, "credentials": {"api_key": "sk-***"}}
def test_validate_credentials_success(app: Flask, load_balancing_module, monkeypatch: pytest.MonkeyPatch):
service = _prepare_context(load_balancing_module, monkeypatch)
with app.test_request_context(
"/workspaces/current/model-providers/openai/models/load-balancing-configs/credentials-validate",
method="POST",
json=_request_payload(),
):
response = load_balancing_module.LoadBalancingCredentialsValidateApi().post(provider="openai")
assert response == {"result": "success"}
service.validate_load_balancing_credentials.assert_called_once_with(
tenant_id="tenant-123",
provider="openai",
model="gpt-4o",
model_type=ModelType.LLM,
credentials={"api_key": "sk-***"},
)
def test_validate_credentials_returns_error_message(app: Flask, load_balancing_module, monkeypatch: pytest.MonkeyPatch):
service = _prepare_context(load_balancing_module, monkeypatch)
service.validate_load_balancing_credentials.side_effect = CredentialsValidateFailedError("invalid credentials")
with app.test_request_context(
"/workspaces/current/model-providers/openai/models/load-balancing-configs/credentials-validate",
method="POST",
json=_request_payload(),
):
response = load_balancing_module.LoadBalancingCredentialsValidateApi().post(provider="openai")
assert response == {"result": "error", "error": "invalid credentials"}
def test_validate_credentials_requires_privileged_role(
app: Flask, load_balancing_module, monkeypatch: pytest.MonkeyPatch
):
_prepare_context(load_balancing_module, monkeypatch, role=TenantAccountRole.NORMAL)
with app.test_request_context(
"/workspaces/current/model-providers/openai/models/load-balancing-configs/credentials-validate",
method="POST",
json=_request_payload(),
):
api = load_balancing_module.LoadBalancingCredentialsValidateApi()
with pytest.raises(Forbidden):
api.post(provider="openai")
def test_validate_credentials_with_config_id(app: Flask, load_balancing_module, monkeypatch: pytest.MonkeyPatch):
service = _prepare_context(load_balancing_module, monkeypatch)
with app.test_request_context(
"/workspaces/current/model-providers/openai/models/load-balancing-configs/cfg-1/credentials-validate",
method="POST",
json=_request_payload(),
):
response = load_balancing_module.LoadBalancingConfigCredentialsValidateApi().post(
provider="openai", config_id="cfg-1"
)
assert response == {"result": "success"}
service.validate_load_balancing_credentials.assert_called_once_with(
tenant_id="tenant-123",
provider="openai",
model="gpt-4o",
model_type=ModelType.LLM,
credentials={"api_key": "sk-***"},
config_id="cfg-1",
)

View File

@ -1,103 +0,0 @@
import json
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from flask_restx import Api
from controllers.console.workspace.tool_providers import ToolProviderMCPApi
from core.db.session_factory import configure_session_factory
from extensions.ext_database import db
from services.tools.mcp_tools_manage_service import ReconnectResult
# Backward-compat fixtures referenced by @pytest.mark.usefixtures in this file.
# They are intentionally no-ops because the test already patches the required
# behaviors explicitly via @patch and context managers below.
@pytest.fixture
def _mock_cache():
return
@pytest.fixture
def _mock_user_tenant():
return
@pytest.fixture
def client():
app = Flask(__name__)
app.config["TESTING"] = True
app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///:memory:"
api = Api(app)
api.add_resource(ToolProviderMCPApi, "/console/api/workspaces/current/tool-provider/mcp")
db.init_app(app)
# Configure session factory used by controller code
with app.app_context():
configure_session_factory(db.engine)
return app.test_client()
@patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant", return_value=(MagicMock(id="u1"), "t1")
)
@patch("controllers.console.workspace.tool_providers.ToolProviderListCache.invalidate_cache", return_value=None)
@patch("controllers.console.workspace.tool_providers.Session")
@patch("controllers.console.workspace.tool_providers.MCPToolManageService._reconnect_with_url")
@pytest.mark.usefixtures("_mock_cache", "_mock_user_tenant")
def test_create_mcp_provider_populates_tools(
mock_reconnect, mock_session, mock_invalidate_cache, mock_current_account_with_tenant, client
):
# Arrange: reconnect returns tools immediately
mock_reconnect.return_value = ReconnectResult(
authed=True,
tools=json.dumps(
[{"name": "ping", "description": "ok", "inputSchema": {"type": "object"}, "outputSchema": {}}]
),
encrypted_credentials="{}",
)
# Fake service.create_provider -> returns object with id for reload
svc = MagicMock()
create_result = MagicMock()
create_result.id = "provider-1"
svc.create_provider.return_value = create_result
svc.get_provider.return_value = MagicMock(id="provider-1", tenant_id="t1") # used by reload path
mock_session.return_value.__enter__.return_value = MagicMock()
# Patch MCPToolManageService constructed inside controller
with patch("controllers.console.workspace.tool_providers.MCPToolManageService", return_value=svc):
payload = {
"server_url": "http://example.com/mcp",
"name": "demo",
"icon": "😀",
"icon_type": "emoji",
"icon_background": "#000",
"server_identifier": "demo-sid",
"configuration": {"timeout": 5, "sse_read_timeout": 30},
"headers": {},
"authentication": {},
}
# Act
with (
patch("controllers.console.wraps.dify_config.EDITION", "CLOUD"), # bypass setup_required DB check
patch("controllers.console.wraps.current_account_with_tenant", return_value=(MagicMock(id="u1"), "t1")),
patch("libs.login.check_csrf_token", return_value=None), # bypass CSRF in login_required
patch("libs.login._get_user", return_value=MagicMock(id="u1", is_authenticated=True)), # login
patch(
"services.tools.tools_transform_service.ToolTransformService.mcp_provider_to_user_provider",
return_value={"id": "provider-1", "tools": [{"name": "ping"}]},
),
):
resp = client.post(
"/console/api/workspaces/current/tool-provider/mcp",
data=json.dumps(payload),
content_type="application/json",
)
# Assert
assert resp.status_code == 200
body = resp.get_json()
assert body.get("id") == "provider-1"
# 若 transform 后包含 tools 字段,确保非空
assert isinstance(body.get("tools"), list)
assert body["tools"]

View File

@ -1,245 +0,0 @@
from types import SimpleNamespace
from unittest.mock import MagicMock
import core.app.apps.workflow.app_runner as workflow_app_runner
from core.app.apps.workflow.app_runner import WorkflowAppRunner, WorkflowBasedAppRunner
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.enums import NodeType, SystemVariableKey, WorkflowType
from core.workflow.system_variable import SystemVariable
def test_prepare_single_iteration_injects_system_variables_and_fake_workflow():
node_id = "iteration_node"
execution_id = "workflow-exec-123"
workflow = SimpleNamespace(
id="workflow-id",
tenant_id="tenant-id",
app_id="app-id",
environment_variables=[],
graph_dict={
"nodes": [
{
"id": node_id,
"type": "custom",
"data": {
"type": NodeType.ITERATION,
"title": "Iteration",
"version": "1",
"iterator_selector": ["start", "items"],
"output_selector": [node_id, "output"],
},
}
],
"edges": [],
},
)
runner = WorkflowBasedAppRunner(queue_manager=MagicMock(), app_id="app-id")
system_inputs = SystemVariable(app_id="app-id", workflow_id="workflow-id", workflow_execution_id=execution_id)
graph, _, runtime_state = runner._prepare_single_node_execution(
workflow=workflow,
single_iteration_run=SimpleNamespace(node_id=node_id, inputs={"input_selector": [1, 2, 3]}),
system_variables=system_inputs,
)
assert runtime_state.variable_pool.system_variables.workflow_execution_id == execution_id
assert runtime_state.variable_pool.get_by_prefix("sys")[SystemVariableKey.WORKFLOW_EXECUTION_ID] == execution_id
assert graph.root_node.id == f"{node_id}_single_step_start"
assert f"{node_id}_single_step_end" in graph.nodes
def test_prepare_single_loop_injects_system_variables_and_fake_workflow():
node_id = "loop_node"
execution_id = "workflow-exec-456"
workflow = SimpleNamespace(
id="workflow-id",
tenant_id="tenant-id",
app_id="app-id",
environment_variables=[],
graph_dict={
"nodes": [
{
"id": node_id,
"type": "custom",
"data": {
"type": NodeType.LOOP,
"title": "Loop",
"version": "1",
"loop_count": 1,
"break_conditions": [],
"logical_operator": "and",
"loop_variables": [],
"outputs": {},
},
}
],
"edges": [],
},
)
runner = WorkflowBasedAppRunner(queue_manager=MagicMock(), app_id="app-id")
system_inputs = SystemVariable(app_id="app-id", workflow_id="workflow-id", workflow_execution_id=execution_id)
graph, _, runtime_state = runner._prepare_single_node_execution(
workflow=workflow,
single_loop_run=SimpleNamespace(node_id=node_id, inputs={}),
system_variables=system_inputs,
)
assert runtime_state.variable_pool.system_variables.workflow_execution_id == execution_id
assert graph.root_node.id == f"{node_id}_single_step_start"
assert f"{node_id}_single_step_end" in graph.nodes
class DummyCommandChannel:
def fetch_commands(self):
return []
def send_command(self, command):
return None
def _empty_graph_engine_run(self):
if False: # pragma: no cover
yield None
def _build_generate_entity(*, single_iteration_run=None, single_loop_run=None):
if isinstance(single_iteration_run, dict):
single_iteration_run = SimpleNamespace(**single_iteration_run)
if isinstance(single_loop_run, dict):
single_loop_run = SimpleNamespace(**single_loop_run)
base = SimpleNamespace(
app_config=SimpleNamespace(app_id="app-id", workflow_id="workflow-id"),
workflow_execution_id="workflow-exec-id",
files=[],
user_id="user-id",
inputs={},
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
task_id="task-id",
trace_manager=None,
single_iteration_run=single_iteration_run,
single_loop_run=single_loop_run,
)
def is_single_stepping_container_nodes():
return base.single_iteration_run is not None or base.single_loop_run is not None
base.is_single_stepping_container_nodes = is_single_stepping_container_nodes # type: ignore[attr-defined]
return base
def test_workflow_runner_attaches_persistence_for_full_run(monkeypatch):
from core.workflow.graph_engine.graph_engine import GraphEngine
monkeypatch.setattr(GraphEngine, "run", _empty_graph_engine_run)
persistence_ctor = MagicMock(name="persistence_layer_ctor")
monkeypatch.setattr(workflow_app_runner, "WorkflowPersistenceLayer", persistence_ctor)
monkeypatch.setattr(workflow_app_runner, "RedisChannel", lambda *args, **kwargs: DummyCommandChannel())
queue_manager = MagicMock()
workflow = SimpleNamespace(
id="workflow-id",
tenant_id="tenant-id",
app_id="app-id",
type=WorkflowType.WORKFLOW,
version="1",
graph_dict={
"nodes": [
{
"id": "start",
"type": "custom",
"data": {"type": NodeType.START, "title": "Start", "version": "1", "variables": []},
},
{
"id": "end",
"type": "custom",
"data": {"type": NodeType.END, "title": "End", "version": "1", "outputs": []},
},
],
"edges": [
{"source": "start", "target": "end", "sourceHandle": "source", "targetHandle": "target"},
],
},
environment_variables=[],
)
generate_entity = _build_generate_entity()
generate_entity.inputs = {"input": "value"}
runner = WorkflowAppRunner(
application_generate_entity=generate_entity,
queue_manager=queue_manager,
variable_loader=MagicMock(),
workflow=workflow,
system_user_id="system-user-id",
root_node_id=None,
workflow_execution_repository=MagicMock(),
workflow_node_execution_repository=MagicMock(),
graph_engine_layers=(),
)
runner.run()
assert persistence_ctor.call_count == 1
def test_workflow_runner_skips_persistence_for_single_step(monkeypatch):
from core.workflow.graph_engine.graph_engine import GraphEngine
monkeypatch.setattr(GraphEngine, "run", _empty_graph_engine_run)
persistence_ctor = MagicMock(name="persistence_layer_ctor")
monkeypatch.setattr(workflow_app_runner, "WorkflowPersistenceLayer", persistence_ctor)
monkeypatch.setattr(workflow_app_runner, "RedisChannel", lambda *args, **kwargs: DummyCommandChannel())
queue_manager = MagicMock()
workflow = SimpleNamespace(
id="workflow-id",
tenant_id="tenant-id",
app_id="app-id",
type=WorkflowType.WORKFLOW,
version="1",
graph_dict={
"nodes": [
{
"id": "loop",
"type": "custom",
"data": {
"type": NodeType.LOOP,
"title": "Loop",
"version": "1",
"loop_count": 1,
"break_conditions": [],
"logical_operator": "and",
"loop_variables": [],
"outputs": {},
},
}
],
"edges": [],
},
environment_variables=[],
)
generate_entity = _build_generate_entity(single_loop_run={"node_id": "loop", "inputs": {}})
runner = WorkflowAppRunner(
application_generate_entity=generate_entity,
queue_manager=queue_manager,
variable_loader=MagicMock(),
workflow=workflow,
system_user_id="system-user-id",
root_node_id=None,
workflow_execution_repository=MagicMock(),
workflow_node_execution_repository=MagicMock(),
graph_engine_layers=(),
)
runner.run()
assert persistence_ctor.call_count == 0

View File

@ -96,6 +96,9 @@ class TestToolProviderListCache:
ToolProviderListCache.invalidate_cache(tenant_id)
mock_redis_client.scan_iter.assert_called_once_with(f"tool_providers:tenant_id:{tenant_id}:*")
mock_redis_client.delete.assert_called_once_with(*mock_keys)
def test_invalidate_cache_no_keys(self, mock_redis_client):
"""Test invalidate cache - no cache keys for tenant"""
tenant_id = "tenant_123"

View File

@ -1,327 +0,0 @@
import unittest
from unittest.mock import MagicMock, patch
import pytest
from core.rag.datasource.vdb.pgvector.pgvector import (
PGVector,
PGVectorConfig,
)
class TestPGVector(unittest.TestCase):
def setUp(self):
self.config = PGVectorConfig(
host="localhost",
port=5432,
user="test_user",
password="test_password",
database="test_db",
min_connection=1,
max_connection=5,
pg_bigm=False,
)
self.collection_name = "test_collection"
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
def test_init(self, mock_pool_class):
"""Test PGVector initialization."""
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
pgvector = PGVector(self.collection_name, self.config)
assert pgvector._collection_name == self.collection_name
assert pgvector.table_name == f"embedding_{self.collection_name}"
assert pgvector.get_type() == "pgvector"
assert pgvector.pool is not None
assert pgvector.pg_bigm is False
assert pgvector.index_hash is not None
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
def test_init_with_pg_bigm(self, mock_pool_class):
"""Test PGVector initialization with pg_bigm enabled."""
config = PGVectorConfig(
host="localhost",
port=5432,
user="test_user",
password="test_password",
database="test_db",
min_connection=1,
max_connection=5,
pg_bigm=True,
)
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
pgvector = PGVector(self.collection_name, config)
assert pgvector.pg_bigm is True
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
@patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client")
def test_create_collection_basic(self, mock_redis, mock_pool_class):
"""Test basic collection creation."""
# Mock Redis operations
mock_lock = MagicMock()
mock_lock.__enter__ = MagicMock()
mock_lock.__exit__ = MagicMock()
mock_redis.lock.return_value = mock_lock
mock_redis.get.return_value = None
mock_redis.set.return_value = None
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
# Mock connection and cursor
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.getconn.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.return_value = [1] # vector extension exists
pgvector = PGVector(self.collection_name, self.config)
pgvector._create_collection(1536)
# Verify SQL execution calls
assert mock_cursor.execute.called
# Check that CREATE TABLE was called with correct dimension
create_table_calls = [call for call in mock_cursor.execute.call_args_list if "CREATE TABLE" in str(call)]
assert len(create_table_calls) == 1
assert "vector(1536)" in create_table_calls[0][0][0]
# Check that CREATE INDEX was called (dimension <= 2000)
create_index_calls = [
call for call in mock_cursor.execute.call_args_list if "CREATE INDEX" in str(call) and "hnsw" in str(call)
]
assert len(create_index_calls) == 1
# Verify Redis cache was set
mock_redis.set.assert_called_once()
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
@patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client")
def test_create_collection_with_large_dimension(self, mock_redis, mock_pool_class):
"""Test collection creation with dimension > 2000 (no HNSW index)."""
# Mock Redis operations
mock_lock = MagicMock()
mock_lock.__enter__ = MagicMock()
mock_lock.__exit__ = MagicMock()
mock_redis.lock.return_value = mock_lock
mock_redis.get.return_value = None
mock_redis.set.return_value = None
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
# Mock connection and cursor
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.getconn.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.return_value = [1] # vector extension exists
pgvector = PGVector(self.collection_name, self.config)
pgvector._create_collection(3072) # Dimension > 2000
# Check that CREATE TABLE was called
create_table_calls = [call for call in mock_cursor.execute.call_args_list if "CREATE TABLE" in str(call)]
assert len(create_table_calls) == 1
assert "vector(3072)" in create_table_calls[0][0][0]
# Check that HNSW index was NOT created (dimension > 2000)
hnsw_index_calls = [call for call in mock_cursor.execute.call_args_list if "hnsw" in str(call)]
assert len(hnsw_index_calls) == 0
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
@patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client")
def test_create_collection_with_pg_bigm(self, mock_redis, mock_pool_class):
"""Test collection creation with pg_bigm enabled."""
config = PGVectorConfig(
host="localhost",
port=5432,
user="test_user",
password="test_password",
database="test_db",
min_connection=1,
max_connection=5,
pg_bigm=True,
)
# Mock Redis operations
mock_lock = MagicMock()
mock_lock.__enter__ = MagicMock()
mock_lock.__exit__ = MagicMock()
mock_redis.lock.return_value = mock_lock
mock_redis.get.return_value = None
mock_redis.set.return_value = None
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
# Mock connection and cursor
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.getconn.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.return_value = [1] # vector extension exists
pgvector = PGVector(self.collection_name, config)
pgvector._create_collection(1536)
# Check that pg_bigm index was created
bigm_index_calls = [call for call in mock_cursor.execute.call_args_list if "gin_bigm_ops" in str(call)]
assert len(bigm_index_calls) == 1
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
@patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client")
def test_create_collection_creates_vector_extension(self, mock_redis, mock_pool_class):
"""Test that vector extension is created if it doesn't exist."""
# Mock Redis operations
mock_lock = MagicMock()
mock_lock.__enter__ = MagicMock()
mock_lock.__exit__ = MagicMock()
mock_redis.lock.return_value = mock_lock
mock_redis.get.return_value = None
mock_redis.set.return_value = None
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
# Mock connection and cursor
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.getconn.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
# First call: vector extension doesn't exist
mock_cursor.fetchone.return_value = None
pgvector = PGVector(self.collection_name, self.config)
pgvector._create_collection(1536)
# Check that CREATE EXTENSION was called
create_extension_calls = [
call for call in mock_cursor.execute.call_args_list if "CREATE EXTENSION vector" in str(call)
]
assert len(create_extension_calls) == 1
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
@patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client")
def test_create_collection_with_cache_hit(self, mock_redis, mock_pool_class):
"""Test that collection creation is skipped when cache exists."""
# Mock Redis operations - cache exists
mock_lock = MagicMock()
mock_lock.__enter__ = MagicMock()
mock_lock.__exit__ = MagicMock()
mock_redis.lock.return_value = mock_lock
mock_redis.get.return_value = 1 # Cache exists
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
# Mock connection and cursor
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.getconn.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
pgvector = PGVector(self.collection_name, self.config)
pgvector._create_collection(1536)
# Check that no SQL was executed (early return due to cache)
assert mock_cursor.execute.call_count == 0
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
@patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client")
def test_create_collection_with_redis_lock(self, mock_redis, mock_pool_class):
"""Test that Redis lock is used during collection creation."""
# Mock Redis operations
mock_lock = MagicMock()
mock_lock.__enter__ = MagicMock()
mock_lock.__exit__ = MagicMock()
mock_redis.lock.return_value = mock_lock
mock_redis.get.return_value = None
mock_redis.set.return_value = None
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
# Mock connection and cursor
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.getconn.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.return_value = [1] # vector extension exists
pgvector = PGVector(self.collection_name, self.config)
pgvector._create_collection(1536)
# Verify Redis lock was acquired with correct lock name
mock_redis.lock.assert_called_once_with("vector_indexing_test_collection_lock", timeout=20)
# Verify lock context manager was entered and exited
mock_lock.__enter__.assert_called_once()
mock_lock.__exit__.assert_called_once()
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
def test_get_cursor_context_manager(self, mock_pool_class):
"""Test that _get_cursor properly manages connection lifecycle."""
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.getconn.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
pgvector = PGVector(self.collection_name, self.config)
with pgvector._get_cursor() as cur:
assert cur == mock_cursor
# Verify connection lifecycle methods were called
mock_pool.getconn.assert_called_once()
mock_cursor.close.assert_called_once()
mock_conn.commit.assert_called_once()
mock_pool.putconn.assert_called_once_with(mock_conn)
@pytest.mark.parametrize(
"invalid_config_override",
[
{"host": ""}, # Test empty host
{"port": 0}, # Test invalid port
{"user": ""}, # Test empty user
{"password": ""}, # Test empty password
{"database": ""}, # Test empty database
{"min_connection": 0}, # Test invalid min_connection
{"max_connection": 0}, # Test invalid max_connection
{"min_connection": 10, "max_connection": 5}, # Test min > max
],
)
def test_config_validation_parametrized(invalid_config_override):
"""Test configuration validation for various invalid inputs using parametrize."""
config = {
"host": "localhost",
"port": 5432,
"user": "test_user",
"password": "test_password",
"database": "test_db",
"min_connection": 1,
"max_connection": 5,
}
config.update(invalid_config_override)
with pytest.raises(ValueError):
PGVectorConfig(**config)
if __name__ == "__main__":
unittest.main()

View File

@ -1,873 +0,0 @@
"""
Unit tests for DatasetRetrieval.process_metadata_filter_func.
This module provides comprehensive test coverage for the process_metadata_filter_func
method in the DatasetRetrieval class, which is responsible for building SQLAlchemy
filter expressions based on metadata filtering conditions.
Conditions Tested:
==================
1. **String Conditions**: contains, not contains, start with, end with
2. **Equality Conditions**: is / =, is not / ≠
3. **Null Conditions**: empty, not empty
4. **Numeric Comparisons**: before / <, after / >, ≤ / <=, ≥ / >=
5. **List Conditions**: in
6. **Edge Cases**: None values, different data types (str, int, float)
Test Architecture:
==================
- Direct instantiation of DatasetRetrieval
- Mocking of DatasetDocument model attributes
- Verification of SQLAlchemy filter expressions
- Follows Arrange-Act-Assert (AAA) pattern
Running Tests:
==============
# Run all tests in this module
uv run --project api pytest \
api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_metadata_filter.py -v
# Run a specific test
uv run --project api pytest \
api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_metadata_filter.py::\
TestProcessMetadataFilterFunc::test_contains_condition -v
"""
from unittest.mock import MagicMock
import pytest
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
class TestProcessMetadataFilterFunc:
"""
Comprehensive test suite for process_metadata_filter_func method.
This test class validates all metadata filtering conditions supported by
the DatasetRetrieval class, including string operations, numeric comparisons,
null checks, and list operations.
Method Signature:
==================
def process_metadata_filter_func(
self, sequence: int, condition: str, metadata_name: str, value: Any | None, filters: list
) -> list:
The method builds SQLAlchemy filter expressions by:
1. Validating value is not None (except for empty/not empty conditions)
2. Using DatasetDocument.doc_metadata JSON field operations
3. Adding appropriate SQLAlchemy expressions to the filters list
4. Returning the updated filters list
Mocking Strategy:
==================
- Mock DatasetDocument.doc_metadata to avoid database dependencies
- Verify filter expressions are created correctly
- Test with various data types (str, int, float, list)
"""
@pytest.fixture
def retrieval(self):
"""
Create a DatasetRetrieval instance for testing.
Returns:
DatasetRetrieval: Instance to test process_metadata_filter_func
"""
return DatasetRetrieval()
@pytest.fixture
def mock_doc_metadata(self):
"""
Mock the DatasetDocument.doc_metadata JSON field.
The method uses DatasetDocument.doc_metadata[metadata_name] to access
JSON fields. We mock this to avoid database dependencies.
Returns:
Mock: Mocked doc_metadata attribute
"""
mock_metadata_field = MagicMock()
# Create mock for string access
mock_string_access = MagicMock()
mock_string_access.like = MagicMock()
mock_string_access.notlike = MagicMock()
mock_string_access.__eq__ = MagicMock(return_value=MagicMock())
mock_string_access.__ne__ = MagicMock(return_value=MagicMock())
mock_string_access.in_ = MagicMock(return_value=MagicMock())
# Create mock for float access (for numeric comparisons)
mock_float_access = MagicMock()
mock_float_access.__eq__ = MagicMock(return_value=MagicMock())
mock_float_access.__ne__ = MagicMock(return_value=MagicMock())
mock_float_access.__lt__ = MagicMock(return_value=MagicMock())
mock_float_access.__gt__ = MagicMock(return_value=MagicMock())
mock_float_access.__le__ = MagicMock(return_value=MagicMock())
mock_float_access.__ge__ = MagicMock(return_value=MagicMock())
# Create mock for null checks
mock_null_access = MagicMock()
mock_null_access.is_ = MagicMock(return_value=MagicMock())
mock_null_access.isnot = MagicMock(return_value=MagicMock())
# Setup __getitem__ to return appropriate mock based on usage
def getitem_side_effect(name):
if name in ["author", "title", "category"]:
return mock_string_access
elif name in ["year", "price", "rating"]:
return mock_float_access
else:
return mock_string_access
mock_metadata_field.__getitem__ = MagicMock(side_effect=getitem_side_effect)
mock_metadata_field.as_string.return_value = mock_string_access
mock_metadata_field.as_float.return_value = mock_float_access
mock_metadata_field[metadata_name:str].is_ = mock_null_access.is_
mock_metadata_field[metadata_name:str].isnot = mock_null_access.isnot
return mock_metadata_field
# ==================== String Condition Tests ====================
def test_contains_condition_string_value(self, retrieval):
"""
Test 'contains' condition with string value.
Verifies:
- Filters list is populated with LIKE expression
- Pattern matching uses %value% syntax
"""
filters = []
sequence = 0
condition = "contains"
metadata_name = "author"
value = "John"
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_not_contains_condition(self, retrieval):
"""
Test 'not contains' condition.
Verifies:
- Filters list is populated with NOT LIKE expression
- Pattern matching uses %value% syntax with negation
"""
filters = []
sequence = 0
condition = "not contains"
metadata_name = "title"
value = "banned"
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_start_with_condition(self, retrieval):
"""
Test 'start with' condition.
Verifies:
- Filters list is populated with LIKE expression
- Pattern matching uses value% syntax
"""
filters = []
sequence = 0
condition = "start with"
metadata_name = "category"
value = "tech"
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_end_with_condition(self, retrieval):
"""
Test 'end with' condition.
Verifies:
- Filters list is populated with LIKE expression
- Pattern matching uses %value syntax
"""
filters = []
sequence = 0
condition = "end with"
metadata_name = "filename"
value = ".pdf"
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
# ==================== Equality Condition Tests ====================
def test_is_condition_with_string_value(self, retrieval):
"""
Test 'is' (=) condition with string value.
Verifies:
- Filters list is populated with equality expression
- String comparison is used
"""
filters = []
sequence = 0
condition = "is"
metadata_name = "author"
value = "Jane Doe"
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_equals_condition_with_string_value(self, retrieval):
"""
Test '=' condition with string value.
Verifies:
- Same behavior as 'is' condition
- String comparison is used
"""
filters = []
sequence = 0
condition = "="
metadata_name = "category"
value = "technology"
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_is_condition_with_int_value(self, retrieval):
"""
Test 'is' condition with integer value.
Verifies:
- Numeric comparison is used
- as_float() is called on the metadata field
"""
filters = []
sequence = 0
condition = "is"
metadata_name = "year"
value = 2023
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_is_condition_with_float_value(self, retrieval):
"""
Test 'is' condition with float value.
Verifies:
- Numeric comparison is used
- as_float() is called on the metadata field
"""
filters = []
sequence = 0
condition = "is"
metadata_name = "price"
value = 19.99
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_is_not_condition_with_string_value(self, retrieval):
"""
Test 'is not' (≠) condition with string value.
Verifies:
- Filters list is populated with inequality expression
- String comparison is used
"""
filters = []
sequence = 0
condition = "is not"
metadata_name = "author"
value = "Unknown"
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_not_equals_condition(self, retrieval):
"""
Test '' condition with string value.
Verifies:
- Same behavior as 'is not' condition
- Inequality expression is used
"""
filters = []
sequence = 0
condition = ""
metadata_name = "category"
value = "archived"
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_is_not_condition_with_numeric_value(self, retrieval):
"""
Test 'is not' condition with numeric value.
Verifies:
- Numeric inequality comparison is used
- as_float() is called on the metadata field
"""
filters = []
sequence = 0
condition = "is not"
metadata_name = "year"
value = 2000
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
# ==================== Null Condition Tests ====================
def test_empty_condition(self, retrieval):
"""
Test 'empty' condition (null check).
Verifies:
- Filters list is populated with IS NULL expression
- Value can be None for this condition
"""
filters = []
sequence = 0
condition = "empty"
metadata_name = "author"
value = None
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_not_empty_condition(self, retrieval):
"""
Test 'not empty' condition (not null check).
Verifies:
- Filters list is populated with IS NOT NULL expression
- Value can be None for this condition
"""
filters = []
sequence = 0
condition = "not empty"
metadata_name = "description"
value = None
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
# ==================== Numeric Comparison Tests ====================
def test_before_condition(self, retrieval):
"""
Test 'before' (<) condition.
Verifies:
- Filters list is populated with less than expression
- Numeric comparison is used
"""
filters = []
sequence = 0
condition = "before"
metadata_name = "year"
value = 2020
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_less_than_condition(self, retrieval):
"""
Test '<' condition.
Verifies:
- Same behavior as 'before' condition
- Less than expression is used
"""
filters = []
sequence = 0
condition = "<"
metadata_name = "price"
value = 100.0
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_after_condition(self, retrieval):
"""
Test 'after' (>) condition.
Verifies:
- Filters list is populated with greater than expression
- Numeric comparison is used
"""
filters = []
sequence = 0
condition = "after"
metadata_name = "year"
value = 2020
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_greater_than_condition(self, retrieval):
"""
Test '>' condition.
Verifies:
- Same behavior as 'after' condition
- Greater than expression is used
"""
filters = []
sequence = 0
condition = ">"
metadata_name = "rating"
value = 4.5
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_less_than_or_equal_condition_unicode(self, retrieval):
"""
Test '' condition.
Verifies:
- Filters list is populated with less than or equal expression
- Numeric comparison is used
"""
filters = []
sequence = 0
condition = ""
metadata_name = "price"
value = 50.0
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_less_than_or_equal_condition_ascii(self, retrieval):
"""
Test '<=' condition.
Verifies:
- Same behavior as '' condition
- Less than or equal expression is used
"""
filters = []
sequence = 0
condition = "<="
metadata_name = "year"
value = 2023
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_greater_than_or_equal_condition_unicode(self, retrieval):
"""
Test '' condition.
Verifies:
- Filters list is populated with greater than or equal expression
- Numeric comparison is used
"""
filters = []
sequence = 0
condition = ""
metadata_name = "rating"
value = 3.5
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_greater_than_or_equal_condition_ascii(self, retrieval):
"""
Test '>=' condition.
Verifies:
- Same behavior as '' condition
- Greater than or equal expression is used
"""
filters = []
sequence = 0
condition = ">="
metadata_name = "year"
value = 2000
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
# ==================== List/In Condition Tests ====================
def test_in_condition_with_comma_separated_string(self, retrieval):
"""
Test 'in' condition with comma-separated string value.
Verifies:
- String is split into list
- Whitespace is trimmed from each value
- IN expression is created
"""
filters = []
sequence = 0
condition = "in"
metadata_name = "category"
value = "tech, science, AI "
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_in_condition_with_list_value(self, retrieval):
"""
Test 'in' condition with list value.
Verifies:
- List is processed correctly
- None values are filtered out
- IN expression is created with valid values
"""
filters = []
sequence = 0
condition = "in"
metadata_name = "tags"
value = ["python", "javascript", None, "golang"]
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_in_condition_with_tuple_value(self, retrieval):
"""
Test 'in' condition with tuple value.
Verifies:
- Tuple is processed like a list
- IN expression is created
"""
filters = []
sequence = 0
condition = "in"
metadata_name = "category"
value = ("tech", "science", "ai")
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_in_condition_with_empty_string(self, retrieval):
"""
Test 'in' condition with empty string value.
Verifies:
- Empty string results in literal(False) filter
- No valid values to match
"""
filters = []
sequence = 0
condition = "in"
metadata_name = "category"
value = ""
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
# Verify it's a literal(False) expression
# This is a bit tricky to test without access to the actual expression
def test_in_condition_with_only_whitespace(self, retrieval):
"""
Test 'in' condition with whitespace-only string value.
Verifies:
- Whitespace-only string results in literal(False) filter
- All values are stripped and filtered out
"""
filters = []
sequence = 0
condition = "in"
metadata_name = "category"
value = " , , "
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_in_condition_with_single_string(self, retrieval):
"""
Test 'in' condition with single non-comma string.
Verifies:
- Single string is treated as single-item list
- IN expression is created with one value
"""
filters = []
sequence = 0
condition = "in"
metadata_name = "category"
value = "technology"
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
# ==================== Edge Case Tests ====================
def test_none_value_with_non_empty_condition(self, retrieval):
"""
Test None value with conditions that require value.
Verifies:
- Original filters list is returned unchanged
- No filter is added for None values (except empty/not empty)
"""
filters = []
sequence = 0
condition = "contains"
metadata_name = "author"
value = None
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 0 # No filter added
def test_none_value_with_equals_condition(self, retrieval):
"""
Test None value with 'is' (=) condition.
Verifies:
- Original filters list is returned unchanged
- No filter is added for None values
"""
filters = []
sequence = 0
condition = "is"
metadata_name = "author"
value = None
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 0
def test_none_value_with_numeric_condition(self, retrieval):
"""
Test None value with numeric comparison condition.
Verifies:
- Original filters list is returned unchanged
- No filter is added for None values
"""
filters = []
sequence = 0
condition = ">"
metadata_name = "year"
value = None
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 0
def test_existing_filters_preserved(self, retrieval):
"""
Test that existing filters are preserved.
Verifies:
- Existing filters in the list are not removed
- New filters are appended to the list
"""
existing_filter = MagicMock()
filters = [existing_filter]
sequence = 0
condition = "contains"
metadata_name = "author"
value = "test"
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 2
assert filters[0] == existing_filter
def test_multiple_filters_accumulated(self, retrieval):
"""
Test multiple calls to accumulate filters.
Verifies:
- Each call adds a new filter to the list
- All filters are preserved across calls
"""
filters = []
# First filter
retrieval.process_metadata_filter_func(0, "contains", "author", "John", filters)
assert len(filters) == 1
# Second filter
retrieval.process_metadata_filter_func(1, ">", "year", 2020, filters)
assert len(filters) == 2
# Third filter
retrieval.process_metadata_filter_func(2, "is", "category", "tech", filters)
assert len(filters) == 3
def test_unknown_condition(self, retrieval):
"""
Test unknown/unsupported condition.
Verifies:
- Original filters list is returned unchanged
- No filter is added for unknown conditions
"""
filters = []
sequence = 0
condition = "unknown_condition"
metadata_name = "author"
value = "test"
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 0
def test_empty_string_value_with_contains(self, retrieval):
"""
Test empty string value with 'contains' condition.
Verifies:
- Filter is added even with empty string
- LIKE expression is created
"""
filters = []
sequence = 0
condition = "contains"
metadata_name = "author"
value = ""
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_special_characters_in_value(self, retrieval):
"""
Test special characters in value string.
Verifies:
- Special characters are handled in value
- LIKE expression is created correctly
"""
filters = []
sequence = 0
condition = "contains"
metadata_name = "title"
value = "C++ & Python's features"
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_zero_value_with_numeric_condition(self, retrieval):
"""
Test zero value with numeric comparison condition.
Verifies:
- Zero is treated as valid value
- Numeric comparison is performed
"""
filters = []
sequence = 0
condition = ">"
metadata_name = "price"
value = 0
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_negative_value_with_numeric_condition(self, retrieval):
"""
Test negative value with numeric comparison condition.
Verifies:
- Negative numbers are handled correctly
- Numeric comparison is performed
"""
filters = []
sequence = 0
condition = "<"
metadata_name = "temperature"
value = -10.5
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1
def test_float_value_with_integer_comparison(self, retrieval):
"""
Test float value with numeric comparison condition.
Verifies:
- Float values work correctly
- Numeric comparison is performed
"""
filters = []
sequence = 0
condition = ">="
metadata_name = "rating"
value = 4.5
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
assert result == filters
assert len(filters) == 1

View File

@ -1,51 +0,0 @@
from core.workflow.entities import GraphInitParams
from core.workflow.graph_events import (
NodeRunIterationStartedEvent,
NodeRunIterationSucceededEvent,
NodeRunSucceededEvent,
)
from core.workflow.nodes.iteration.iteration_node import IterationNode
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
def test_iteration_node_emits_iteration_events_when_iterator_empty():
init_params = GraphInitParams(
tenant_id="tenant",
app_id="app",
workflow_id="workflow",
graph_config={},
user_id="user",
user_from="account",
invoke_from="debugger",
call_depth=0,
)
runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=SystemVariable.empty(), user_inputs={}),
start_at=0.0,
)
runtime_state.variable_pool.add(("start", "items"), [])
node = IterationNode(
id="iteration-node",
config={
"id": "iteration-node",
"data": {
"title": "Iteration",
"iterator_selector": ["start", "items"],
"output_selector": ["iteration-node", "output"],
},
},
graph_init_params=init_params,
graph_runtime_state=runtime_state,
)
events = list(node.run())
assert any(isinstance(event, NodeRunIterationStartedEvent) for event in events)
iteration_succeeded_event = next(event for event in events if isinstance(event, NodeRunIterationSucceededEvent))
assert iteration_succeeded_event.steps == 0
assert iteration_succeeded_event.outputs == {"output": []}
assert any(isinstance(event, NodeRunSucceededEvent) for event in events)

View File

@ -1,122 +0,0 @@
import base64
from unittest.mock import Mock, patch
import pytest
from core.mcp.types import (
AudioContent,
BlobResourceContents,
CallToolResult,
EmbeddedResource,
ImageContent,
TextResourceContents,
)
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage
from core.tools.mcp_tool.tool import MCPTool
def _make_mcp_tool(output_schema: dict | None = None) -> MCPTool:
identity = ToolIdentity(
author="test",
name="test_mcp_tool",
label=I18nObject(en_US="Test MCP Tool", zh_Hans="测试MCP工具"),
provider="test_provider",
)
entity = ToolEntity(identity=identity, output_schema=output_schema or {})
runtime = Mock(spec=ToolRuntime)
runtime.credentials = {}
return MCPTool(
entity=entity,
runtime=runtime,
tenant_id="test_tenant",
icon="",
server_url="https://server.invalid",
provider_id="provider_1",
headers={},
)
class TestMCPToolInvoke:
@pytest.mark.parametrize(
("content_factory", "mime_type"),
[
(
lambda b64, mt: ImageContent(type="image", data=b64, mimeType=mt),
"image/png",
),
(
lambda b64, mt: AudioContent(type="audio", data=b64, mimeType=mt),
"audio/mpeg",
),
],
)
def test_invoke_image_or_audio_yields_blob(self, content_factory, mime_type) -> None:
tool = _make_mcp_tool()
raw = b"\x00\x01test-bytes\x02"
b64 = base64.b64encode(raw).decode()
content = content_factory(b64, mime_type)
result = CallToolResult(content=[content])
with patch.object(tool, "invoke_remote_mcp_tool", return_value=result):
messages = list(tool._invoke(user_id="test_user", tool_parameters={}))
assert len(messages) == 1
msg = messages[0]
assert msg.type == ToolInvokeMessage.MessageType.BLOB
assert isinstance(msg.message, ToolInvokeMessage.BlobMessage)
assert msg.message.blob == raw
assert msg.meta == {"mime_type": mime_type}
def test_invoke_embedded_text_resource_yields_text(self) -> None:
tool = _make_mcp_tool()
text_resource = TextResourceContents(uri="file://test.txt", mimeType="text/plain", text="hello world")
content = EmbeddedResource(type="resource", resource=text_resource)
result = CallToolResult(content=[content])
with patch.object(tool, "invoke_remote_mcp_tool", return_value=result):
messages = list(tool._invoke(user_id="test_user", tool_parameters={}))
assert len(messages) == 1
msg = messages[0]
assert msg.type == ToolInvokeMessage.MessageType.TEXT
assert isinstance(msg.message, ToolInvokeMessage.TextMessage)
assert msg.message.text == "hello world"
@pytest.mark.parametrize(
("mime_type", "expected_mime"),
[("application/pdf", "application/pdf"), (None, "application/octet-stream")],
)
def test_invoke_embedded_blob_resource_yields_blob(self, mime_type, expected_mime) -> None:
tool = _make_mcp_tool()
raw = b"binary-data"
b64 = base64.b64encode(raw).decode()
blob_resource = BlobResourceContents(uri="file://doc.bin", mimeType=mime_type, blob=b64)
content = EmbeddedResource(type="resource", resource=blob_resource)
result = CallToolResult(content=[content])
with patch.object(tool, "invoke_remote_mcp_tool", return_value=result):
messages = list(tool._invoke(user_id="test_user", tool_parameters={}))
assert len(messages) == 1
msg = messages[0]
assert msg.type == ToolInvokeMessage.MessageType.BLOB
assert isinstance(msg.message, ToolInvokeMessage.BlobMessage)
assert msg.message.blob == raw
assert msg.meta == {"mime_type": expected_mime}
def test_invoke_yields_variables_when_structured_content_and_schema(self) -> None:
tool = _make_mcp_tool(output_schema={"type": "object"})
result = CallToolResult(content=[], structuredContent={"a": 1, "b": "x"})
with patch.object(tool, "invoke_remote_mcp_tool", return_value=result):
messages = list(tool._invoke(user_id="test_user", tool_parameters={}))
# Expect two variable messages corresponding to keys a and b
assert len(messages) == 2
var_msgs = [m for m in messages if isinstance(m.message, ToolInvokeMessage.VariableMessage)]
assert {m.message.variable_name for m in var_msgs} == {"a", "b"}
# Validate values
values = {m.message.variable_name: m.message.variable_value for m in var_msgs}
assert values == {"a": 1, "b": "x"}

12
api/uv.lock generated
View File

@ -1953,14 +1953,14 @@ wheels = [
[[package]]
name = "fickling"
version = "0.1.5"
version = "0.1.6"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "stdlib-list" },
]
sdist = { url = "https://files.pythonhosted.org/packages/41/94/0d0ce455952c036cfee235637f786c1d1d07d1b90f6a4dfb50e0eff929d6/fickling-0.1.5.tar.gz", hash = "sha256:92f9b49e717fa8dbc198b4b7b685587adb652d85aa9ede8131b3e44494efca05", size = 282462, upload-time = "2025-11-18T05:04:30.748Z" }
sdist = { url = "https://files.pythonhosted.org/packages/07/ab/7571453f9365c17c047b5a7b7e82692a7f6be51203f295030886758fd57a/fickling-0.1.6.tar.gz", hash = "sha256:03cb5d7bd09f9169c7583d2079fad4b3b88b25f865ed0049172e5cb68582311d", size = 284033, upload-time = "2025-12-15T18:14:58.721Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/bf/a7/d25912b2e3a5b0a37e6f460050bbc396042b5906a6563a1962c484abc3c6/fickling-0.1.5-py3-none-any.whl", hash = "sha256:6aed7270bfa276e188b0abe043a27b3a042129d28ec1fa6ff389bdcc5ad178bb", size = 46240, upload-time = "2025-11-18T05:04:29.048Z" },
{ url = "https://files.pythonhosted.org/packages/76/99/cc04258dda421bc612cdfe4be8c253f45b922f1c7f268b5a0b9962d9cd12/fickling-0.1.6-py3-none-any.whl", hash = "sha256:465d0069548bfc731bdd75a583cb4cf5a4b2666739c0f76287807d724b147ed3", size = 47922, upload-time = "2025-12-15T18:14:57.526Z" },
]
[[package]]
@ -3072,11 +3072,11 @@ wheels = [
[[package]]
name = "json-repair"
version = "0.54.3"
version = "0.54.1"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/b5/86/48b12ac02032f121ac7e5f11a32143edca6c1e3d19ffc54d6fb9ca0aafd0/json_repair-0.54.3.tar.gz", hash = "sha256:e50feec9725e52ac91f12184609754684ac1656119dfbd31de09bdaf9a1d8bf6", size = 38626, upload-time = "2025-12-15T09:41:58.594Z" }
sdist = { url = "https://files.pythonhosted.org/packages/00/46/d3a4d9a3dad39bb4a2ad16b8adb9fe2e8611b20b71197fe33daa6768e85d/json_repair-0.54.1.tar.gz", hash = "sha256:d010bc31f1fc66e7c36dc33bff5f8902674498ae5cb8e801ad455a53b455ad1d", size = 38555, upload-time = "2025-11-19T14:55:24.265Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/e9/08/abe317237add63c3e62f18a981bccf92112b431835b43d844aedaf61f4a0/json_repair-0.54.3-py3-none-any.whl", hash = "sha256:4cdc132ee27d4780576f71bf27a113877046224a808bfc17392e079cb344fb81", size = 29357, upload-time = "2025-12-15T09:41:57.436Z" },
{ url = "https://files.pythonhosted.org/packages/db/96/c9aad7ee949cc1bf15df91f347fbc2d3bd10b30b80c7df689ce6fe9332b5/json_repair-0.54.1-py3-none-any.whl", hash = "sha256:016160c5db5d5fe443164927bb58d2dfbba5f43ad85719fa9bc51c713a443ab1", size = 29311, upload-time = "2025-11-19T14:55:22.886Z" },
]
[[package]]

View File

@ -399,7 +399,6 @@ CONSOLE_CORS_ALLOW_ORIGINS=*
COOKIE_DOMAIN=
# When the frontend and backend run on different subdomains, set NEXT_PUBLIC_COOKIE_DOMAIN=1.
NEXT_PUBLIC_COOKIE_DOMAIN=
NEXT_PUBLIC_BATCH_CONCURRENCY=5
# ------------------------------
# File Storage Configuration

View File

@ -108,7 +108,6 @@ x-shared-env: &shared-api-worker-env
CONSOLE_CORS_ALLOW_ORIGINS: ${CONSOLE_CORS_ALLOW_ORIGINS:-*}
COOKIE_DOMAIN: ${COOKIE_DOMAIN:-}
NEXT_PUBLIC_COOKIE_DOMAIN: ${NEXT_PUBLIC_COOKIE_DOMAIN:-}
NEXT_PUBLIC_BATCH_CONCURRENCY: ${NEXT_PUBLIC_BATCH_CONCURRENCY:-5}
STORAGE_TYPE: ${STORAGE_TYPE:-opendal}
OPENDAL_SCHEME: ${OPENDAL_SCHEME:-fs}
OPENDAL_FS_ROOT: ${OPENDAL_FS_ROOT:-storage}

View File

@ -54,17 +54,17 @@
"publish:npm": "./scripts/publish.sh"
},
"dependencies": {
"axios": "^1.13.2"
"axios": "^1.3.5"
},
"devDependencies": {
"@eslint/js": "^9.39.2",
"@types/node": "^25.0.3",
"@eslint/js": "^9.2.0",
"@types/node": "^20.11.30",
"@typescript-eslint/eslint-plugin": "^8.50.1",
"@typescript-eslint/parser": "^8.50.1",
"@vitest/coverage-v8": "4.0.16",
"eslint": "^9.39.2",
"@vitest/coverage-v8": "1.6.1",
"eslint": "^9.2.0",
"tsup": "^8.5.1",
"typescript": "^5.9.3",
"vitest": "^4.0.16"
"typescript": "^5.4.5",
"vitest": "^1.5.0"
}
}

File diff suppressed because it is too large Load Diff

View File

@ -1,2 +0,0 @@
onlyBuiltDependencies:
- esbuild

View File

@ -73,6 +73,3 @@ NEXT_PUBLIC_MAX_TREE_DEPTH=50
# The API key of amplitude
NEXT_PUBLIC_AMPLITUDE_API_KEY=
# number of concurrency
NEXT_PUBLIC_BATCH_CONCURRENCY=5

View File

@ -1,5 +1,5 @@
# base image
FROM node:22-alpine3.21 AS base
FROM node:24-alpine AS base
LABEL maintainer="takatost@gmail.com"
# if you located in China, you can use aliyun mirror to speed up

View File

@ -8,8 +8,8 @@ This is a [Next.js](https://nextjs.org/) project bootstrapped with [`create-next
Before starting the web frontend service, please make sure the following environment is ready.
- [Node.js](https://nodejs.org) >= v22.11.x
- [pnpm](https://pnpm.io) v10.x
- [Node.js](https://nodejs.org)
- [pnpm](https://pnpm.io)
First, install the dependencies:

View File

@ -1,6 +1,5 @@
import type { ReactNode } from 'react'
import * as React from 'react'
import { AppInitializer } from '@/app/components/app-initializer'
import AmplitudeProvider from '@/app/components/base/amplitude'
import GA, { GaType } from '@/app/components/base/ga'
import Zendesk from '@/app/components/base/zendesk'
@ -8,6 +7,7 @@ import GotoAnything from '@/app/components/goto-anything'
import Header from '@/app/components/header'
import HeaderWrapper from '@/app/components/header/header-wrapper'
import ReadmePanel from '@/app/components/plugins/readme-panel'
import SwrInitializer from '@/app/components/swr-initializer'
import { AppContextProvider } from '@/context/app-context'
import { EventEmitterContextProvider } from '@/context/event-emitter'
import { ModalContextProvider } from '@/context/modal-context'
@ -20,7 +20,7 @@ const Layout = ({ children }: { children: ReactNode }) => {
<>
<GA gaType={GaType.admin} />
<AmplitudeProvider />
<AppInitializer>
<SwrInitializer>
<AppContextProvider>
<EventEmitterContextProvider>
<ProviderContextProvider>
@ -38,7 +38,7 @@ const Layout = ({ children }: { children: ReactNode }) => {
</EventEmitterContextProvider>
</AppContextProvider>
<Zendesk />
</AppInitializer>
</SwrInitializer>
</>
)
}

View File

@ -14,6 +14,7 @@ import { useWebAppStore } from '@/context/web-app-context'
import { sendWebAppEMailLoginCode, webAppEmailLoginWithCode } from '@/service/common'
import { fetchAccessToken } from '@/service/share'
import { setWebAppAccessToken, setWebAppPassport } from '@/service/webapp-auth'
import { encryptVerificationCode } from '@/utils/encryption'
export default function CheckCode() {
const { t } = useTranslation()
@ -64,7 +65,7 @@ export default function CheckCode() {
return
}
setIsLoading(true)
const ret = await webAppEmailLoginWithCode({ email, code, token })
const ret = await webAppEmailLoginWithCode({ email, code: encryptVerificationCode(code), token })
if (ret.result === 'success') {
setWebAppAccessToken(ret.data.access_token)
const { access_token } = await fetchAccessToken({

View File

@ -14,6 +14,7 @@ import { useWebAppStore } from '@/context/web-app-context'
import { webAppLogin } from '@/service/common'
import { fetchAccessToken } from '@/service/share'
import { setWebAppAccessToken, setWebAppPassport } from '@/service/webapp-auth'
import { encryptPassword } from '@/utils/encryption'
type MailAndPasswordAuthProps = {
isEmailSetup: boolean
@ -72,7 +73,7 @@ export default function MailAndPasswordAuth({ isEmailSetup }: MailAndPasswordAut
setIsLoading(true)
const loginData: Record<string, any> = {
email,
password,
password: encryptPassword(password),
language: locale,
remember_me: true,
}

View File

@ -1,9 +1,9 @@
import type { ReactNode } from 'react'
import * as React from 'react'
import { AppInitializer } from '@/app/components/app-initializer'
import AmplitudeProvider from '@/app/components/base/amplitude'
import GA, { GaType } from '@/app/components/base/ga'
import HeaderWrapper from '@/app/components/header/header-wrapper'
import SwrInitor from '@/app/components/swr-initializer'
import { AppContextProvider } from '@/context/app-context'
import { EventEmitterContextProvider } from '@/context/event-emitter'
import { ModalContextProvider } from '@/context/modal-context'
@ -15,7 +15,7 @@ const Layout = ({ children }: { children: ReactNode }) => {
<>
<GA gaType={GaType.admin} />
<AmplitudeProvider />
<AppInitializer>
<SwrInitor>
<AppContextProvider>
<EventEmitterContextProvider>
<ProviderContextProvider>
@ -30,7 +30,7 @@ const Layout = ({ children }: { children: ReactNode }) => {
</ProviderContextProvider>
</EventEmitterContextProvider>
</AppContextProvider>
</AppInitializer>
</SwrInitor>
</>
)
}

View File

@ -47,12 +47,6 @@ const getCheckboxDefaultSelectValue = (value: InputVar['default']) => {
const parseCheckboxSelectValue = (value: string) =>
value === CHECKBOX_DEFAULT_TRUE_VALUE
const normalizeSelectDefaultValue = (inputVar: InputVar) => {
if (inputVar.type === InputVarType.select && inputVar.default === '')
return { ...inputVar, default: undefined }
return inputVar
}
export type IConfigModalProps = {
isCreate?: boolean
payload?: InputVar
@ -73,7 +67,7 @@ const ConfigModal: FC<IConfigModalProps> = ({
}) => {
const { modelConfig } = useContext(ConfigContext)
const { t } = useTranslation()
const [tempPayload, setTempPayload] = useState<InputVar>(() => normalizeSelectDefaultValue(payload || getNewVarInWorkflow('') as any))
const [tempPayload, setTempPayload] = useState<InputVar>(() => payload || getNewVarInWorkflow('') as any)
const { type, label, variable, options, max_length } = tempPayload
const modalRef = useRef<HTMLDivElement>(null)
const appDetail = useAppStore(state => state.appDetail)
@ -188,8 +182,6 @@ const ConfigModal: FC<IConfigModalProps> = ({
const newPayload = produce(tempPayload, (draft) => {
draft.type = type
if (type === InputVarType.select)
draft.default = undefined
if ([InputVarType.singleFile, InputVarType.multiFiles].includes(type)) {
(Object.keys(DEFAULT_FILE_UPLOAD_SETTING)).forEach((key) => {
if (key !== 'max_length')

View File

@ -176,7 +176,7 @@ const DatasetConfig: FC = () => {
}))
}, [setDatasetConfigs, datasetConfigsRef])
const handleAddCondition = useCallback<HandleAddCondition>(({ id, name, type }) => {
const handleAddCondition = useCallback<HandleAddCondition>(({ name, type }) => {
let operator: ComparisonOperator = ComparisonOperator.is
if (type === MetadataFilteringVariableType.number)
@ -184,7 +184,6 @@ const DatasetConfig: FC = () => {
const newCondition = {
id: uuid4(),
metadata_id: id, // Save metadata.id for reliable reference
name,
comparison_operator: operator,
}

View File

@ -1,141 +0,0 @@
import type { DataSet } from '@/models/datasets'
import { act, fireEvent, render, screen } from '@testing-library/react'
import * as React from 'react'
import { describe, expect, it, vi } from 'vitest'
import { IndexingType } from '@/app/components/datasets/create/step-two'
import { DatasetPermission } from '@/models/datasets'
import { RETRIEVE_METHOD } from '@/types/app'
import SelectDataSet from './index'
vi.mock('@/i18n-config/i18next-config', () => ({
__esModule: true,
default: {
changeLanguage: vi.fn(),
addResourceBundle: vi.fn(),
use: vi.fn().mockReturnThis(),
init: vi.fn(),
addResource: vi.fn(),
hasResourceBundle: vi.fn().mockReturnValue(true),
},
}))
const mockUseInfiniteScroll = vi.fn()
vi.mock('ahooks', async (importOriginal) => {
const actual = await importOriginal()
return {
...(typeof actual === 'object' && actual !== null ? actual : {}),
useInfiniteScroll: (...args: any[]) => mockUseInfiniteScroll(...args),
}
})
const mockUseInfiniteDatasets = vi.fn()
vi.mock('@/service/knowledge/use-dataset', () => ({
useInfiniteDatasets: (...args: any[]) => mockUseInfiniteDatasets(...args),
}))
vi.mock('@/hooks/use-knowledge', () => ({
useKnowledge: () => ({
formatIndexingTechniqueAndMethod: (tech: string, method: string) => `${tech}:${method}`,
}),
}))
const baseProps = {
isShow: true,
onClose: vi.fn(),
selectedIds: [] as string[],
onSelect: vi.fn(),
}
const makeDataset = (overrides: Partial<DataSet>): DataSet => ({
id: 'dataset-id',
name: 'Dataset Name',
provider: 'internal',
icon_info: {
icon_type: 'emoji',
icon: '💾',
icon_background: '#fff',
icon_url: '',
},
embedding_available: true,
is_multimodal: false,
description: '',
permission: DatasetPermission.allTeamMembers,
indexing_technique: IndexingType.ECONOMICAL,
retrieval_model_dict: {
search_method: RETRIEVE_METHOD.fullText,
top_k: 5,
reranking_enable: false,
reranking_model: {
reranking_model_name: '',
reranking_provider_name: '',
},
score_threshold_enabled: false,
score_threshold: 0,
},
...overrides,
} as DataSet)
describe('SelectDataSet', () => {
beforeEach(() => {
vi.clearAllMocks()
})
it('renders dataset entries, allows selection, and fires onSelect', async () => {
const datasetOne = makeDataset({
id: 'set-1',
name: 'Dataset One',
is_multimodal: true,
indexing_technique: IndexingType.ECONOMICAL,
})
const datasetTwo = makeDataset({
id: 'set-2',
name: 'Hidden Dataset',
embedding_available: false,
provider: 'external',
})
mockUseInfiniteDatasets.mockReturnValue({
data: { pages: [{ data: [datasetOne, datasetTwo] }] },
isLoading: false,
isFetchingNextPage: false,
fetchNextPage: vi.fn(),
hasNextPage: false,
})
const onSelect = vi.fn()
await act(async () => {
render(<SelectDataSet {...baseProps} onSelect={onSelect} selectedIds={[]} />)
})
expect(screen.getByText('Dataset One')).toBeInTheDocument()
expect(screen.getByText('Hidden Dataset')).toBeInTheDocument()
await act(async () => {
fireEvent.click(screen.getByText('Dataset One'))
})
expect(screen.getByText('1 appDebug.feature.dataSet.selected')).toBeInTheDocument()
const addButton = screen.getByRole('button', { name: 'common.operation.add' })
await act(async () => {
fireEvent.click(addButton)
})
expect(onSelect).toHaveBeenCalledWith([datasetOne])
})
it('shows empty state when no datasets are available and disables add', async () => {
mockUseInfiniteDatasets.mockReturnValue({
data: { pages: [{ data: [] }] },
isLoading: false,
isFetchingNextPage: false,
fetchNextPage: vi.fn(),
hasNextPage: false,
})
await act(async () => {
render(<SelectDataSet {...baseProps} onSelect={vi.fn()} selectedIds={[]} />)
})
expect(screen.getByText('appDebug.feature.dataSet.noDataSet')).toBeInTheDocument()
expect(screen.getByRole('link', { name: 'appDebug.feature.dataSet.toCreate' })).toHaveAttribute('href', '/datasets/create')
expect(screen.getByRole('button', { name: 'common.operation.add' })).toBeDisabled()
})
})

View File

@ -679,7 +679,7 @@ const Configuration: FC = () => {
const toolInCollectionList = collectionList.find(c => tool.provider_id === c.id)
return {
...tool,
isDeleted: res.deleted_tools?.some((deletedTool: any) => deletedTool.provider_id === tool.provider_id && deletedTool.tool_name === tool.tool_name) ?? false,
isDeleted: res.deleted_tools?.some((deletedTool: any) => deletedTool.id === tool.id && deletedTool.tool_name === tool.tool_name) ?? false,
notAuthor: toolInCollectionList?.is_team_authorization === false,
...(tool.provider_type === 'builtin'
? {

View File

@ -1,125 +0,0 @@
import type { IPromptValuePanelProps } from './index'
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import * as React from 'react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { useStore } from '@/app/components/app/store'
import ConfigContext from '@/context/debug-configuration'
import { AppModeEnum, ModelModeType, Resolution } from '@/types/app'
import PromptValuePanel from './index'
vi.mock('@/app/components/app/store', () => ({
useStore: vi.fn(),
}))
vi.mock('@/app/components/base/features/new-feature-panel/feature-bar', () => ({
__esModule: true,
default: ({ onFeatureBarClick }: { onFeatureBarClick: () => void }) => (
<button type="button" onClick={onFeatureBarClick}>
feature bar
</button>
),
}))
const mockSetShowAppConfigureFeaturesModal = vi.fn()
const mockUseStore = vi.mocked(useStore)
const mockSetInputs = vi.fn()
const mockOnSend = vi.fn()
const promptVariables = [
{ key: 'textVar', name: 'Text Var', type: 'string', required: true },
{ key: 'boolVar', name: 'Boolean Var', type: 'checkbox' },
] as const
const baseContextValue: any = {
modelModeType: ModelModeType.completion,
modelConfig: {
configs: {
prompt_template: 'prompt template',
prompt_variables: promptVariables,
},
},
setInputs: mockSetInputs,
mode: AppModeEnum.COMPLETION,
isAdvancedMode: false,
completionPromptConfig: {
prompt: { text: 'completion' },
conversation_histories_role: { user_prefix: 'user', assistant_prefix: 'assistant' },
},
chatPromptConfig: { prompt: [] },
} as any
const defaultProps: IPromptValuePanelProps = {
appType: AppModeEnum.COMPLETION,
onSend: mockOnSend,
inputs: { textVar: 'initial', boolVar: false },
visionConfig: { enabled: false, number_limits: 0, detail: Resolution.low, transfer_methods: [] },
onVisionFilesChange: vi.fn(),
}
const renderPanel = (options: {
context?: Partial<typeof baseContextValue>
props?: Partial<IPromptValuePanelProps>
} = {}) => {
const contextValue = { ...baseContextValue, ...options.context }
const props = { ...defaultProps, ...options.props }
return render(
<ConfigContext.Provider value={contextValue}>
<PromptValuePanel {...props} />
</ConfigContext.Provider>,
)
}
describe('PromptValuePanel', () => {
beforeEach(() => {
mockUseStore.mockImplementation(selector => selector({
setShowAppConfigureFeaturesModal: mockSetShowAppConfigureFeaturesModal,
appSidebarExpand: '',
currentLogModalActiveTab: 'prompt',
showPromptLogModal: false,
showAgentLogModal: false,
setShowPromptLogModal: vi.fn(),
setShowAgentLogModal: vi.fn(),
showMessageLogModal: false,
showAppConfigureFeaturesModal: false,
} as any))
mockSetInputs.mockClear()
mockOnSend.mockClear()
mockSetShowAppConfigureFeaturesModal.mockClear()
})
it('updates inputs, clears values, and triggers run when ready', async () => {
renderPanel()
const textInput = screen.getByPlaceholderText('Text Var')
fireEvent.change(textInput, { target: { value: 'updated' } })
expect(mockSetInputs).toHaveBeenCalledWith(expect.objectContaining({ textVar: 'updated' }))
const clearButton = screen.getByRole('button', { name: 'common.operation.clear' })
fireEvent.click(clearButton)
expect(mockSetInputs).toHaveBeenLastCalledWith({
textVar: '',
boolVar: '',
})
const runButton = screen.getByRole('button', { name: 'appDebug.inputs.run' })
expect(runButton).not.toBeDisabled()
fireEvent.click(runButton)
await waitFor(() => expect(mockOnSend).toHaveBeenCalledTimes(1))
})
it('disables run when mode is not completion', () => {
renderPanel({
context: {
mode: AppModeEnum.CHAT,
},
props: {
appType: AppModeEnum.CHAT,
},
})
const runButton = screen.getByRole('button', { name: 'appDebug.inputs.run' })
expect(runButton).toBeDisabled()
fireEvent.click(runButton)
expect(mockOnSend).not.toHaveBeenCalled()
})
})

View File

@ -1,29 +0,0 @@
import type { PromptVariable } from '@/models/debug'
import { describe, expect, it } from 'vitest'
import { replaceStringWithValues } from './utils'
const promptVariables: PromptVariable[] = [
{ key: 'user', name: 'User', type: 'string' },
{ key: 'topic', name: 'Topic', type: 'string' },
]
describe('replaceStringWithValues', () => {
it('should replace placeholders when inputs have values', () => {
const template = 'Hello {{user}} talking about {{topic}}'
const result = replaceStringWithValues(template, promptVariables, { user: 'Alice', topic: 'cats' })
expect(result).toBe('Hello Alice talking about cats')
})
it('should use prompt variable name when value is missing', () => {
const template = 'Hi {{user}} from {{topic}}'
const result = replaceStringWithValues(template, promptVariables, {})
expect(result).toBe('Hi {{User}} from {{Topic}}')
})
it('should leave placeholder untouched when no variable is defined', () => {
const template = 'Unknown {{missing}} placeholder'
const result = replaceStringWithValues(template, promptVariables, {})
expect(result).toBe('Unknown {{missing}} placeholder')
})
})

View File

@ -1,136 +0,0 @@
import { fireEvent, render, screen } from '@testing-library/react'
import { AppModeEnum } from '@/types/app'
import Apps from './index'
const mockUseExploreAppList = vi.fn()
vi.mock('ahooks', () => ({
useDebounceFn: (fn: () => void) => ({
run: () => setTimeout(fn, 0),
cancel: vi.fn(),
flush: () => fn(),
}),
}))
vi.mock('@/context/app-context', () => ({
useAppContext: () => ({ isCurrentWorkspaceEditor: true }),
}))
vi.mock('use-context-selector', async () => {
const actual = await vi.importActual<typeof import('use-context-selector')>('use-context-selector')
return {
...actual,
useContext: () => ({ hasEditPermission: true }),
}
})
vi.mock('nuqs', () => ({
useQueryState: () => ['Recommended', vi.fn()],
}))
vi.mock('@/service/use-explore', () => ({
useExploreAppList: () => mockUseExploreAppList(),
}))
vi.mock('@/app/components/app/type-selector', () => ({
__esModule: true,
default: ({ value, onChange }: { value: AppModeEnum[], onChange: (value: AppModeEnum[]) => void }) => (
<button data-testid="type-selector" onClick={() => onChange([...value, 'chat' as AppModeEnum])}>{value.join(',')}</button>
),
}))
vi.mock('../app-card', () => ({
__esModule: true,
default: ({ app, onCreate }: { app: any, onCreate: () => void }) => (
<div
data-testid="app-card"
data-name={app.app.name}
onClick={onCreate}
>
{app.app.name}
</div>
),
}))
vi.mock('@/app/components/explore/create-app-modal', () => ({
__esModule: true,
default: () => <div data-testid="create-from-template-modal" />,
}))
vi.mock('@/app/components/base/toast', () => ({
default: { notify: vi.fn() },
}))
vi.mock('@/app/components/base/amplitude', () => ({
trackEvent: vi.fn(),
}))
vi.mock('@/service/apps', () => ({
importDSL: vi.fn().mockResolvedValue({ app_id: '1' }),
}))
vi.mock('@/service/explore', () => ({
fetchAppDetail: vi.fn().mockResolvedValue({
export_data: 'dsl',
mode: 'chat',
}),
}))
vi.mock('@/app/components/workflow/plugin-dependency/hooks', () => ({
usePluginDependencies: () => ({
handleCheckPluginDependencies: vi.fn(),
}),
}))
vi.mock('@/utils/app-redirection', () => ({
getRedirection: vi.fn(),
}))
vi.mock('next/navigation', () => ({
useRouter: () => ({ push: vi.fn() }),
}))
const createAppEntry = (name: string, category: string) => ({
app_id: name,
category,
app: {
id: name,
name,
icon_type: 'emoji',
icon: '🙂',
icon_background: '#000',
icon_url: null,
description: 'desc',
mode: AppModeEnum.CHAT,
},
})
describe('Apps', () => {
const defaultData = {
allList: [
createAppEntry('Alpha', 'Cat A'),
createAppEntry('Bravo', 'Cat B'),
],
categories: ['Cat A', 'Cat B'],
}
beforeEach(() => {
vi.clearAllMocks()
mockUseExploreAppList.mockReturnValue({
data: defaultData,
isLoading: false,
})
})
it('renders template cards when data is available', () => {
render(<Apps />)
expect(screen.getAllByTestId('app-card')).toHaveLength(2)
expect(screen.getByText('Alpha')).toBeInTheDocument()
expect(screen.getByText('Bravo')).toBeInTheDocument()
})
it('opens create modal when a template card is clicked', () => {
render(<Apps />)
fireEvent.click(screen.getAllByTestId('app-card')[0])
expect(screen.getByTestId('create-from-template-modal')).toBeInTheDocument()
})
it('shows no template message when list is empty', () => {
mockUseExploreAppList.mockReturnValueOnce({
data: { allList: [], categories: [] },
isLoading: false,
})
render(<Apps />)
expect(screen.getByText('app.newApp.noTemplateFound')).toBeInTheDocument()
expect(screen.getByText('app.newApp.noTemplateFoundTip')).toBeInTheDocument()
})
})

View File

@ -8,7 +8,6 @@ import { useRouter } from 'next/navigation'
import * as React from 'react'
import { useMemo, useState } from 'react'
import { useTranslation } from 'react-i18next'
import { useContext } from 'use-context-selector'
import AppTypeSelector from '@/app/components/app/type-selector'
import { trackEvent } from '@/app/components/base/amplitude'
import Divider from '@/app/components/base/divider'
@ -19,7 +18,7 @@ import CreateAppModal from '@/app/components/explore/create-app-modal'
import { usePluginDependencies } from '@/app/components/workflow/plugin-dependency/hooks'
import { NEED_REFRESH_APP_LIST_KEY } from '@/config'
import { useAppContext } from '@/context/app-context'
import ExploreContext from '@/context/explore-context'
import { useTabSearchParams } from '@/hooks/use-tab-searchparams'
import { DSLImportMode } from '@/models/app'
import { importDSL } from '@/service/apps'
import { fetchAppDetail } from '@/service/explore'
@ -47,7 +46,6 @@ const Apps = ({
const { t } = useTranslation()
const { isCurrentWorkspaceEditor } = useAppContext()
const { push } = useRouter()
const { hasEditPermission } = useContext(ExploreContext)
const allCategoriesEn = AppCategories.RECOMMENDED
const [keywords, setKeywords] = useState('')
@ -63,7 +61,10 @@ const Apps = ({
}
const [currentType, setCurrentType] = useState<AppModeEnum[]>([])
const [currCategory, setCurrCategory] = useState<AppCategories | string>(allCategoriesEn)
const [currCategory, setCurrCategory] = useTabSearchParams({
defaultTab: allCategoriesEn,
disableSearchParams: true,
})
const {
data,
@ -214,7 +215,7 @@ const Apps = ({
<AppCard
key={app.app_id}
app={app}
canCreate={hasEditPermission}
canCreate={isCurrentWorkspaceEditor}
onCreate={() => {
setCurrApp(app)
setIsShowCreateModal(true)

View File

@ -1,38 +0,0 @@
import { fireEvent, render, screen } from '@testing-library/react'
import Sidebar, { AppCategories } from './sidebar'
vi.mock('@remixicon/react', () => ({
RiStickyNoteAddLine: () => <span>sticky</span>,
RiThumbUpLine: () => <span>thumb</span>,
}))
describe('Sidebar', () => {
it('renders recommended and custom categories', () => {
render(<Sidebar current={AppCategories.RECOMMENDED} categories={['Cat A', 'Cat B']} />)
expect(screen.getByText('app.newAppFromTemplate.sidebar.Recommended')).toBeInTheDocument()
expect(screen.getByText('Cat A')).toBeInTheDocument()
expect(screen.getByText('Cat B')).toBeInTheDocument()
})
it('notifies callbacks when items are clicked', () => {
const onClick = vi.fn()
const onCreate = vi.fn()
render(
<Sidebar
current="Cat A"
categories={['Cat A']}
onClick={onClick}
onCreateFromBlank={onCreate}
/>,
)
fireEvent.click(screen.getByText('app.newAppFromTemplate.sidebar.Recommended'))
expect(onClick).toHaveBeenCalledWith(AppCategories.RECOMMENDED)
fireEvent.click(screen.getByText('Cat A'))
expect(onClick).toHaveBeenCalledWith('Cat A')
fireEvent.click(screen.getByText('app.newApp.startFromBlank'))
expect(onCreate).toHaveBeenCalled()
})
})

View File

@ -1,162 +0,0 @@
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import { useRouter } from 'next/navigation'
import { afterAll, beforeEach, describe, expect, it, vi } from 'vitest'
import { trackEvent } from '@/app/components/base/amplitude'
import { ToastContext } from '@/app/components/base/toast'
import { NEED_REFRESH_APP_LIST_KEY } from '@/config'
import { useAppContext } from '@/context/app-context'
import { useProviderContext } from '@/context/provider-context'
import { createApp } from '@/service/apps'
import { AppModeEnum } from '@/types/app'
import { getRedirection } from '@/utils/app-redirection'
import CreateAppModal from './index'
vi.mock('ahooks', () => ({
useDebounceFn: (fn: (...args: any[]) => any) => {
const run = (...args: any[]) => fn(...args)
const cancel = vi.fn()
const flush = vi.fn()
return { run, cancel, flush }
},
useKeyPress: vi.fn(),
useHover: () => false,
}))
vi.mock('next/navigation', () => ({
useRouter: vi.fn(),
}))
vi.mock('@/app/components/base/amplitude', () => ({
trackEvent: vi.fn(),
}))
vi.mock('@/service/apps', () => ({
createApp: vi.fn(),
}))
vi.mock('@/utils/app-redirection', () => ({
getRedirection: vi.fn(),
}))
vi.mock('@/context/provider-context', () => ({
useProviderContext: vi.fn(),
}))
vi.mock('@/context/app-context', () => ({
useAppContext: vi.fn(),
}))
vi.mock('@/context/i18n', () => ({
useDocLink: () => () => '/guides',
}))
vi.mock('@/hooks/use-theme', () => ({
__esModule: true,
default: () => ({ theme: 'light' }),
}))
const mockNotify = vi.fn()
const mockUseRouter = vi.mocked(useRouter)
const mockPush = vi.fn()
const mockCreateApp = vi.mocked(createApp)
const mockTrackEvent = vi.mocked(trackEvent)
const mockGetRedirection = vi.mocked(getRedirection)
const mockUseProviderContext = vi.mocked(useProviderContext)
const mockUseAppContext = vi.mocked(useAppContext)
const defaultPlanUsage = {
buildApps: 0,
teamMembers: 0,
annotatedResponse: 0,
documentsUploadQuota: 0,
apiRateLimit: 0,
triggerEvents: 0,
vectorSpace: 0,
}
const renderModal = () => {
const onClose = vi.fn()
const onSuccess = vi.fn()
render(
<ToastContext.Provider value={{ notify: mockNotify, close: vi.fn() }}>
<CreateAppModal show onClose={onClose} onSuccess={onSuccess} defaultAppMode={AppModeEnum.ADVANCED_CHAT} />
</ToastContext.Provider>,
)
return { onClose, onSuccess }
}
describe('CreateAppModal', () => {
const mockSetItem = vi.fn()
const originalLocalStorage = window.localStorage
beforeEach(() => {
vi.clearAllMocks()
mockUseRouter.mockReturnValue({ push: mockPush } as any)
mockUseProviderContext.mockReturnValue({
plan: {
type: AppModeEnum.ADVANCED_CHAT,
usage: defaultPlanUsage,
total: { ...defaultPlanUsage, buildApps: 1 },
reset: {},
},
enableBilling: true,
} as any)
mockUseAppContext.mockReturnValue({
isCurrentWorkspaceEditor: true,
} as any)
mockSetItem.mockClear()
Object.defineProperty(window, 'localStorage', {
value: {
setItem: mockSetItem,
getItem: vi.fn(),
removeItem: vi.fn(),
clear: vi.fn(),
key: vi.fn(),
length: 0,
},
writable: true,
})
})
afterAll(() => {
Object.defineProperty(window, 'localStorage', {
value: originalLocalStorage,
writable: true,
})
})
it('creates an app, notifies success, and fires callbacks', async () => {
const mockApp = { id: 'app-1', mode: AppModeEnum.ADVANCED_CHAT }
mockCreateApp.mockResolvedValue(mockApp as any)
const { onClose, onSuccess } = renderModal()
const nameInput = screen.getByPlaceholderText('app.newApp.appNamePlaceholder')
fireEvent.change(nameInput, { target: { value: 'My App' } })
fireEvent.click(screen.getByRole('button', { name: 'app.newApp.Create' }))
await waitFor(() => expect(mockCreateApp).toHaveBeenCalledWith({
name: 'My App',
description: '',
icon_type: 'emoji',
icon: '🤖',
icon_background: '#FFEAD5',
mode: AppModeEnum.ADVANCED_CHAT,
}))
expect(mockTrackEvent).toHaveBeenCalledWith('create_app', {
app_mode: AppModeEnum.ADVANCED_CHAT,
description: '',
})
expect(mockNotify).toHaveBeenCalledWith({ type: 'success', message: 'app.newApp.appCreated' })
expect(onSuccess).toHaveBeenCalled()
expect(onClose).toHaveBeenCalled()
await waitFor(() => expect(mockSetItem).toHaveBeenCalledWith(NEED_REFRESH_APP_LIST_KEY, '1'))
await waitFor(() => expect(mockGetRedirection).toHaveBeenCalledWith(true, mockApp, mockPush))
})
it('shows error toast when creation fails', async () => {
mockCreateApp.mockRejectedValue(new Error('boom'))
const { onClose } = renderModal()
const nameInput = screen.getByPlaceholderText('app.newApp.appNamePlaceholder')
fireEvent.change(nameInput, { target: { value: 'My App' } })
fireEvent.click(screen.getByRole('button', { name: 'app.newApp.Create' }))
await waitFor(() => expect(mockCreateApp).toHaveBeenCalled())
expect(mockNotify).toHaveBeenCalledWith({ type: 'error', message: 'boom' })
expect(onClose).not.toHaveBeenCalled()
})
})

View File

@ -139,14 +139,14 @@ const getFormattedChatList = (messages: ChatMessage[], conversationId: string, t
id: item.id,
content: item.answer,
agent_thoughts: addFileInfos(item.agent_thoughts ? sortAgentSorts(item.agent_thoughts) : item.agent_thoughts, item.message_files),
feedback: item.feedbacks?.find(item => item.from_source === 'user'), // user feedback
adminFeedback: item.feedbacks?.find(item => item.from_source === 'admin'), // admin feedback
feedback: item.feedbacks.find(item => item.from_source === 'user'), // user feedback
adminFeedback: item.feedbacks.find(item => item.from_source === 'admin'), // admin feedback
feedbackDisabled: false,
isAnswer: true,
message_files: getProcessedFilesFromResponse(answerFiles.map((item: any) => ({ ...item, related_id: item.id }))),
log: [
...(item.message ?? []),
...(item.message?.[item.message.length - 1]?.role !== 'assistant'
...item.message,
...(item.message[item.message.length - 1]?.role !== 'assistant'
? [
{
role: 'assistant',
@ -165,7 +165,7 @@ const getFormattedChatList = (messages: ChatMessage[], conversationId: string, t
more: {
time: dayjs.unix(item.created_at).tz(timezone).format(format),
tokens: item.answer_tokens + item.message_tokens,
latency: (item.provider_response_latency ?? 0).toFixed(2),
latency: item.provider_response_latency.toFixed(2),
},
citation: item.metadata?.retriever_resources,
annotation: (() => {

View File

@ -1,121 +0,0 @@
import type { SiteInfo } from '@/models/share'
import { fireEvent, render, screen } from '@testing-library/react'
import copy from 'copy-to-clipboard'
import * as React from 'react'
import { act } from 'react'
import { afterAll, afterEach, describe, expect, it, vi } from 'vitest'
import Embedded from './index'
vi.mock('./style.module.css', () => ({
__esModule: true,
default: {
option: 'option',
active: 'active',
iframeIcon: 'iframeIcon',
scriptsIcon: 'scriptsIcon',
chromePluginIcon: 'chromePluginIcon',
pluginInstallIcon: 'pluginInstallIcon',
},
}))
const mockThemeBuilder = {
buildTheme: vi.fn(),
theme: {
primaryColor: '#123456',
},
}
const mockUseAppContext = vi.fn(() => ({
langGeniusVersionInfo: {
current_env: 'PRODUCTION',
current_version: '',
latest_version: '',
release_date: '',
release_notes: '',
version: '',
can_auto_update: false,
},
}))
vi.mock('copy-to-clipboard', () => ({
__esModule: true,
default: vi.fn(),
}))
vi.mock('@/app/components/base/chat/embedded-chatbot/theme/theme-context', () => ({
useThemeContext: () => mockThemeBuilder,
}))
vi.mock('@/context/app-context', () => ({
useAppContext: () => mockUseAppContext(),
}))
const mockWindowOpen = vi.spyOn(window, 'open').mockImplementation(() => null)
const mockedCopy = vi.mocked(copy)
const siteInfo: SiteInfo = {
title: 'test site',
chat_color_theme: '#000000',
chat_color_theme_inverted: false,
}
const baseProps = {
isShow: true,
siteInfo,
onClose: vi.fn(),
appBaseUrl: 'https://app.example.com',
accessToken: 'token',
className: 'custom-modal',
}
const getCopyButton = () => {
const buttons = screen.getAllByRole('button')
const actionButton = buttons.find(button => button.className.includes('action-btn'))
expect(actionButton).toBeDefined()
return actionButton!
}
describe('Embedded', () => {
afterEach(() => {
vi.clearAllMocks()
mockWindowOpen.mockClear()
})
afterAll(() => {
mockWindowOpen.mockRestore()
})
it('builds theme and copies iframe snippet', async () => {
await act(async () => {
render(<Embedded {...baseProps} />)
})
const actionButton = getCopyButton()
const innerDiv = actionButton.querySelector('div')
act(() => {
fireEvent.click(innerDiv ?? actionButton)
})
expect(mockThemeBuilder.buildTheme).toHaveBeenCalledWith(siteInfo.chat_color_theme, siteInfo.chat_color_theme_inverted)
expect(mockedCopy).toHaveBeenCalledWith(expect.stringContaining('/chatbot/token'))
})
it('opens chrome plugin store link when chrome option selected', async () => {
await act(async () => {
render(<Embedded {...baseProps} />)
})
const optionButtons = document.body.querySelectorAll('[class*="option"]')
expect(optionButtons.length).toBeGreaterThanOrEqual(3)
act(() => {
fireEvent.click(optionButtons[2])
})
const [chromeText] = screen.getAllByText('appOverview.overview.appInfo.embedded.chromePlugin')
act(() => {
fireEvent.click(chromeText)
})
expect(mockWindowOpen).toHaveBeenCalledWith(
'https://chrome.google.com/webstore/detail/dify-chatbot/ceehdapohffmjmkdcifjofadiaoeggaf',
'_blank',
'noopener,noreferrer',
)
})
})

View File

@ -1,217 +0,0 @@
import type { ReactNode } from 'react'
import type { ModalContextState } from '@/context/modal-context'
import type { ProviderContextState } from '@/context/provider-context'
import type { AppDetailResponse } from '@/models/app'
import type { AppSSO } from '@/types/app'
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import { Plan } from '@/app/components/billing/type'
import { baseProviderContextValue } from '@/context/provider-context'
import { AppModeEnum } from '@/types/app'
import SettingsModal from './index'
vi.mock('react-i18next', async () => {
const actual = await vi.importActual<typeof import('react-i18next')>('react-i18next')
return {
...actual,
useTranslation: () => ({
t: (key: string, options?: Record<string, unknown>) => {
if (options?.returnObjects)
return [`${key}-feature-1`, `${key}-feature-2`]
if (options)
return `${key}:${JSON.stringify(options)}`
return key
},
i18n: {
language: 'en',
changeLanguage: vi.fn(),
},
}),
Trans: ({ children }: { children?: ReactNode }) => <>{children}</>,
}
})
const mockNotify = vi.fn()
const mockOnClose = vi.fn()
const mockOnSave = vi.fn()
const mockSetShowPricingModal = vi.fn()
const mockSetShowAccountSettingModal = vi.fn()
const mockUseProviderContext = vi.fn<() => ProviderContextState>()
const buildModalContext = (): ModalContextState => ({
setShowAccountSettingModal: mockSetShowAccountSettingModal,
setShowApiBasedExtensionModal: vi.fn(),
setShowModerationSettingModal: vi.fn(),
setShowExternalDataToolModal: vi.fn(),
setShowPricingModal: mockSetShowPricingModal,
setShowAnnotationFullModal: vi.fn(),
setShowModelModal: vi.fn(),
setShowExternalKnowledgeAPIModal: vi.fn(),
setShowModelLoadBalancingModal: vi.fn(),
setShowOpeningModal: vi.fn(),
setShowUpdatePluginModal: vi.fn(),
setShowEducationExpireNoticeModal: vi.fn(),
setShowTriggerEventsLimitModal: vi.fn(),
})
vi.mock('@/context/modal-context', () => ({
useModalContext: () => buildModalContext(),
}))
vi.mock('@/app/components/base/toast', async () => {
const actual = await vi.importActual<typeof import('@/app/components/base/toast')>('@/app/components/base/toast')
return {
...actual,
useToastContext: () => ({
notify: mockNotify,
close: vi.fn(),
}),
}
})
vi.mock('@/context/i18n', async () => {
const actual = await vi.importActual<typeof import('@/context/i18n')>('@/context/i18n')
return {
...actual,
useDocLink: () => (path?: string) => `https://docs.example.com${path ?? ''}`,
}
})
vi.mock('@/context/provider-context', async () => {
const actual = await vi.importActual<typeof import('@/context/provider-context')>('@/context/provider-context')
return {
...actual,
useProviderContext: () => mockUseProviderContext(),
}
})
const mockAppInfo = {
site: {
title: 'Test App',
icon_type: 'emoji',
icon: '😀',
icon_background: '#ABCDEF',
icon_url: 'https://example.com/icon.png',
description: 'A description',
chat_color_theme: '#123456',
chat_color_theme_inverted: true,
copyright: '© Dify',
privacy_policy: '',
custom_disclaimer: 'Disclaimer',
default_language: 'en-US',
show_workflow_steps: true,
use_icon_as_answer_icon: true,
},
mode: AppModeEnum.ADVANCED_CHAT,
enable_sso: false,
} as unknown as AppDetailResponse & Partial<AppSSO>
const renderSettingsModal = () => render(
<SettingsModal
isChat
isShow
appInfo={mockAppInfo}
onClose={mockOnClose}
onSave={mockOnSave}
/>,
)
describe('SettingsModal', () => {
beforeEach(() => {
mockNotify.mockClear()
mockOnClose.mockClear()
mockOnSave.mockClear()
mockSetShowPricingModal.mockClear()
mockSetShowAccountSettingModal.mockClear()
mockUseProviderContext.mockReturnValue({
...baseProviderContextValue,
enableBilling: true,
plan: {
...baseProviderContextValue.plan,
type: Plan.sandbox,
},
webappCopyrightEnabled: true,
})
})
it('should render the modal and expose the expanded settings section', async () => {
renderSettingsModal()
expect(screen.getByText('appOverview.overview.appInfo.settings.title')).toBeInTheDocument()
const showMoreEntry = screen.getByText('appOverview.overview.appInfo.settings.more.entry')
fireEvent.click(showMoreEntry)
await waitFor(() => {
expect(screen.getByPlaceholderText('appOverview.overview.appInfo.settings.more.copyRightPlaceholder')).toBeInTheDocument()
expect(screen.getByPlaceholderText('appOverview.overview.appInfo.settings.more.privacyPolicyPlaceholder')).toBeInTheDocument()
})
})
it('should notify the user when the name is empty', async () => {
renderSettingsModal()
const nameInput = screen.getByPlaceholderText('app.appNamePlaceholder')
fireEvent.change(nameInput, { target: { value: '' } })
fireEvent.click(screen.getByText('common.operation.save'))
await waitFor(() => {
expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ message: 'app.newApp.nameNotEmpty' }))
})
expect(mockOnSave).not.toHaveBeenCalled()
})
it('should validate the theme color and show an error when the hex is invalid', async () => {
renderSettingsModal()
const colorInput = screen.getByPlaceholderText('E.g #A020F0')
fireEvent.change(colorInput, { target: { value: 'not-a-hex' } })
fireEvent.click(screen.getByText('common.operation.save'))
await waitFor(() => {
expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({
message: 'appOverview.overview.appInfo.settings.invalidHexMessage',
}))
})
expect(mockOnSave).not.toHaveBeenCalled()
})
it('should validate the privacy policy URL when advanced settings are open', async () => {
renderSettingsModal()
fireEvent.click(screen.getByText('appOverview.overview.appInfo.settings.more.entry'))
const privacyInput = screen.getByPlaceholderText('appOverview.overview.appInfo.settings.more.privacyPolicyPlaceholder')
// eslint-disable-next-line sonarjs/no-clear-text-protocols
fireEvent.change(privacyInput, { target: { value: 'ftp://invalid-url' } })
fireEvent.click(screen.getByText('common.operation.save'))
await waitFor(() => {
expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({
message: 'appOverview.overview.appInfo.settings.invalidPrivacyPolicy',
}))
})
expect(mockOnSave).not.toHaveBeenCalled()
})
it('should save valid settings and close the modal', async () => {
mockOnSave.mockResolvedValueOnce(undefined)
renderSettingsModal()
fireEvent.click(screen.getByText('common.operation.save'))
await waitFor(() => expect(mockOnSave).toHaveBeenCalled())
expect(mockOnSave).toHaveBeenCalledWith(expect.objectContaining({
title: mockAppInfo.site.title,
description: mockAppInfo.site.description,
default_language: mockAppInfo.site.default_language,
chat_color_theme: mockAppInfo.site.chat_color_theme,
chat_color_theme_inverted: mockAppInfo.site.chat_color_theme_inverted,
prompt_public: false,
copyright: mockAppInfo.site.copyright,
privacy_policy: mockAppInfo.site.privacy_policy,
custom_disclaimer: mockAppInfo.site.custom_disclaimer,
icon_type: 'emoji',
icon: mockAppInfo.site.icon,
icon_background: mockAppInfo.site.icon_background,
show_workflow_steps: mockAppInfo.site.show_workflow_steps,
use_icon_as_answer_icon: mockAppInfo.site.use_icon_as_answer_icon,
enable_sso: mockAppInfo.enable_sso,
}))
expect(mockOnClose).toHaveBeenCalled()
})
})

View File

@ -1,67 +0,0 @@
import type { ISavedItemsProps } from './index'
import { fireEvent, render, screen } from '@testing-library/react'
import copy from 'copy-to-clipboard'
import * as React from 'react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import Toast from '@/app/components/base/toast'
import SavedItems from './index'
vi.mock('copy-to-clipboard', () => ({
__esModule: true,
default: vi.fn(),
}))
vi.mock('next/navigation', () => ({
useParams: () => ({}),
usePathname: () => '/',
}))
const mockCopy = vi.mocked(copy)
const toastNotifySpy = vi.spyOn(Toast, 'notify')
const baseProps: ISavedItemsProps = {
list: [
{ id: '1', answer: 'hello world' },
],
isShowTextToSpeech: true,
onRemove: vi.fn(),
onStartCreateContent: vi.fn(),
}
describe('SavedItems', () => {
beforeEach(() => {
vi.clearAllMocks()
toastNotifySpy.mockClear()
})
it('renders saved answers with metadata and controls', () => {
const { container } = render(<SavedItems {...baseProps} />)
const markdownElement = container.querySelector('.markdown-body')
expect(markdownElement).toBeInTheDocument()
expect(screen.getByText('11 common.unit.char')).toBeInTheDocument()
const actionArea = container.querySelector('[class*="bg-components-actionbar-bg"]')
const actionButtons = actionArea?.querySelectorAll('button') ?? []
expect(actionButtons.length).toBeGreaterThanOrEqual(3)
})
it('copies content and notifies, and triggers remove callback', () => {
const handleRemove = vi.fn()
const { container } = render(<SavedItems {...baseProps} onRemove={handleRemove} />)
const actionArea = container.querySelector('[class*="bg-components-actionbar-bg"]')
const actionButtons = actionArea?.querySelectorAll('button') ?? []
expect(actionButtons.length).toBeGreaterThanOrEqual(3)
const copyButton = actionButtons[1]
const deleteButton = actionButtons[2]
fireEvent.click(copyButton)
expect(mockCopy).toHaveBeenCalledWith('hello world')
expect(toastNotifySpy).toHaveBeenCalledWith({ type: 'success', message: 'common.actionMsg.copySuccessfully' })
fireEvent.click(deleteButton)
expect(handleRemove).toHaveBeenCalledWith('1')
})
})

View File

@ -1,22 +0,0 @@
import { fireEvent, render, screen } from '@testing-library/react'
import { describe, expect, it, vi } from 'vitest'
import NoData from './index'
describe('NoData', () => {
it('renders title/description and calls callback when button clicked', () => {
const handleStart = vi.fn()
render(<NoData onStartCreateContent={handleStart} />)
const title = screen.getByText('share.generation.savedNoData.title')
const description = screen.getByText('share.generation.savedNoData.description')
const button = screen.getByRole('button', { name: 'share.generation.savedNoData.startCreateContent' })
expect(title).toBeInTheDocument()
expect(description).toBeInTheDocument()
expect(button).toBeInTheDocument()
fireEvent.click(button)
expect(handleStart).toHaveBeenCalledTimes(1)
})
})

View File

@ -0,0 +1,363 @@
/**
* Test suite for useAppsQueryState hook
*
* This hook manages app filtering state through URL search parameters, enabling:
* - Bookmarkable filter states (users can share URLs with specific filters active)
* - Browser history integration (back/forward buttons work with filters)
* - Multiple filter types: tagIDs, keywords, isCreatedByMe
*
* The hook syncs local filter state with URL search parameters, making filter
* navigation persistent and shareable across sessions.
*/
import { act, renderHook } from '@testing-library/react'
// Import the hook after mocks are set up
import useAppsQueryState from './use-apps-query-state'
// Mock Next.js navigation hooks
const mockPush = vi.fn()
const mockPathname = '/apps'
let mockSearchParams = new URLSearchParams()
vi.mock('next/navigation', () => ({
usePathname: vi.fn(() => mockPathname),
useRouter: vi.fn(() => ({
push: mockPush,
})),
useSearchParams: vi.fn(() => mockSearchParams),
}))
describe('useAppsQueryState', () => {
beforeEach(() => {
vi.clearAllMocks()
mockSearchParams = new URLSearchParams()
})
describe('Basic functionality', () => {
it('should return query object and setQuery function', () => {
const { result } = renderHook(() => useAppsQueryState())
expect(result.current.query).toBeDefined()
expect(typeof result.current.setQuery).toBe('function')
})
it('should initialize with empty query when no search params exist', () => {
const { result } = renderHook(() => useAppsQueryState())
expect(result.current.query.tagIDs).toBeUndefined()
expect(result.current.query.keywords).toBeUndefined()
expect(result.current.query.isCreatedByMe).toBe(false)
})
})
describe('Parsing search params', () => {
it('should parse tagIDs from URL', () => {
mockSearchParams.set('tagIDs', 'tag1;tag2;tag3')
const { result } = renderHook(() => useAppsQueryState())
expect(result.current.query.tagIDs).toEqual(['tag1', 'tag2', 'tag3'])
})
it('should parse single tagID from URL', () => {
mockSearchParams.set('tagIDs', 'single-tag')
const { result } = renderHook(() => useAppsQueryState())
expect(result.current.query.tagIDs).toEqual(['single-tag'])
})
it('should parse keywords from URL', () => {
mockSearchParams.set('keywords', 'search term')
const { result } = renderHook(() => useAppsQueryState())
expect(result.current.query.keywords).toBe('search term')
})
it('should parse isCreatedByMe as true from URL', () => {
mockSearchParams.set('isCreatedByMe', 'true')
const { result } = renderHook(() => useAppsQueryState())
expect(result.current.query.isCreatedByMe).toBe(true)
})
it('should parse isCreatedByMe as false for other values', () => {
mockSearchParams.set('isCreatedByMe', 'false')
const { result } = renderHook(() => useAppsQueryState())
expect(result.current.query.isCreatedByMe).toBe(false)
})
it('should parse all params together', () => {
mockSearchParams.set('tagIDs', 'tag1;tag2')
mockSearchParams.set('keywords', 'test')
mockSearchParams.set('isCreatedByMe', 'true')
const { result } = renderHook(() => useAppsQueryState())
expect(result.current.query.tagIDs).toEqual(['tag1', 'tag2'])
expect(result.current.query.keywords).toBe('test')
expect(result.current.query.isCreatedByMe).toBe(true)
})
})
describe('Updating query state', () => {
it('should update keywords via setQuery', () => {
const { result } = renderHook(() => useAppsQueryState())
act(() => {
result.current.setQuery({ keywords: 'new search' })
})
expect(result.current.query.keywords).toBe('new search')
})
it('should update tagIDs via setQuery', () => {
const { result } = renderHook(() => useAppsQueryState())
act(() => {
result.current.setQuery({ tagIDs: ['tag1', 'tag2'] })
})
expect(result.current.query.tagIDs).toEqual(['tag1', 'tag2'])
})
it('should update isCreatedByMe via setQuery', () => {
const { result } = renderHook(() => useAppsQueryState())
act(() => {
result.current.setQuery({ isCreatedByMe: true })
})
expect(result.current.query.isCreatedByMe).toBe(true)
})
it('should support partial updates via callback', () => {
const { result } = renderHook(() => useAppsQueryState())
act(() => {
result.current.setQuery({ keywords: 'initial' })
})
act(() => {
result.current.setQuery(prev => ({ ...prev, isCreatedByMe: true }))
})
expect(result.current.query.keywords).toBe('initial')
expect(result.current.query.isCreatedByMe).toBe(true)
})
})
describe('URL synchronization', () => {
it('should sync keywords to URL', async () => {
const { result } = renderHook(() => useAppsQueryState())
act(() => {
result.current.setQuery({ keywords: 'search' })
})
// Wait for useEffect to run
await act(async () => {
await new Promise(resolve => setTimeout(resolve, 0))
})
expect(mockPush).toHaveBeenCalledWith(
expect.stringContaining('keywords=search'),
{ scroll: false },
)
})
it('should sync tagIDs to URL with semicolon separator', async () => {
const { result } = renderHook(() => useAppsQueryState())
act(() => {
result.current.setQuery({ tagIDs: ['tag1', 'tag2'] })
})
await act(async () => {
await new Promise(resolve => setTimeout(resolve, 0))
})
expect(mockPush).toHaveBeenCalledWith(
expect.stringContaining('tagIDs=tag1%3Btag2'),
{ scroll: false },
)
})
it('should sync isCreatedByMe to URL', async () => {
const { result } = renderHook(() => useAppsQueryState())
act(() => {
result.current.setQuery({ isCreatedByMe: true })
})
await act(async () => {
await new Promise(resolve => setTimeout(resolve, 0))
})
expect(mockPush).toHaveBeenCalledWith(
expect.stringContaining('isCreatedByMe=true'),
{ scroll: false },
)
})
it('should remove keywords from URL when empty', async () => {
mockSearchParams.set('keywords', 'existing')
const { result } = renderHook(() => useAppsQueryState())
act(() => {
result.current.setQuery({ keywords: '' })
})
await act(async () => {
await new Promise(resolve => setTimeout(resolve, 0))
})
// Should be called without keywords param
expect(mockPush).toHaveBeenCalled()
})
it('should remove tagIDs from URL when empty array', async () => {
mockSearchParams.set('tagIDs', 'tag1;tag2')
const { result } = renderHook(() => useAppsQueryState())
act(() => {
result.current.setQuery({ tagIDs: [] })
})
await act(async () => {
await new Promise(resolve => setTimeout(resolve, 0))
})
expect(mockPush).toHaveBeenCalled()
})
it('should remove isCreatedByMe from URL when false', async () => {
mockSearchParams.set('isCreatedByMe', 'true')
const { result } = renderHook(() => useAppsQueryState())
act(() => {
result.current.setQuery({ isCreatedByMe: false })
})
await act(async () => {
await new Promise(resolve => setTimeout(resolve, 0))
})
expect(mockPush).toHaveBeenCalled()
})
})
describe('Edge cases', () => {
it('should handle empty tagIDs string in URL', () => {
// NOTE: This test documents current behavior where ''.split(';') returns ['']
// This could potentially cause filtering issues as it's treated as a tag with empty name
// rather than absence of tags. Consider updating parseParams if this is problematic.
mockSearchParams.set('tagIDs', '')
const { result } = renderHook(() => useAppsQueryState())
expect(result.current.query.tagIDs).toEqual([''])
})
it('should handle empty keywords', () => {
mockSearchParams.set('keywords', '')
const { result } = renderHook(() => useAppsQueryState())
expect(result.current.query.keywords).toBeUndefined()
})
it('should handle undefined tagIDs', () => {
const { result } = renderHook(() => useAppsQueryState())
act(() => {
result.current.setQuery({ tagIDs: undefined })
})
expect(result.current.query.tagIDs).toBeUndefined()
})
it('should handle special characters in keywords', () => {
// Use URLSearchParams constructor to properly simulate URL decoding behavior
// URLSearchParams.get() decodes URL-encoded characters
mockSearchParams = new URLSearchParams('keywords=test%20with%20spaces')
const { result } = renderHook(() => useAppsQueryState())
expect(result.current.query.keywords).toBe('test with spaces')
})
})
describe('Memoization', () => {
it('should return memoized object reference when query unchanged', () => {
const { result, rerender } = renderHook(() => useAppsQueryState())
const firstResult = result.current
rerender()
const secondResult = result.current
expect(firstResult.query).toBe(secondResult.query)
})
it('should return new object reference when query changes', () => {
const { result } = renderHook(() => useAppsQueryState())
const firstQuery = result.current.query
act(() => {
result.current.setQuery({ keywords: 'changed' })
})
expect(result.current.query).not.toBe(firstQuery)
})
})
describe('Integration scenarios', () => {
it('should handle sequential updates', async () => {
const { result } = renderHook(() => useAppsQueryState())
act(() => {
result.current.setQuery({ keywords: 'first' })
})
act(() => {
result.current.setQuery(prev => ({ ...prev, tagIDs: ['tag1'] }))
})
act(() => {
result.current.setQuery(prev => ({ ...prev, isCreatedByMe: true }))
})
expect(result.current.query.keywords).toBe('first')
expect(result.current.query.tagIDs).toEqual(['tag1'])
expect(result.current.query.isCreatedByMe).toBe(true)
})
it('should clear all filters', () => {
mockSearchParams.set('tagIDs', 'tag1;tag2')
mockSearchParams.set('keywords', 'search')
mockSearchParams.set('isCreatedByMe', 'true')
const { result } = renderHook(() => useAppsQueryState())
act(() => {
result.current.setQuery({
tagIDs: undefined,
keywords: undefined,
isCreatedByMe: false,
})
})
expect(result.current.query.tagIDs).toBeUndefined()
expect(result.current.query.keywords).toBeUndefined()
expect(result.current.query.isCreatedByMe).toBe(false)
})
})
})

Some files were not shown because too many files have changed in this diff Show More