mirror of
https://github.com/langgenius/dify.git
synced 2026-02-04 02:37:53 +08:00
Compare commits
18 Commits
fix/loop-i
...
release/e-
| Author | SHA1 | Date | |
|---|---|---|---|
| c0c64c75d6 | |||
| 9b0984eab7 | |||
| 83a943d8c4 | |||
| c21f7c48fb | |||
| 8371a26cdf | |||
| a6e43f0fa0 | |||
| 30f9199fba | |||
| b47afdd314 | |||
| fc81605ae8 | |||
| 248871fca1 | |||
| 4e0d3c224f | |||
| c9858f851f | |||
| c17052b8b4 | |||
| 70571b53ad | |||
| 44ef3cc27d | |||
| 676063890c | |||
| 901cc64ac9 | |||
| 894a3c03a2 |
@ -1,8 +0,0 @@
|
||||
{
|
||||
"enabledPlugins": {
|
||||
"feature-dev@claude-plugins-official": true,
|
||||
"context7@claude-plugins-official": true,
|
||||
"typescript-lsp@claude-plugins-official": true,
|
||||
"pyright-lsp@claude-plugins-official": true
|
||||
}
|
||||
}
|
||||
19
.claude/settings.json.example
Normal file
19
.claude/settings.json.example
Normal file
@ -0,0 +1,19 @@
|
||||
{
|
||||
"permissions": {
|
||||
"allow": [],
|
||||
"deny": []
|
||||
},
|
||||
"env": {
|
||||
"__comment": "Environment variables for MCP servers. Override in .claude/settings.local.json with actual values.",
|
||||
"GITHUB_PERSONAL_ACCESS_TOKEN": "ghp_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
|
||||
},
|
||||
"enabledMcpjsonServers": [
|
||||
"context7",
|
||||
"sequential-thinking",
|
||||
"github",
|
||||
"fetch",
|
||||
"playwright",
|
||||
"ide"
|
||||
],
|
||||
"enableAllProjectMcpServers": true
|
||||
}
|
||||
@ -1,483 +0,0 @@
|
||||
---
|
||||
name: component-refactoring
|
||||
description: Refactor high-complexity React components in Dify frontend. Use when `pnpm analyze-component --json` shows complexity > 50 or lineCount > 300, when the user asks for code splitting, hook extraction, or complexity reduction, or when `pnpm analyze-component` warns to refactor before testing; avoid for simple/well-structured components, third-party wrappers, or when the user explicitly wants testing without refactoring.
|
||||
---
|
||||
|
||||
# Dify Component Refactoring Skill
|
||||
|
||||
Refactor high-complexity React components in the Dify frontend codebase with the patterns and workflow below.
|
||||
|
||||
> **Complexity Threshold**: Components with complexity > 50 (measured by `pnpm analyze-component`) should be refactored before testing.
|
||||
|
||||
## Quick Reference
|
||||
|
||||
### Commands (run from `web/`)
|
||||
|
||||
Use paths relative to `web/` (e.g., `app/components/...`).
|
||||
Use `refactor-component` for refactoring prompts and `analyze-component` for testing prompts and metrics.
|
||||
|
||||
```bash
|
||||
cd web
|
||||
|
||||
# Generate refactoring prompt
|
||||
pnpm refactor-component <path>
|
||||
|
||||
# Output refactoring analysis as JSON
|
||||
pnpm refactor-component <path> --json
|
||||
|
||||
# Generate testing prompt (after refactoring)
|
||||
pnpm analyze-component <path>
|
||||
|
||||
# Output testing analysis as JSON
|
||||
pnpm analyze-component <path> --json
|
||||
```
|
||||
|
||||
### Complexity Analysis
|
||||
|
||||
```bash
|
||||
# Analyze component complexity
|
||||
pnpm analyze-component <path> --json
|
||||
|
||||
# Key metrics to check:
|
||||
# - complexity: normalized score 0-100 (target < 50)
|
||||
# - maxComplexity: highest single function complexity
|
||||
# - lineCount: total lines (target < 300)
|
||||
```
|
||||
|
||||
### Complexity Score Interpretation
|
||||
|
||||
| Score | Level | Action |
|
||||
|-------|-------|--------|
|
||||
| 0-25 | 🟢 Simple | Ready for testing |
|
||||
| 26-50 | 🟡 Medium | Consider minor refactoring |
|
||||
| 51-75 | 🟠 Complex | **Refactor before testing** |
|
||||
| 76-100 | 🔴 Very Complex | **Must refactor** |
|
||||
|
||||
## Core Refactoring Patterns
|
||||
|
||||
### Pattern 1: Extract Custom Hooks
|
||||
|
||||
**When**: Component has complex state management, multiple `useState`/`useEffect`, or business logic mixed with UI.
|
||||
|
||||
**Dify Convention**: Place hooks in a `hooks/` subdirectory or alongside the component as `use-<feature>.ts`.
|
||||
|
||||
```typescript
|
||||
// ❌ Before: Complex state logic in component
|
||||
const Configuration: FC = () => {
|
||||
const [modelConfig, setModelConfig] = useState<ModelConfig>(...)
|
||||
const [datasetConfigs, setDatasetConfigs] = useState<DatasetConfigs>(...)
|
||||
const [completionParams, setCompletionParams] = useState<FormValue>({})
|
||||
|
||||
// 50+ lines of state management logic...
|
||||
|
||||
return <div>...</div>
|
||||
}
|
||||
|
||||
// ✅ After: Extract to custom hook
|
||||
// hooks/use-model-config.ts
|
||||
export const useModelConfig = (appId: string) => {
|
||||
const [modelConfig, setModelConfig] = useState<ModelConfig>(...)
|
||||
const [completionParams, setCompletionParams] = useState<FormValue>({})
|
||||
|
||||
// Related state management logic here
|
||||
|
||||
return { modelConfig, setModelConfig, completionParams, setCompletionParams }
|
||||
}
|
||||
|
||||
// Component becomes cleaner
|
||||
const Configuration: FC = () => {
|
||||
const { modelConfig, setModelConfig } = useModelConfig(appId)
|
||||
return <div>...</div>
|
||||
}
|
||||
```
|
||||
|
||||
**Dify Examples**:
|
||||
- `web/app/components/app/configuration/hooks/use-advanced-prompt-config.ts`
|
||||
- `web/app/components/app/configuration/debug/hooks.tsx`
|
||||
- `web/app/components/workflow/hooks/use-workflow.ts`
|
||||
|
||||
### Pattern 2: Extract Sub-Components
|
||||
|
||||
**When**: Single component has multiple UI sections, conditional rendering blocks, or repeated patterns.
|
||||
|
||||
**Dify Convention**: Place sub-components in subdirectories or as separate files in the same directory.
|
||||
|
||||
```typescript
|
||||
// ❌ Before: Monolithic JSX with multiple sections
|
||||
const AppInfo = () => {
|
||||
return (
|
||||
<div>
|
||||
{/* 100 lines of header UI */}
|
||||
{/* 100 lines of operations UI */}
|
||||
{/* 100 lines of modals */}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// ✅ After: Split into focused components
|
||||
// app-info/
|
||||
// ├── index.tsx (orchestration only)
|
||||
// ├── app-header.tsx (header UI)
|
||||
// ├── app-operations.tsx (operations UI)
|
||||
// └── app-modals.tsx (modal management)
|
||||
|
||||
const AppInfo = () => {
|
||||
const { showModal, setShowModal } = useAppInfoModals()
|
||||
|
||||
return (
|
||||
<div>
|
||||
<AppHeader appDetail={appDetail} />
|
||||
<AppOperations onAction={handleAction} />
|
||||
<AppModals show={showModal} onClose={() => setShowModal(null)} />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
**Dify Examples**:
|
||||
- `web/app/components/app/configuration/` directory structure
|
||||
- `web/app/components/workflow/nodes/` per-node organization
|
||||
|
||||
### Pattern 3: Simplify Conditional Logic
|
||||
|
||||
**When**: Deep nesting (> 3 levels), complex ternaries, or multiple `if/else` chains.
|
||||
|
||||
```typescript
|
||||
// ❌ Before: Deeply nested conditionals
|
||||
const Template = useMemo(() => {
|
||||
if (appDetail?.mode === AppModeEnum.CHAT) {
|
||||
switch (locale) {
|
||||
case LanguagesSupported[1]:
|
||||
return <TemplateChatZh />
|
||||
case LanguagesSupported[7]:
|
||||
return <TemplateChatJa />
|
||||
default:
|
||||
return <TemplateChatEn />
|
||||
}
|
||||
}
|
||||
if (appDetail?.mode === AppModeEnum.ADVANCED_CHAT) {
|
||||
// Another 15 lines...
|
||||
}
|
||||
// More conditions...
|
||||
}, [appDetail, locale])
|
||||
|
||||
// ✅ After: Use lookup tables + early returns
|
||||
const TEMPLATE_MAP = {
|
||||
[AppModeEnum.CHAT]: {
|
||||
[LanguagesSupported[1]]: TemplateChatZh,
|
||||
[LanguagesSupported[7]]: TemplateChatJa,
|
||||
default: TemplateChatEn,
|
||||
},
|
||||
[AppModeEnum.ADVANCED_CHAT]: {
|
||||
[LanguagesSupported[1]]: TemplateAdvancedChatZh,
|
||||
// ...
|
||||
},
|
||||
}
|
||||
|
||||
const Template = useMemo(() => {
|
||||
const modeTemplates = TEMPLATE_MAP[appDetail?.mode]
|
||||
if (!modeTemplates) return null
|
||||
|
||||
const TemplateComponent = modeTemplates[locale] || modeTemplates.default
|
||||
return <TemplateComponent appDetail={appDetail} />
|
||||
}, [appDetail, locale])
|
||||
```
|
||||
|
||||
### Pattern 4: Extract API/Data Logic
|
||||
|
||||
**When**: Component directly handles API calls, data transformation, or complex async operations.
|
||||
|
||||
**Dify Convention**: Use `@tanstack/react-query` hooks from `web/service/use-*.ts` or create custom data hooks.
|
||||
|
||||
```typescript
|
||||
// ❌ Before: API logic in component
|
||||
const MCPServiceCard = () => {
|
||||
const [basicAppConfig, setBasicAppConfig] = useState({})
|
||||
|
||||
useEffect(() => {
|
||||
if (isBasicApp && appId) {
|
||||
(async () => {
|
||||
const res = await fetchAppDetail({ url: '/apps', id: appId })
|
||||
setBasicAppConfig(res?.model_config || {})
|
||||
})()
|
||||
}
|
||||
}, [appId, isBasicApp])
|
||||
|
||||
// More API-related logic...
|
||||
}
|
||||
|
||||
// ✅ After: Extract to data hook using React Query
|
||||
// use-app-config.ts
|
||||
import { useQuery } from '@tanstack/react-query'
|
||||
import { get } from '@/service/base'
|
||||
|
||||
const NAME_SPACE = 'appConfig'
|
||||
|
||||
export const useAppConfig = (appId: string, isBasicApp: boolean) => {
|
||||
return useQuery({
|
||||
enabled: isBasicApp && !!appId,
|
||||
queryKey: [NAME_SPACE, 'detail', appId],
|
||||
queryFn: () => get<AppDetailResponse>(`/apps/${appId}`),
|
||||
select: data => data?.model_config || {},
|
||||
})
|
||||
}
|
||||
|
||||
// Component becomes cleaner
|
||||
const MCPServiceCard = () => {
|
||||
const { data: config, isLoading } = useAppConfig(appId, isBasicApp)
|
||||
// UI only
|
||||
}
|
||||
```
|
||||
|
||||
**React Query Best Practices in Dify**:
|
||||
- Define `NAME_SPACE` for query key organization
|
||||
- Use `enabled` option for conditional fetching
|
||||
- Use `select` for data transformation
|
||||
- Export invalidation hooks: `useInvalidXxx`
|
||||
|
||||
**Dify Examples**:
|
||||
- `web/service/use-workflow.ts`
|
||||
- `web/service/use-common.ts`
|
||||
- `web/service/knowledge/use-dataset.ts`
|
||||
- `web/service/knowledge/use-document.ts`
|
||||
|
||||
### Pattern 5: Extract Modal/Dialog Management
|
||||
|
||||
**When**: Component manages multiple modals with complex open/close states.
|
||||
|
||||
**Dify Convention**: Modals should be extracted with their state management.
|
||||
|
||||
```typescript
|
||||
// ❌ Before: Multiple modal states in component
|
||||
const AppInfo = () => {
|
||||
const [showEditModal, setShowEditModal] = useState(false)
|
||||
const [showDuplicateModal, setShowDuplicateModal] = useState(false)
|
||||
const [showConfirmDelete, setShowConfirmDelete] = useState(false)
|
||||
const [showSwitchModal, setShowSwitchModal] = useState(false)
|
||||
const [showImportDSLModal, setShowImportDSLModal] = useState(false)
|
||||
// 5+ more modal states...
|
||||
}
|
||||
|
||||
// ✅ After: Extract to modal management hook
|
||||
type ModalType = 'edit' | 'duplicate' | 'delete' | 'switch' | 'import' | null
|
||||
|
||||
const useAppInfoModals = () => {
|
||||
const [activeModal, setActiveModal] = useState<ModalType>(null)
|
||||
|
||||
const openModal = useCallback((type: ModalType) => setActiveModal(type), [])
|
||||
const closeModal = useCallback(() => setActiveModal(null), [])
|
||||
|
||||
return {
|
||||
activeModal,
|
||||
openModal,
|
||||
closeModal,
|
||||
isOpen: (type: ModalType) => activeModal === type,
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Pattern 6: Extract Form Logic
|
||||
|
||||
**When**: Complex form validation, submission handling, or field transformation.
|
||||
|
||||
**Dify Convention**: Use `@tanstack/react-form` patterns from `web/app/components/base/form/`.
|
||||
|
||||
```typescript
|
||||
// ✅ Use existing form infrastructure
|
||||
import { useAppForm } from '@/app/components/base/form'
|
||||
|
||||
const ConfigForm = () => {
|
||||
const form = useAppForm({
|
||||
defaultValues: { name: '', description: '' },
|
||||
onSubmit: handleSubmit,
|
||||
})
|
||||
|
||||
return <form.Provider>...</form.Provider>
|
||||
}
|
||||
```
|
||||
|
||||
## Dify-Specific Refactoring Guidelines
|
||||
|
||||
### 1. Context Provider Extraction
|
||||
|
||||
**When**: Component provides complex context values with multiple states.
|
||||
|
||||
```typescript
|
||||
// ❌ Before: Large context value object
|
||||
const value = {
|
||||
appId, isAPIKeySet, isTrailFinished, mode, modelModeType,
|
||||
promptMode, isAdvancedMode, isAgent, isOpenAI, isFunctionCall,
|
||||
// 50+ more properties...
|
||||
}
|
||||
return <ConfigContext.Provider value={value}>...</ConfigContext.Provider>
|
||||
|
||||
// ✅ After: Split into domain-specific contexts
|
||||
<ModelConfigProvider value={modelConfigValue}>
|
||||
<DatasetConfigProvider value={datasetConfigValue}>
|
||||
<UIConfigProvider value={uiConfigValue}>
|
||||
{children}
|
||||
</UIConfigProvider>
|
||||
</DatasetConfigProvider>
|
||||
</ModelConfigProvider>
|
||||
```
|
||||
|
||||
**Dify Reference**: `web/context/` directory structure
|
||||
|
||||
### 2. Workflow Node Components
|
||||
|
||||
**When**: Refactoring workflow node components (`web/app/components/workflow/nodes/`).
|
||||
|
||||
**Conventions**:
|
||||
- Keep node logic in `use-interactions.ts`
|
||||
- Extract panel UI to separate files
|
||||
- Use `_base` components for common patterns
|
||||
|
||||
```
|
||||
nodes/<node-type>/
|
||||
├── index.tsx # Node registration
|
||||
├── node.tsx # Node visual component
|
||||
├── panel.tsx # Configuration panel
|
||||
├── use-interactions.ts # Node-specific hooks
|
||||
└── types.ts # Type definitions
|
||||
```
|
||||
|
||||
### 3. Configuration Components
|
||||
|
||||
**When**: Refactoring app configuration components.
|
||||
|
||||
**Conventions**:
|
||||
- Separate config sections into subdirectories
|
||||
- Use existing patterns from `web/app/components/app/configuration/`
|
||||
- Keep feature toggles in dedicated components
|
||||
|
||||
### 4. Tool/Plugin Components
|
||||
|
||||
**When**: Refactoring tool-related components (`web/app/components/tools/`).
|
||||
|
||||
**Conventions**:
|
||||
- Follow existing modal patterns
|
||||
- Use service hooks from `web/service/use-tools.ts`
|
||||
- Keep provider-specific logic isolated
|
||||
|
||||
## Refactoring Workflow
|
||||
|
||||
### Step 1: Generate Refactoring Prompt
|
||||
|
||||
```bash
|
||||
pnpm refactor-component <path>
|
||||
```
|
||||
|
||||
This command will:
|
||||
- Analyze component complexity and features
|
||||
- Identify specific refactoring actions needed
|
||||
- Generate a prompt for AI assistant (auto-copied to clipboard on macOS)
|
||||
- Provide detailed requirements based on detected patterns
|
||||
|
||||
### Step 2: Analyze Details
|
||||
|
||||
```bash
|
||||
pnpm analyze-component <path> --json
|
||||
```
|
||||
|
||||
Identify:
|
||||
- Total complexity score
|
||||
- Max function complexity
|
||||
- Line count
|
||||
- Features detected (state, effects, API, etc.)
|
||||
|
||||
### Step 3: Plan
|
||||
|
||||
Create a refactoring plan based on detected features:
|
||||
|
||||
| Detected Feature | Refactoring Action |
|
||||
|------------------|-------------------|
|
||||
| `hasState: true` + `hasEffects: true` | Extract custom hook |
|
||||
| `hasAPI: true` | Extract data/service hook |
|
||||
| `hasEvents: true` (many) | Extract event handlers |
|
||||
| `lineCount > 300` | Split into sub-components |
|
||||
| `maxComplexity > 50` | Simplify conditional logic |
|
||||
|
||||
### Step 4: Execute Incrementally
|
||||
|
||||
1. **Extract one piece at a time**
|
||||
2. **Run lint, type-check, and tests after each extraction**
|
||||
3. **Verify functionality before next step**
|
||||
|
||||
```
|
||||
For each extraction:
|
||||
┌────────────────────────────────────────┐
|
||||
│ 1. Extract code │
|
||||
│ 2. Run: pnpm lint:fix │
|
||||
│ 3. Run: pnpm type-check:tsgo │
|
||||
│ 4. Run: pnpm test │
|
||||
│ 5. Test functionality manually │
|
||||
│ 6. PASS? → Next extraction │
|
||||
│ FAIL? → Fix before continuing │
|
||||
└────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### Step 5: Verify
|
||||
|
||||
After refactoring:
|
||||
|
||||
```bash
|
||||
# Re-run refactor command to verify improvements
|
||||
pnpm refactor-component <path>
|
||||
|
||||
# If complexity < 25 and lines < 200, you'll see:
|
||||
# ✅ COMPONENT IS WELL-STRUCTURED
|
||||
|
||||
# For detailed metrics:
|
||||
pnpm analyze-component <path> --json
|
||||
|
||||
# Target metrics:
|
||||
# - complexity < 50
|
||||
# - lineCount < 300
|
||||
# - maxComplexity < 30
|
||||
```
|
||||
|
||||
## Common Mistakes to Avoid
|
||||
|
||||
### ❌ Over-Engineering
|
||||
|
||||
```typescript
|
||||
// ❌ Too many tiny hooks
|
||||
const useButtonText = () => useState('Click')
|
||||
const useButtonDisabled = () => useState(false)
|
||||
const useButtonLoading = () => useState(false)
|
||||
|
||||
// ✅ Cohesive hook with related state
|
||||
const useButtonState = () => {
|
||||
const [text, setText] = useState('Click')
|
||||
const [disabled, setDisabled] = useState(false)
|
||||
const [loading, setLoading] = useState(false)
|
||||
return { text, setText, disabled, setDisabled, loading, setLoading }
|
||||
}
|
||||
```
|
||||
|
||||
### ❌ Breaking Existing Patterns
|
||||
|
||||
- Follow existing directory structures
|
||||
- Maintain naming conventions
|
||||
- Preserve export patterns for compatibility
|
||||
|
||||
### ❌ Premature Abstraction
|
||||
|
||||
- Only extract when there's clear complexity benefit
|
||||
- Don't create abstractions for single-use code
|
||||
- Keep refactored code in the same domain area
|
||||
|
||||
## References
|
||||
|
||||
### Dify Codebase Examples
|
||||
|
||||
- **Hook extraction**: `web/app/components/app/configuration/hooks/`
|
||||
- **Component splitting**: `web/app/components/app/configuration/`
|
||||
- **Service hooks**: `web/service/use-*.ts`
|
||||
- **Workflow patterns**: `web/app/components/workflow/hooks/`
|
||||
- **Form patterns**: `web/app/components/base/form/`
|
||||
|
||||
### Related Skills
|
||||
|
||||
- `frontend-testing` - For testing refactored components
|
||||
- `web/testing/testing.md` - Testing specification
|
||||
@ -1,493 +0,0 @@
|
||||
# Complexity Reduction Patterns
|
||||
|
||||
This document provides patterns for reducing cognitive complexity in Dify React components.
|
||||
|
||||
## Understanding Complexity
|
||||
|
||||
### SonarJS Cognitive Complexity
|
||||
|
||||
The `pnpm analyze-component` tool uses SonarJS cognitive complexity metrics:
|
||||
|
||||
- **Total Complexity**: Sum of all functions' complexity in the file
|
||||
- **Max Complexity**: Highest single function complexity
|
||||
|
||||
### What Increases Complexity
|
||||
|
||||
| Pattern | Complexity Impact |
|
||||
|---------|-------------------|
|
||||
| `if/else` | +1 per branch |
|
||||
| Nested conditions | +1 per nesting level |
|
||||
| `switch/case` | +1 per case |
|
||||
| `for/while/do` | +1 per loop |
|
||||
| `&&`/`||` chains | +1 per operator |
|
||||
| Nested callbacks | +1 per nesting level |
|
||||
| `try/catch` | +1 per catch |
|
||||
| Ternary expressions | +1 per nesting |
|
||||
|
||||
## Pattern 1: Replace Conditionals with Lookup Tables
|
||||
|
||||
**Before** (complexity: ~15):
|
||||
|
||||
```typescript
|
||||
const Template = useMemo(() => {
|
||||
if (appDetail?.mode === AppModeEnum.CHAT) {
|
||||
switch (locale) {
|
||||
case LanguagesSupported[1]:
|
||||
return <TemplateChatZh appDetail={appDetail} />
|
||||
case LanguagesSupported[7]:
|
||||
return <TemplateChatJa appDetail={appDetail} />
|
||||
default:
|
||||
return <TemplateChatEn appDetail={appDetail} />
|
||||
}
|
||||
}
|
||||
if (appDetail?.mode === AppModeEnum.ADVANCED_CHAT) {
|
||||
switch (locale) {
|
||||
case LanguagesSupported[1]:
|
||||
return <TemplateAdvancedChatZh appDetail={appDetail} />
|
||||
case LanguagesSupported[7]:
|
||||
return <TemplateAdvancedChatJa appDetail={appDetail} />
|
||||
default:
|
||||
return <TemplateAdvancedChatEn appDetail={appDetail} />
|
||||
}
|
||||
}
|
||||
if (appDetail?.mode === AppModeEnum.WORKFLOW) {
|
||||
// Similar pattern...
|
||||
}
|
||||
return null
|
||||
}, [appDetail, locale])
|
||||
```
|
||||
|
||||
**After** (complexity: ~3):
|
||||
|
||||
```typescript
|
||||
// Define lookup table outside component
|
||||
const TEMPLATE_MAP: Record<AppModeEnum, Record<string, FC<TemplateProps>>> = {
|
||||
[AppModeEnum.CHAT]: {
|
||||
[LanguagesSupported[1]]: TemplateChatZh,
|
||||
[LanguagesSupported[7]]: TemplateChatJa,
|
||||
default: TemplateChatEn,
|
||||
},
|
||||
[AppModeEnum.ADVANCED_CHAT]: {
|
||||
[LanguagesSupported[1]]: TemplateAdvancedChatZh,
|
||||
[LanguagesSupported[7]]: TemplateAdvancedChatJa,
|
||||
default: TemplateAdvancedChatEn,
|
||||
},
|
||||
[AppModeEnum.WORKFLOW]: {
|
||||
[LanguagesSupported[1]]: TemplateWorkflowZh,
|
||||
[LanguagesSupported[7]]: TemplateWorkflowJa,
|
||||
default: TemplateWorkflowEn,
|
||||
},
|
||||
// ...
|
||||
}
|
||||
|
||||
// Clean component logic
|
||||
const Template = useMemo(() => {
|
||||
if (!appDetail?.mode) return null
|
||||
|
||||
const templates = TEMPLATE_MAP[appDetail.mode]
|
||||
if (!templates) return null
|
||||
|
||||
const TemplateComponent = templates[locale] ?? templates.default
|
||||
return <TemplateComponent appDetail={appDetail} />
|
||||
}, [appDetail, locale])
|
||||
```
|
||||
|
||||
## Pattern 2: Use Early Returns
|
||||
|
||||
**Before** (complexity: ~10):
|
||||
|
||||
```typescript
|
||||
const handleSubmit = () => {
|
||||
if (isValid) {
|
||||
if (hasChanges) {
|
||||
if (isConnected) {
|
||||
submitData()
|
||||
} else {
|
||||
showConnectionError()
|
||||
}
|
||||
} else {
|
||||
showNoChangesMessage()
|
||||
}
|
||||
} else {
|
||||
showValidationError()
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**After** (complexity: ~4):
|
||||
|
||||
```typescript
|
||||
const handleSubmit = () => {
|
||||
if (!isValid) {
|
||||
showValidationError()
|
||||
return
|
||||
}
|
||||
|
||||
if (!hasChanges) {
|
||||
showNoChangesMessage()
|
||||
return
|
||||
}
|
||||
|
||||
if (!isConnected) {
|
||||
showConnectionError()
|
||||
return
|
||||
}
|
||||
|
||||
submitData()
|
||||
}
|
||||
```
|
||||
|
||||
## Pattern 3: Extract Complex Conditions
|
||||
|
||||
**Before** (complexity: high):
|
||||
|
||||
```typescript
|
||||
const canPublish = (() => {
|
||||
if (mode !== AppModeEnum.COMPLETION) {
|
||||
if (!isAdvancedMode)
|
||||
return true
|
||||
|
||||
if (modelModeType === ModelModeType.completion) {
|
||||
if (!hasSetBlockStatus.history || !hasSetBlockStatus.query)
|
||||
return false
|
||||
return true
|
||||
}
|
||||
return true
|
||||
}
|
||||
return !promptEmpty
|
||||
})()
|
||||
```
|
||||
|
||||
**After** (complexity: lower):
|
||||
|
||||
```typescript
|
||||
// Extract to named functions
|
||||
const canPublishInCompletionMode = () => !promptEmpty
|
||||
|
||||
const canPublishInChatMode = () => {
|
||||
if (!isAdvancedMode) return true
|
||||
if (modelModeType !== ModelModeType.completion) return true
|
||||
return hasSetBlockStatus.history && hasSetBlockStatus.query
|
||||
}
|
||||
|
||||
// Clean main logic
|
||||
const canPublish = mode === AppModeEnum.COMPLETION
|
||||
? canPublishInCompletionMode()
|
||||
: canPublishInChatMode()
|
||||
```
|
||||
|
||||
## Pattern 4: Replace Chained Ternaries
|
||||
|
||||
**Before** (complexity: ~5):
|
||||
|
||||
```typescript
|
||||
const statusText = serverActivated
|
||||
? t('status.running')
|
||||
: serverPublished
|
||||
? t('status.inactive')
|
||||
: appUnpublished
|
||||
? t('status.unpublished')
|
||||
: t('status.notConfigured')
|
||||
```
|
||||
|
||||
**After** (complexity: ~2):
|
||||
|
||||
```typescript
|
||||
const getStatusText = () => {
|
||||
if (serverActivated) return t('status.running')
|
||||
if (serverPublished) return t('status.inactive')
|
||||
if (appUnpublished) return t('status.unpublished')
|
||||
return t('status.notConfigured')
|
||||
}
|
||||
|
||||
const statusText = getStatusText()
|
||||
```
|
||||
|
||||
Or use lookup:
|
||||
|
||||
```typescript
|
||||
const STATUS_TEXT_MAP = {
|
||||
running: 'status.running',
|
||||
inactive: 'status.inactive',
|
||||
unpublished: 'status.unpublished',
|
||||
notConfigured: 'status.notConfigured',
|
||||
} as const
|
||||
|
||||
const getStatusKey = (): keyof typeof STATUS_TEXT_MAP => {
|
||||
if (serverActivated) return 'running'
|
||||
if (serverPublished) return 'inactive'
|
||||
if (appUnpublished) return 'unpublished'
|
||||
return 'notConfigured'
|
||||
}
|
||||
|
||||
const statusText = t(STATUS_TEXT_MAP[getStatusKey()])
|
||||
```
|
||||
|
||||
## Pattern 5: Flatten Nested Loops
|
||||
|
||||
**Before** (complexity: high):
|
||||
|
||||
```typescript
|
||||
const processData = (items: Item[]) => {
|
||||
const results: ProcessedItem[] = []
|
||||
|
||||
for (const item of items) {
|
||||
if (item.isValid) {
|
||||
for (const child of item.children) {
|
||||
if (child.isActive) {
|
||||
for (const prop of child.properties) {
|
||||
if (prop.value !== null) {
|
||||
results.push({
|
||||
itemId: item.id,
|
||||
childId: child.id,
|
||||
propValue: prop.value,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return results
|
||||
}
|
||||
```
|
||||
|
||||
**After** (complexity: lower):
|
||||
|
||||
```typescript
|
||||
// Use functional approach
|
||||
const processData = (items: Item[]) => {
|
||||
return items
|
||||
.filter(item => item.isValid)
|
||||
.flatMap(item =>
|
||||
item.children
|
||||
.filter(child => child.isActive)
|
||||
.flatMap(child =>
|
||||
child.properties
|
||||
.filter(prop => prop.value !== null)
|
||||
.map(prop => ({
|
||||
itemId: item.id,
|
||||
childId: child.id,
|
||||
propValue: prop.value,
|
||||
}))
|
||||
)
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
## Pattern 6: Extract Event Handler Logic
|
||||
|
||||
**Before** (complexity: high in component):
|
||||
|
||||
```typescript
|
||||
const Component = () => {
|
||||
const handleSelect = (data: DataSet[]) => {
|
||||
if (isEqual(data.map(item => item.id), dataSets.map(item => item.id))) {
|
||||
hideSelectDataSet()
|
||||
return
|
||||
}
|
||||
|
||||
formattingChangedDispatcher()
|
||||
let newDatasets = data
|
||||
if (data.find(item => !item.name)) {
|
||||
const newSelected = produce(data, (draft) => {
|
||||
data.forEach((item, index) => {
|
||||
if (!item.name) {
|
||||
const newItem = dataSets.find(i => i.id === item.id)
|
||||
if (newItem)
|
||||
draft[index] = newItem
|
||||
}
|
||||
})
|
||||
})
|
||||
setDataSets(newSelected)
|
||||
newDatasets = newSelected
|
||||
}
|
||||
else {
|
||||
setDataSets(data)
|
||||
}
|
||||
hideSelectDataSet()
|
||||
|
||||
// 40 more lines of logic...
|
||||
}
|
||||
|
||||
return <div>...</div>
|
||||
}
|
||||
```
|
||||
|
||||
**After** (complexity: lower):
|
||||
|
||||
```typescript
|
||||
// Extract to hook or utility
|
||||
const useDatasetSelection = (dataSets: DataSet[], setDataSets: SetState<DataSet[]>) => {
|
||||
const normalizeSelection = (data: DataSet[]) => {
|
||||
const hasUnloadedItem = data.some(item => !item.name)
|
||||
if (!hasUnloadedItem) return data
|
||||
|
||||
return produce(data, (draft) => {
|
||||
data.forEach((item, index) => {
|
||||
if (!item.name) {
|
||||
const existing = dataSets.find(i => i.id === item.id)
|
||||
if (existing) draft[index] = existing
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
const hasSelectionChanged = (newData: DataSet[]) => {
|
||||
return !isEqual(
|
||||
newData.map(item => item.id),
|
||||
dataSets.map(item => item.id)
|
||||
)
|
||||
}
|
||||
|
||||
return { normalizeSelection, hasSelectionChanged }
|
||||
}
|
||||
|
||||
// Component becomes cleaner
|
||||
const Component = () => {
|
||||
const { normalizeSelection, hasSelectionChanged } = useDatasetSelection(dataSets, setDataSets)
|
||||
|
||||
const handleSelect = (data: DataSet[]) => {
|
||||
if (!hasSelectionChanged(data)) {
|
||||
hideSelectDataSet()
|
||||
return
|
||||
}
|
||||
|
||||
formattingChangedDispatcher()
|
||||
const normalized = normalizeSelection(data)
|
||||
setDataSets(normalized)
|
||||
hideSelectDataSet()
|
||||
}
|
||||
|
||||
return <div>...</div>
|
||||
}
|
||||
```
|
||||
|
||||
## Pattern 7: Reduce Boolean Logic Complexity
|
||||
|
||||
**Before** (complexity: ~8):
|
||||
|
||||
```typescript
|
||||
const toggleDisabled = hasInsufficientPermissions
|
||||
|| appUnpublished
|
||||
|| missingStartNode
|
||||
|| triggerModeDisabled
|
||||
|| (isAdvancedApp && !currentWorkflow?.graph)
|
||||
|| (isBasicApp && !basicAppConfig.updated_at)
|
||||
```
|
||||
|
||||
**After** (complexity: ~3):
|
||||
|
||||
```typescript
|
||||
// Extract meaningful boolean functions
|
||||
const isAppReady = () => {
|
||||
if (isAdvancedApp) return !!currentWorkflow?.graph
|
||||
return !!basicAppConfig.updated_at
|
||||
}
|
||||
|
||||
const hasRequiredPermissions = () => {
|
||||
return isCurrentWorkspaceEditor && !hasInsufficientPermissions
|
||||
}
|
||||
|
||||
const canToggle = () => {
|
||||
if (!hasRequiredPermissions()) return false
|
||||
if (!isAppReady()) return false
|
||||
if (missingStartNode) return false
|
||||
if (triggerModeDisabled) return false
|
||||
return true
|
||||
}
|
||||
|
||||
const toggleDisabled = !canToggle()
|
||||
```
|
||||
|
||||
## Pattern 8: Simplify useMemo/useCallback Dependencies
|
||||
|
||||
**Before** (complexity: multiple recalculations):
|
||||
|
||||
```typescript
|
||||
const payload = useMemo(() => {
|
||||
let parameters: Parameter[] = []
|
||||
let outputParameters: OutputParameter[] = []
|
||||
|
||||
if (!published) {
|
||||
parameters = (inputs || []).map((item) => ({
|
||||
name: item.variable,
|
||||
description: '',
|
||||
form: 'llm',
|
||||
required: item.required,
|
||||
type: item.type,
|
||||
}))
|
||||
outputParameters = (outputs || []).map((item) => ({
|
||||
name: item.variable,
|
||||
description: '',
|
||||
type: item.value_type,
|
||||
}))
|
||||
}
|
||||
else if (detail && detail.tool) {
|
||||
parameters = (inputs || []).map((item) => ({
|
||||
// Complex transformation...
|
||||
}))
|
||||
outputParameters = (outputs || []).map((item) => ({
|
||||
// Complex transformation...
|
||||
}))
|
||||
}
|
||||
|
||||
return {
|
||||
icon: detail?.icon || icon,
|
||||
label: detail?.label || name,
|
||||
// ...more fields
|
||||
}
|
||||
}, [detail, published, workflowAppId, icon, name, description, inputs, outputs])
|
||||
```
|
||||
|
||||
**After** (complexity: separated concerns):
|
||||
|
||||
```typescript
|
||||
// Separate transformations
|
||||
const useParameterTransform = (inputs: InputVar[], detail?: ToolDetail, published?: boolean) => {
|
||||
return useMemo(() => {
|
||||
if (!published) {
|
||||
return inputs.map(item => ({
|
||||
name: item.variable,
|
||||
description: '',
|
||||
form: 'llm',
|
||||
required: item.required,
|
||||
type: item.type,
|
||||
}))
|
||||
}
|
||||
|
||||
if (!detail?.tool) return []
|
||||
|
||||
return inputs.map(item => ({
|
||||
name: item.variable,
|
||||
required: item.required,
|
||||
type: item.type === 'paragraph' ? 'string' : item.type,
|
||||
description: detail.tool.parameters.find(p => p.name === item.variable)?.llm_description || '',
|
||||
form: detail.tool.parameters.find(p => p.name === item.variable)?.form || 'llm',
|
||||
}))
|
||||
}, [inputs, detail, published])
|
||||
}
|
||||
|
||||
// Component uses hook
|
||||
const parameters = useParameterTransform(inputs, detail, published)
|
||||
const outputParameters = useOutputTransform(outputs, detail, published)
|
||||
|
||||
const payload = useMemo(() => ({
|
||||
icon: detail?.icon || icon,
|
||||
label: detail?.label || name,
|
||||
parameters,
|
||||
outputParameters,
|
||||
// ...
|
||||
}), [detail, icon, name, parameters, outputParameters])
|
||||
```
|
||||
|
||||
## Target Metrics After Refactoring
|
||||
|
||||
| Metric | Target |
|
||||
|--------|--------|
|
||||
| Total Complexity | < 50 |
|
||||
| Max Function Complexity | < 30 |
|
||||
| Function Length | < 30 lines |
|
||||
| Nesting Depth | ≤ 3 levels |
|
||||
| Conditional Chains | ≤ 3 conditions |
|
||||
@ -1,477 +0,0 @@
|
||||
# Component Splitting Patterns
|
||||
|
||||
This document provides detailed guidance on splitting large components into smaller, focused components in Dify.
|
||||
|
||||
## When to Split Components
|
||||
|
||||
Split a component when you identify:
|
||||
|
||||
1. **Multiple UI sections** - Distinct visual areas with minimal coupling that can be composed independently
|
||||
1. **Conditional rendering blocks** - Large `{condition && <JSX />}` blocks
|
||||
1. **Repeated patterns** - Similar UI structures used multiple times
|
||||
1. **300+ lines** - Component exceeds manageable size
|
||||
1. **Modal clusters** - Multiple modals rendered in one component
|
||||
|
||||
## Splitting Strategies
|
||||
|
||||
### Strategy 1: Section-Based Splitting
|
||||
|
||||
Identify visual sections and extract each as a component.
|
||||
|
||||
```typescript
|
||||
// ❌ Before: Monolithic component (500+ lines)
|
||||
const ConfigurationPage = () => {
|
||||
return (
|
||||
<div>
|
||||
{/* Header Section - 50 lines */}
|
||||
<div className="header">
|
||||
<h1>{t('configuration.title')}</h1>
|
||||
<div className="actions">
|
||||
{isAdvancedMode && <Badge>Advanced</Badge>}
|
||||
<ModelParameterModal ... />
|
||||
<AppPublisher ... />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Config Section - 200 lines */}
|
||||
<div className="config">
|
||||
<Config />
|
||||
</div>
|
||||
|
||||
{/* Debug Section - 150 lines */}
|
||||
<div className="debug">
|
||||
<Debug ... />
|
||||
</div>
|
||||
|
||||
{/* Modals Section - 100 lines */}
|
||||
{showSelectDataSet && <SelectDataSet ... />}
|
||||
{showHistoryModal && <EditHistoryModal ... />}
|
||||
{showUseGPT4Confirm && <Confirm ... />}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// ✅ After: Split into focused components
|
||||
// configuration/
|
||||
// ├── index.tsx (orchestration)
|
||||
// ├── configuration-header.tsx
|
||||
// ├── configuration-content.tsx
|
||||
// ├── configuration-debug.tsx
|
||||
// └── configuration-modals.tsx
|
||||
|
||||
// configuration-header.tsx
|
||||
interface ConfigurationHeaderProps {
|
||||
isAdvancedMode: boolean
|
||||
onPublish: () => void
|
||||
}
|
||||
|
||||
const ConfigurationHeader: FC<ConfigurationHeaderProps> = ({
|
||||
isAdvancedMode,
|
||||
onPublish,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
return (
|
||||
<div className="header">
|
||||
<h1>{t('configuration.title')}</h1>
|
||||
<div className="actions">
|
||||
{isAdvancedMode && <Badge>Advanced</Badge>}
|
||||
<ModelParameterModal ... />
|
||||
<AppPublisher onPublish={onPublish} />
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// index.tsx (orchestration only)
|
||||
const ConfigurationPage = () => {
|
||||
const { modelConfig, setModelConfig } = useModelConfig()
|
||||
const { activeModal, openModal, closeModal } = useModalState()
|
||||
|
||||
return (
|
||||
<div>
|
||||
<ConfigurationHeader
|
||||
isAdvancedMode={isAdvancedMode}
|
||||
onPublish={handlePublish}
|
||||
/>
|
||||
<ConfigurationContent
|
||||
modelConfig={modelConfig}
|
||||
onConfigChange={setModelConfig}
|
||||
/>
|
||||
{!isMobile && (
|
||||
<ConfigurationDebug
|
||||
inputs={inputs}
|
||||
onSetting={handleSetting}
|
||||
/>
|
||||
)}
|
||||
<ConfigurationModals
|
||||
activeModal={activeModal}
|
||||
onClose={closeModal}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
### Strategy 2: Conditional Block Extraction
|
||||
|
||||
Extract large conditional rendering blocks.
|
||||
|
||||
```typescript
|
||||
// ❌ Before: Large conditional blocks
|
||||
const AppInfo = () => {
|
||||
return (
|
||||
<div>
|
||||
{expand ? (
|
||||
<div className="expanded">
|
||||
{/* 100 lines of expanded view */}
|
||||
</div>
|
||||
) : (
|
||||
<div className="collapsed">
|
||||
{/* 50 lines of collapsed view */}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// ✅ After: Separate view components
|
||||
const AppInfoExpanded: FC<AppInfoViewProps> = ({ appDetail, onAction }) => {
|
||||
return (
|
||||
<div className="expanded">
|
||||
{/* Clean, focused expanded view */}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
const AppInfoCollapsed: FC<AppInfoViewProps> = ({ appDetail, onAction }) => {
|
||||
return (
|
||||
<div className="collapsed">
|
||||
{/* Clean, focused collapsed view */}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
const AppInfo = () => {
|
||||
return (
|
||||
<div>
|
||||
{expand
|
||||
? <AppInfoExpanded appDetail={appDetail} onAction={handleAction} />
|
||||
: <AppInfoCollapsed appDetail={appDetail} onAction={handleAction} />
|
||||
}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
### Strategy 3: Modal Extraction
|
||||
|
||||
Extract modals with their trigger logic.
|
||||
|
||||
```typescript
|
||||
// ❌ Before: Multiple modals in one component
|
||||
const AppInfo = () => {
|
||||
const [showEdit, setShowEdit] = useState(false)
|
||||
const [showDuplicate, setShowDuplicate] = useState(false)
|
||||
const [showDelete, setShowDelete] = useState(false)
|
||||
const [showSwitch, setShowSwitch] = useState(false)
|
||||
|
||||
const onEdit = async (data) => { /* 20 lines */ }
|
||||
const onDuplicate = async (data) => { /* 20 lines */ }
|
||||
const onDelete = async () => { /* 15 lines */ }
|
||||
|
||||
return (
|
||||
<div>
|
||||
{/* Main content */}
|
||||
|
||||
{showEdit && <EditModal onConfirm={onEdit} onClose={() => setShowEdit(false)} />}
|
||||
{showDuplicate && <DuplicateModal onConfirm={onDuplicate} onClose={() => setShowDuplicate(false)} />}
|
||||
{showDelete && <DeleteConfirm onConfirm={onDelete} onClose={() => setShowDelete(false)} />}
|
||||
{showSwitch && <SwitchModal ... />}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// ✅ After: Modal manager component
|
||||
// app-info-modals.tsx
|
||||
type ModalType = 'edit' | 'duplicate' | 'delete' | 'switch' | null
|
||||
|
||||
interface AppInfoModalsProps {
|
||||
appDetail: AppDetail
|
||||
activeModal: ModalType
|
||||
onClose: () => void
|
||||
onSuccess: () => void
|
||||
}
|
||||
|
||||
const AppInfoModals: FC<AppInfoModalsProps> = ({
|
||||
appDetail,
|
||||
activeModal,
|
||||
onClose,
|
||||
onSuccess,
|
||||
}) => {
|
||||
const handleEdit = async (data) => { /* logic */ }
|
||||
const handleDuplicate = async (data) => { /* logic */ }
|
||||
const handleDelete = async () => { /* logic */ }
|
||||
|
||||
return (
|
||||
<>
|
||||
{activeModal === 'edit' && (
|
||||
<EditModal
|
||||
appDetail={appDetail}
|
||||
onConfirm={handleEdit}
|
||||
onClose={onClose}
|
||||
/>
|
||||
)}
|
||||
{activeModal === 'duplicate' && (
|
||||
<DuplicateModal
|
||||
appDetail={appDetail}
|
||||
onConfirm={handleDuplicate}
|
||||
onClose={onClose}
|
||||
/>
|
||||
)}
|
||||
{activeModal === 'delete' && (
|
||||
<DeleteConfirm
|
||||
onConfirm={handleDelete}
|
||||
onClose={onClose}
|
||||
/>
|
||||
)}
|
||||
{activeModal === 'switch' && (
|
||||
<SwitchModal
|
||||
appDetail={appDetail}
|
||||
onClose={onClose}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
// Parent component
|
||||
const AppInfo = () => {
|
||||
const { activeModal, openModal, closeModal } = useModalState()
|
||||
|
||||
return (
|
||||
<div>
|
||||
{/* Main content with openModal triggers */}
|
||||
<Button onClick={() => openModal('edit')}>Edit</Button>
|
||||
|
||||
<AppInfoModals
|
||||
appDetail={appDetail}
|
||||
activeModal={activeModal}
|
||||
onClose={closeModal}
|
||||
onSuccess={handleSuccess}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
### Strategy 4: List Item Extraction
|
||||
|
||||
Extract repeated item rendering.
|
||||
|
||||
```typescript
|
||||
// ❌ Before: Inline item rendering
|
||||
const OperationsList = () => {
|
||||
return (
|
||||
<div>
|
||||
{operations.map(op => (
|
||||
<div key={op.id} className="operation-item">
|
||||
<span className="icon">{op.icon}</span>
|
||||
<span className="title">{op.title}</span>
|
||||
<span className="description">{op.description}</span>
|
||||
<button onClick={() => op.onClick()}>
|
||||
{op.actionLabel}
|
||||
</button>
|
||||
{op.badge && <Badge>{op.badge}</Badge>}
|
||||
{/* More complex rendering... */}
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// ✅ After: Extracted item component
|
||||
interface OperationItemProps {
|
||||
operation: Operation
|
||||
onAction: (id: string) => void
|
||||
}
|
||||
|
||||
const OperationItem: FC<OperationItemProps> = ({ operation, onAction }) => {
|
||||
return (
|
||||
<div className="operation-item">
|
||||
<span className="icon">{operation.icon}</span>
|
||||
<span className="title">{operation.title}</span>
|
||||
<span className="description">{operation.description}</span>
|
||||
<button onClick={() => onAction(operation.id)}>
|
||||
{operation.actionLabel}
|
||||
</button>
|
||||
{operation.badge && <Badge>{operation.badge}</Badge>}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
const OperationsList = () => {
|
||||
const handleAction = useCallback((id: string) => {
|
||||
const op = operations.find(o => o.id === id)
|
||||
op?.onClick()
|
||||
}, [operations])
|
||||
|
||||
return (
|
||||
<div>
|
||||
{operations.map(op => (
|
||||
<OperationItem
|
||||
key={op.id}
|
||||
operation={op}
|
||||
onAction={handleAction}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
## Directory Structure Patterns
|
||||
|
||||
### Pattern A: Flat Structure (Simple Components)
|
||||
|
||||
For components with 2-3 sub-components:
|
||||
|
||||
```
|
||||
component-name/
|
||||
├── index.tsx # Main component
|
||||
├── sub-component-a.tsx
|
||||
├── sub-component-b.tsx
|
||||
└── types.ts # Shared types
|
||||
```
|
||||
|
||||
### Pattern B: Nested Structure (Complex Components)
|
||||
|
||||
For components with many sub-components:
|
||||
|
||||
```
|
||||
component-name/
|
||||
├── index.tsx # Main orchestration
|
||||
├── types.ts # Shared types
|
||||
├── hooks/
|
||||
│ ├── use-feature-a.ts
|
||||
│ └── use-feature-b.ts
|
||||
├── components/
|
||||
│ ├── header/
|
||||
│ │ └── index.tsx
|
||||
│ ├── content/
|
||||
│ │ └── index.tsx
|
||||
│ └── modals/
|
||||
│ └── index.tsx
|
||||
└── utils/
|
||||
└── helpers.ts
|
||||
```
|
||||
|
||||
### Pattern C: Feature-Based Structure (Dify Standard)
|
||||
|
||||
Following Dify's existing patterns:
|
||||
|
||||
```
|
||||
configuration/
|
||||
├── index.tsx # Main page component
|
||||
├── base/ # Base/shared components
|
||||
│ ├── feature-panel/
|
||||
│ ├── group-name/
|
||||
│ └── operation-btn/
|
||||
├── config/ # Config section
|
||||
│ ├── index.tsx
|
||||
│ ├── agent/
|
||||
│ └── automatic/
|
||||
├── dataset-config/ # Dataset section
|
||||
│ ├── index.tsx
|
||||
│ ├── card-item/
|
||||
│ └── params-config/
|
||||
├── debug/ # Debug section
|
||||
│ ├── index.tsx
|
||||
│ └── hooks.tsx
|
||||
└── hooks/ # Shared hooks
|
||||
└── use-advanced-prompt-config.ts
|
||||
```
|
||||
|
||||
## Props Design
|
||||
|
||||
### Minimal Props Principle
|
||||
|
||||
Pass only what's needed:
|
||||
|
||||
```typescript
|
||||
// ❌ Bad: Passing entire objects when only some fields needed
|
||||
<ConfigHeader appDetail={appDetail} modelConfig={modelConfig} />
|
||||
|
||||
// ✅ Good: Destructure to minimum required
|
||||
<ConfigHeader
|
||||
appName={appDetail.name}
|
||||
isAdvancedMode={modelConfig.isAdvanced}
|
||||
onPublish={handlePublish}
|
||||
/>
|
||||
```
|
||||
|
||||
### Callback Props Pattern
|
||||
|
||||
Use callbacks for child-to-parent communication:
|
||||
|
||||
```typescript
|
||||
// Parent
|
||||
const Parent = () => {
|
||||
const [value, setValue] = useState('')
|
||||
|
||||
return (
|
||||
<Child
|
||||
value={value}
|
||||
onChange={setValue}
|
||||
onSubmit={handleSubmit}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
// Child
|
||||
interface ChildProps {
|
||||
value: string
|
||||
onChange: (value: string) => void
|
||||
onSubmit: () => void
|
||||
}
|
||||
|
||||
const Child: FC<ChildProps> = ({ value, onChange, onSubmit }) => {
|
||||
return (
|
||||
<div>
|
||||
<input value={value} onChange={e => onChange(e.target.value)} />
|
||||
<button onClick={onSubmit}>Submit</button>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
### Render Props for Flexibility
|
||||
|
||||
When sub-components need parent context:
|
||||
|
||||
```typescript
|
||||
interface ListProps<T> {
|
||||
items: T[]
|
||||
renderItem: (item: T, index: number) => React.ReactNode
|
||||
renderEmpty?: () => React.ReactNode
|
||||
}
|
||||
|
||||
function List<T>({ items, renderItem, renderEmpty }: ListProps<T>) {
|
||||
if (items.length === 0 && renderEmpty) {
|
||||
return <>{renderEmpty()}</>
|
||||
}
|
||||
|
||||
return (
|
||||
<div>
|
||||
{items.map((item, index) => renderItem(item, index))}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// Usage
|
||||
<List
|
||||
items={operations}
|
||||
renderItem={(op, i) => <OperationItem key={i} operation={op} />}
|
||||
renderEmpty={() => <EmptyState message="No operations" />}
|
||||
/>
|
||||
```
|
||||
@ -1,317 +0,0 @@
|
||||
# Hook Extraction Patterns
|
||||
|
||||
This document provides detailed guidance on extracting custom hooks from complex components in Dify.
|
||||
|
||||
## When to Extract Hooks
|
||||
|
||||
Extract a custom hook when you identify:
|
||||
|
||||
1. **Coupled state groups** - Multiple `useState` hooks that are always used together
|
||||
1. **Complex effects** - `useEffect` with multiple dependencies or cleanup logic
|
||||
1. **Business logic** - Data transformations, validations, or calculations
|
||||
1. **Reusable patterns** - Logic that appears in multiple components
|
||||
|
||||
## Extraction Process
|
||||
|
||||
### Step 1: Identify State Groups
|
||||
|
||||
Look for state variables that are logically related:
|
||||
|
||||
```typescript
|
||||
// ❌ These belong together - extract to hook
|
||||
const [modelConfig, setModelConfig] = useState<ModelConfig>(...)
|
||||
const [completionParams, setCompletionParams] = useState<FormValue>({})
|
||||
const [modelModeType, setModelModeType] = useState<ModelModeType>(...)
|
||||
|
||||
// These are model-related state that should be in useModelConfig()
|
||||
```
|
||||
|
||||
### Step 2: Identify Related Effects
|
||||
|
||||
Find effects that modify the grouped state:
|
||||
|
||||
```typescript
|
||||
// ❌ These effects belong with the state above
|
||||
useEffect(() => {
|
||||
if (hasFetchedDetail && !modelModeType) {
|
||||
const mode = currModel?.model_properties.mode
|
||||
if (mode) {
|
||||
const newModelConfig = produce(modelConfig, (draft) => {
|
||||
draft.mode = mode
|
||||
})
|
||||
setModelConfig(newModelConfig)
|
||||
}
|
||||
}
|
||||
}, [textGenerationModelList, hasFetchedDetail, modelModeType, currModel])
|
||||
```
|
||||
|
||||
### Step 3: Create the Hook
|
||||
|
||||
```typescript
|
||||
// hooks/use-model-config.ts
|
||||
import type { FormValue } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import type { ModelConfig } from '@/models/debug'
|
||||
import { produce } from 'immer'
|
||||
import { useEffect, useState } from 'react'
|
||||
import { ModelModeType } from '@/types/app'
|
||||
|
||||
interface UseModelConfigParams {
|
||||
initialConfig?: Partial<ModelConfig>
|
||||
currModel?: { model_properties?: { mode?: ModelModeType } }
|
||||
hasFetchedDetail: boolean
|
||||
}
|
||||
|
||||
interface UseModelConfigReturn {
|
||||
modelConfig: ModelConfig
|
||||
setModelConfig: (config: ModelConfig) => void
|
||||
completionParams: FormValue
|
||||
setCompletionParams: (params: FormValue) => void
|
||||
modelModeType: ModelModeType
|
||||
}
|
||||
|
||||
export const useModelConfig = ({
|
||||
initialConfig,
|
||||
currModel,
|
||||
hasFetchedDetail,
|
||||
}: UseModelConfigParams): UseModelConfigReturn => {
|
||||
const [modelConfig, setModelConfig] = useState<ModelConfig>({
|
||||
provider: 'langgenius/openai/openai',
|
||||
model_id: 'gpt-3.5-turbo',
|
||||
mode: ModelModeType.unset,
|
||||
// ... default values
|
||||
...initialConfig,
|
||||
})
|
||||
|
||||
const [completionParams, setCompletionParams] = useState<FormValue>({})
|
||||
|
||||
const modelModeType = modelConfig.mode
|
||||
|
||||
// Fill old app data missing model mode
|
||||
useEffect(() => {
|
||||
if (hasFetchedDetail && !modelModeType) {
|
||||
const mode = currModel?.model_properties?.mode
|
||||
if (mode) {
|
||||
setModelConfig(produce(modelConfig, (draft) => {
|
||||
draft.mode = mode
|
||||
}))
|
||||
}
|
||||
}
|
||||
}, [hasFetchedDetail, modelModeType, currModel])
|
||||
|
||||
return {
|
||||
modelConfig,
|
||||
setModelConfig,
|
||||
completionParams,
|
||||
setCompletionParams,
|
||||
modelModeType,
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Step 4: Update Component
|
||||
|
||||
```typescript
|
||||
// Before: 50+ lines of state management
|
||||
const Configuration: FC = () => {
|
||||
const [modelConfig, setModelConfig] = useState<ModelConfig>(...)
|
||||
// ... lots of related state and effects
|
||||
}
|
||||
|
||||
// After: Clean component
|
||||
const Configuration: FC = () => {
|
||||
const {
|
||||
modelConfig,
|
||||
setModelConfig,
|
||||
completionParams,
|
||||
setCompletionParams,
|
||||
modelModeType,
|
||||
} = useModelConfig({
|
||||
currModel,
|
||||
hasFetchedDetail,
|
||||
})
|
||||
|
||||
// Component now focuses on UI
|
||||
}
|
||||
```
|
||||
|
||||
## Naming Conventions
|
||||
|
||||
### Hook Names
|
||||
|
||||
- Use `use` prefix: `useModelConfig`, `useDatasetConfig`
|
||||
- Be specific: `useAdvancedPromptConfig` not `usePrompt`
|
||||
- Include domain: `useWorkflowVariables`, `useMCPServer`
|
||||
|
||||
### File Names
|
||||
|
||||
- Kebab-case: `use-model-config.ts`
|
||||
- Place in `hooks/` subdirectory when multiple hooks exist
|
||||
- Place alongside component for single-use hooks
|
||||
|
||||
### Return Type Names
|
||||
|
||||
- Suffix with `Return`: `UseModelConfigReturn`
|
||||
- Suffix params with `Params`: `UseModelConfigParams`
|
||||
|
||||
## Common Hook Patterns in Dify
|
||||
|
||||
### 1. Data Fetching Hook (React Query)
|
||||
|
||||
```typescript
|
||||
// Pattern: Use @tanstack/react-query for data fetching
|
||||
import { useQuery, useQueryClient } from '@tanstack/react-query'
|
||||
import { get } from '@/service/base'
|
||||
import { useInvalid } from '@/service/use-base'
|
||||
|
||||
const NAME_SPACE = 'appConfig'
|
||||
|
||||
// Query keys for cache management
|
||||
export const appConfigQueryKeys = {
|
||||
detail: (appId: string) => [NAME_SPACE, 'detail', appId] as const,
|
||||
}
|
||||
|
||||
// Main data hook
|
||||
export const useAppConfig = (appId: string) => {
|
||||
return useQuery({
|
||||
enabled: !!appId,
|
||||
queryKey: appConfigQueryKeys.detail(appId),
|
||||
queryFn: () => get<AppDetailResponse>(`/apps/${appId}`),
|
||||
select: data => data?.model_config || null,
|
||||
})
|
||||
}
|
||||
|
||||
// Invalidation hook for refreshing data
|
||||
export const useInvalidAppConfig = () => {
|
||||
return useInvalid([NAME_SPACE])
|
||||
}
|
||||
|
||||
// Usage in component
|
||||
const Component = () => {
|
||||
const { data: config, isLoading, error, refetch } = useAppConfig(appId)
|
||||
const invalidAppConfig = useInvalidAppConfig()
|
||||
|
||||
const handleRefresh = () => {
|
||||
invalidAppConfig() // Invalidates cache and triggers refetch
|
||||
}
|
||||
|
||||
return <div>...</div>
|
||||
}
|
||||
```
|
||||
|
||||
### 2. Form State Hook
|
||||
|
||||
```typescript
|
||||
// Pattern: Form state + validation + submission
|
||||
export const useConfigForm = (initialValues: ConfigFormValues) => {
|
||||
const [values, setValues] = useState(initialValues)
|
||||
const [errors, setErrors] = useState<Record<string, string>>({})
|
||||
const [isSubmitting, setIsSubmitting] = useState(false)
|
||||
|
||||
const validate = useCallback(() => {
|
||||
const newErrors: Record<string, string> = {}
|
||||
if (!values.name) newErrors.name = 'Name is required'
|
||||
setErrors(newErrors)
|
||||
return Object.keys(newErrors).length === 0
|
||||
}, [values])
|
||||
|
||||
const handleChange = useCallback((field: string, value: any) => {
|
||||
setValues(prev => ({ ...prev, [field]: value }))
|
||||
}, [])
|
||||
|
||||
const handleSubmit = useCallback(async (onSubmit: (values: ConfigFormValues) => Promise<void>) => {
|
||||
if (!validate()) return
|
||||
setIsSubmitting(true)
|
||||
try {
|
||||
await onSubmit(values)
|
||||
} finally {
|
||||
setIsSubmitting(false)
|
||||
}
|
||||
}, [values, validate])
|
||||
|
||||
return { values, errors, isSubmitting, handleChange, handleSubmit }
|
||||
}
|
||||
```
|
||||
|
||||
### 3. Modal State Hook
|
||||
|
||||
```typescript
|
||||
// Pattern: Multiple modal management
|
||||
type ModalType = 'edit' | 'delete' | 'duplicate' | null
|
||||
|
||||
export const useModalState = () => {
|
||||
const [activeModal, setActiveModal] = useState<ModalType>(null)
|
||||
const [modalData, setModalData] = useState<any>(null)
|
||||
|
||||
const openModal = useCallback((type: ModalType, data?: any) => {
|
||||
setActiveModal(type)
|
||||
setModalData(data)
|
||||
}, [])
|
||||
|
||||
const closeModal = useCallback(() => {
|
||||
setActiveModal(null)
|
||||
setModalData(null)
|
||||
}, [])
|
||||
|
||||
return {
|
||||
activeModal,
|
||||
modalData,
|
||||
openModal,
|
||||
closeModal,
|
||||
isOpen: useCallback((type: ModalType) => activeModal === type, [activeModal]),
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 4. Toggle/Boolean Hook
|
||||
|
||||
```typescript
|
||||
// Pattern: Boolean state with convenience methods
|
||||
export const useToggle = (initialValue = false) => {
|
||||
const [value, setValue] = useState(initialValue)
|
||||
|
||||
const toggle = useCallback(() => setValue(v => !v), [])
|
||||
const setTrue = useCallback(() => setValue(true), [])
|
||||
const setFalse = useCallback(() => setValue(false), [])
|
||||
|
||||
return [value, { toggle, setTrue, setFalse, set: setValue }] as const
|
||||
}
|
||||
|
||||
// Usage
|
||||
const [isExpanded, { toggle, setTrue: expand, setFalse: collapse }] = useToggle()
|
||||
```
|
||||
|
||||
## Testing Extracted Hooks
|
||||
|
||||
After extraction, test hooks in isolation:
|
||||
|
||||
```typescript
|
||||
// use-model-config.spec.ts
|
||||
import { renderHook, act } from '@testing-library/react'
|
||||
import { useModelConfig } from './use-model-config'
|
||||
|
||||
describe('useModelConfig', () => {
|
||||
it('should initialize with default values', () => {
|
||||
const { result } = renderHook(() => useModelConfig({
|
||||
hasFetchedDetail: false,
|
||||
}))
|
||||
|
||||
expect(result.current.modelConfig.provider).toBe('langgenius/openai/openai')
|
||||
expect(result.current.modelModeType).toBe(ModelModeType.unset)
|
||||
})
|
||||
|
||||
it('should update model config', () => {
|
||||
const { result } = renderHook(() => useModelConfig({
|
||||
hasFetchedDetail: true,
|
||||
}))
|
||||
|
||||
act(() => {
|
||||
result.current.setModelConfig({
|
||||
...result.current.modelConfig,
|
||||
model_id: 'gpt-4',
|
||||
})
|
||||
})
|
||||
|
||||
expect(result.current.modelConfig.model_id).toBe('gpt-4')
|
||||
})
|
||||
})
|
||||
```
|
||||
@ -318,5 +318,5 @@ For more detailed information, refer to:
|
||||
|
||||
- `web/vitest.config.ts` - Vitest configuration
|
||||
- `web/vitest.setup.ts` - Test environment setup
|
||||
- `web/scripts/analyze-component.js` - Component analysis tool
|
||||
- `web/testing/analyze-component.js` - Component analysis tool
|
||||
- Modules are not mocked automatically. Global mocks live in `web/vitest.setup.ts` (for example `react-i18next`, `next/image`); mock other modules like `ky` or `mime` locally in test files.
|
||||
|
||||
6
.github/workflows/api-tests.yml
vendored
6
.github/workflows/api-tests.yml
vendored
@ -22,12 +22,12 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: ${{ matrix.python-version }}
|
||||
@ -57,7 +57,7 @@ jobs:
|
||||
run: sh .github/workflows/expose_service_ports.sh
|
||||
|
||||
- name: Set up Sandbox
|
||||
uses: hoverkraft-tech/compose-action@v2
|
||||
uses: hoverkraft-tech/compose-action@v2.0.2
|
||||
with:
|
||||
compose-file: |
|
||||
docker/docker-compose.middleware.yaml
|
||||
|
||||
4
.github/workflows/autofix.yml
vendored
4
.github/workflows/autofix.yml
vendored
@ -12,7 +12,7 @@ jobs:
|
||||
if: github.repository == 'langgenius/dify'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Check Docker Compose inputs
|
||||
id: docker-compose-changes
|
||||
@ -27,7 +27,7 @@ jobs:
|
||||
with:
|
||||
python-version: "3.11"
|
||||
|
||||
- uses: astral-sh/setup-uv@v7
|
||||
- uses: astral-sh/setup-uv@v6
|
||||
|
||||
- name: Generate Docker Compose
|
||||
if: steps.docker-compose-changes.outputs.any_changed == 'true'
|
||||
|
||||
2
.github/workflows/build-push.yml
vendored
2
.github/workflows/build-push.yml
vendored
@ -90,7 +90,7 @@ jobs:
|
||||
touch "/tmp/digests/${sanitized_digest}"
|
||||
|
||||
- name: Upload digest
|
||||
uses: actions/upload-artifact@v6
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: digests-${{ matrix.context }}-${{ env.PLATFORM_PAIR }}
|
||||
path: /tmp/digests/*
|
||||
|
||||
8
.github/workflows/db-migration-test.yml
vendored
8
.github/workflows/db-migration-test.yml
vendored
@ -13,13 +13,13 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: "3.12"
|
||||
@ -63,13 +63,13 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: "3.12"
|
||||
|
||||
3
.github/workflows/main-ci.yml
vendored
3
.github/workflows/main-ci.yml
vendored
@ -27,7 +27,7 @@ jobs:
|
||||
vdb-changed: ${{ steps.changes.outputs.vdb }}
|
||||
migration-changed: ${{ steps.changes.outputs.migration }}
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
- uses: dorny/paths-filter@v3
|
||||
id: changes
|
||||
with:
|
||||
@ -38,7 +38,6 @@ jobs:
|
||||
- '.github/workflows/api-tests.yml'
|
||||
web:
|
||||
- 'web/**'
|
||||
- '.github/workflows/web-tests.yml'
|
||||
vdb:
|
||||
- 'api/core/rag/datasource/**'
|
||||
- 'docker/**'
|
||||
|
||||
22
.github/workflows/style.yml
vendored
22
.github/workflows/style.yml
vendored
@ -19,13 +19,13 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v47
|
||||
uses: tj-actions/changed-files@v46
|
||||
with:
|
||||
files: |
|
||||
api/**
|
||||
@ -33,7 +33,7 @@ jobs:
|
||||
|
||||
- name: Setup UV and Python
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
uses: astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
enable-cache: false
|
||||
python-version: "3.12"
|
||||
@ -68,17 +68,15 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v47
|
||||
uses: tj-actions/changed-files@v46
|
||||
with:
|
||||
files: |
|
||||
web/**
|
||||
.github/workflows/style.yml
|
||||
files: web/**
|
||||
|
||||
- name: Install pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
@ -87,10 +85,10 @@ jobs:
|
||||
run_install: false
|
||||
|
||||
- name: Setup NodeJS
|
||||
uses: actions/setup-node@v6
|
||||
uses: actions/setup-node@v4
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
with:
|
||||
node-version: 22
|
||||
node-version: 24
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/pnpm-lock.yaml
|
||||
|
||||
@ -116,14 +114,14 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v47
|
||||
uses: tj-actions/changed-files@v46
|
||||
with:
|
||||
files: |
|
||||
**.sh
|
||||
|
||||
13
.github/workflows/tool-test-sdks.yaml
vendored
13
.github/workflows/tool-test-sdks.yaml
vendored
@ -16,23 +16,24 @@ jobs:
|
||||
name: unit test for Node.js SDK
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
node-version: [16, 18, 20, 22]
|
||||
|
||||
defaults:
|
||||
run:
|
||||
working-directory: sdks/nodejs-client
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
<<<<<<< HEAD
|
||||
- name: Use Node.js ${{ matrix.node-version }}
|
||||
uses: actions/setup-node@v4
|
||||
=======
|
||||
- name: Use Node.js
|
||||
uses: actions/setup-node@v6
|
||||
>>>>>>> 328897f81c (build: require node 24.13.0 (#30945))
|
||||
with:
|
||||
node-version: ${{ matrix.node-version }}
|
||||
node-version: 24
|
||||
cache: ''
|
||||
cache-dependency-path: 'pnpm-lock.yaml'
|
||||
|
||||
|
||||
@ -18,7 +18,7 @@ jobs:
|
||||
run:
|
||||
working-directory: web
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
@ -51,7 +51,7 @@ jobs:
|
||||
|
||||
- name: Set up Node.js
|
||||
if: env.FILES_CHANGED == 'true'
|
||||
uses: actions/setup-node@v6
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: 'lts/*'
|
||||
cache: pnpm
|
||||
|
||||
421
.github/workflows/translate-i18n-claude.yml
vendored
Normal file
421
.github/workflows/translate-i18n-claude.yml
vendored
Normal file
@ -0,0 +1,421 @@
|
||||
name: Translate i18n Files with Claude Code
|
||||
|
||||
# Note: claude-code-action doesn't support push events directly.
|
||||
# Push events are handled by trigger-i18n-sync.yml which sends repository_dispatch.
|
||||
# See: https://github.com/langgenius/dify/issues/30743
|
||||
|
||||
on:
|
||||
repository_dispatch:
|
||||
types: [i18n-sync]
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
files:
|
||||
description: 'Specific files to translate (space-separated, e.g., "app common"). Leave empty for all files.'
|
||||
required: false
|
||||
type: string
|
||||
languages:
|
||||
description: 'Specific languages to translate (space-separated, e.g., "zh-Hans ja-JP"). Leave empty for all supported languages.'
|
||||
required: false
|
||||
type: string
|
||||
mode:
|
||||
description: 'Sync mode: incremental (only changes) or full (re-check all keys)'
|
||||
required: false
|
||||
default: 'incremental'
|
||||
type: choice
|
||||
options:
|
||||
- incremental
|
||||
- full
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
translate:
|
||||
if: github.repository == 'langgenius/dify'
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 60
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Configure Git
|
||||
run: |
|
||||
git config --global user.name "github-actions[bot]"
|
||||
git config --global user.email "github-actions[bot]@users.noreply.github.com"
|
||||
|
||||
- name: Install pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
with:
|
||||
package_json_file: web/package.json
|
||||
run_install: false
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: 24
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/pnpm-lock.yaml
|
||||
|
||||
- name: Detect changed files and generate diff
|
||||
id: detect_changes
|
||||
run: |
|
||||
if [ "${{ github.event_name }}" == "workflow_dispatch" ]; then
|
||||
# Manual trigger
|
||||
if [ -n "${{ github.event.inputs.files }}" ]; then
|
||||
echo "CHANGED_FILES=${{ github.event.inputs.files }}" >> $GITHUB_OUTPUT
|
||||
else
|
||||
# Get all JSON files in en-US directory
|
||||
files=$(ls web/i18n/en-US/*.json 2>/dev/null | xargs -n1 basename | sed 's/.json$//' | tr '\n' ' ')
|
||||
echo "CHANGED_FILES=$files" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
echo "TARGET_LANGS=${{ github.event.inputs.languages }}" >> $GITHUB_OUTPUT
|
||||
echo "SYNC_MODE=${{ github.event.inputs.mode || 'incremental' }}" >> $GITHUB_OUTPUT
|
||||
|
||||
# For manual trigger with incremental mode, get diff from last commit
|
||||
# For full mode, we'll do a complete check anyway
|
||||
if [ "${{ github.event.inputs.mode }}" == "full" ]; then
|
||||
echo "Full mode: will check all keys" > /tmp/i18n-diff.txt
|
||||
echo "DIFF_AVAILABLE=false" >> $GITHUB_OUTPUT
|
||||
else
|
||||
git diff HEAD~1..HEAD -- 'web/i18n/en-US/*.json' > /tmp/i18n-diff.txt 2>/dev/null || echo "" > /tmp/i18n-diff.txt
|
||||
if [ -s /tmp/i18n-diff.txt ]; then
|
||||
echo "DIFF_AVAILABLE=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "DIFF_AVAILABLE=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
fi
|
||||
elif [ "${{ github.event_name }}" == "repository_dispatch" ]; then
|
||||
# Triggered by push via trigger-i18n-sync.yml workflow
|
||||
# Validate required payload fields
|
||||
if [ -z "${{ github.event.client_payload.changed_files }}" ]; then
|
||||
echo "Error: repository_dispatch payload missing required 'changed_files' field" >&2
|
||||
exit 1
|
||||
fi
|
||||
echo "CHANGED_FILES=${{ github.event.client_payload.changed_files }}" >> $GITHUB_OUTPUT
|
||||
echo "TARGET_LANGS=" >> $GITHUB_OUTPUT
|
||||
echo "SYNC_MODE=${{ github.event.client_payload.sync_mode || 'incremental' }}" >> $GITHUB_OUTPUT
|
||||
|
||||
# Decode the base64-encoded diff from the trigger workflow
|
||||
if [ -n "${{ github.event.client_payload.diff_base64 }}" ]; then
|
||||
if ! echo "${{ github.event.client_payload.diff_base64 }}" | base64 -d > /tmp/i18n-diff.txt 2>&1; then
|
||||
echo "Warning: Failed to decode base64 diff payload" >&2
|
||||
echo "" > /tmp/i18n-diff.txt
|
||||
echo "DIFF_AVAILABLE=false" >> $GITHUB_OUTPUT
|
||||
elif [ -s /tmp/i18n-diff.txt ]; then
|
||||
echo "DIFF_AVAILABLE=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "DIFF_AVAILABLE=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
else
|
||||
echo "" > /tmp/i18n-diff.txt
|
||||
echo "DIFF_AVAILABLE=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
else
|
||||
echo "Unsupported event type: ${{ github.event_name }}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Truncate diff if too large (keep first 50KB)
|
||||
if [ -f /tmp/i18n-diff.txt ]; then
|
||||
head -c 50000 /tmp/i18n-diff.txt > /tmp/i18n-diff-truncated.txt
|
||||
mv /tmp/i18n-diff-truncated.txt /tmp/i18n-diff.txt
|
||||
fi
|
||||
|
||||
echo "Detected files: $(cat $GITHUB_OUTPUT | grep CHANGED_FILES || echo 'none')"
|
||||
|
||||
- name: Run Claude Code for Translation Sync
|
||||
if: steps.detect_changes.outputs.CHANGED_FILES != ''
|
||||
uses: anthropics/claude-code-action@v1
|
||||
with:
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
prompt: |
|
||||
You are a professional i18n synchronization engineer for the Dify project.
|
||||
Your task is to keep all language translations in sync with the English source (en-US).
|
||||
|
||||
## CRITICAL TOOL RESTRICTIONS
|
||||
- Use **Read** tool to read files (NOT cat or bash)
|
||||
- Use **Edit** tool to modify JSON files (NOT node, jq, or bash scripts)
|
||||
- Use **Bash** ONLY for: git commands, gh commands, pnpm commands
|
||||
- Run bash commands ONE BY ONE, never combine with && or ||
|
||||
- NEVER use `$()` command substitution - it's not supported. Split into separate commands instead.
|
||||
|
||||
## WORKING DIRECTORY & ABSOLUTE PATHS
|
||||
Claude Code sandbox working directory may vary. Always use absolute paths:
|
||||
- For pnpm: `pnpm --dir ${{ github.workspace }}/web <command>`
|
||||
- For git: `git -C ${{ github.workspace }} <command>`
|
||||
- For gh: `gh --repo ${{ github.repository }} <command>`
|
||||
- For file paths: `${{ github.workspace }}/web/i18n/`
|
||||
|
||||
## EFFICIENCY RULES
|
||||
- **ONE Edit per language file** - batch all key additions into a single Edit
|
||||
- Insert new keys at the beginning of JSON (after `{`), lint:fix will sort them
|
||||
- Translate ALL keys for a language mentally first, then do ONE Edit
|
||||
|
||||
## Context
|
||||
- Changed/target files: ${{ steps.detect_changes.outputs.CHANGED_FILES }}
|
||||
- Target languages (empty means all supported): ${{ steps.detect_changes.outputs.TARGET_LANGS }}
|
||||
- Sync mode: ${{ steps.detect_changes.outputs.SYNC_MODE }}
|
||||
- Translation files are located in: ${{ github.workspace }}/web/i18n/{locale}/{filename}.json
|
||||
- Language configuration is in: ${{ github.workspace }}/web/i18n-config/languages.ts
|
||||
- Git diff is available: ${{ steps.detect_changes.outputs.DIFF_AVAILABLE }}
|
||||
|
||||
## CRITICAL DESIGN: Verify First, Then Sync
|
||||
|
||||
You MUST follow this three-phase approach:
|
||||
|
||||
═══════════════════════════════════════════════════════════════
|
||||
║ PHASE 1: VERIFY - Analyze and Generate Change Report ║
|
||||
═══════════════════════════════════════════════════════════════
|
||||
|
||||
### Step 1.1: Analyze Git Diff (for incremental mode)
|
||||
Use the Read tool to read `/tmp/i18n-diff.txt` to see the git diff.
|
||||
|
||||
Parse the diff to categorize changes:
|
||||
- Lines with `+` (not `+++`): Added or modified values
|
||||
- Lines with `-` (not `---`): Removed or old values
|
||||
- Identify specific keys for each category:
|
||||
* ADD: Keys that appear only in `+` lines (new keys)
|
||||
* UPDATE: Keys that appear in both `-` and `+` lines (value changed)
|
||||
* DELETE: Keys that appear only in `-` lines (removed keys)
|
||||
|
||||
### Step 1.2: Read Language Configuration
|
||||
Use the Read tool to read `${{ github.workspace }}/web/i18n-config/languages.ts`.
|
||||
Extract all languages with `supported: true`.
|
||||
|
||||
### Step 1.3: Run i18n:check for Each Language
|
||||
```bash
|
||||
pnpm --dir ${{ github.workspace }}/web install --frozen-lockfile
|
||||
```
|
||||
```bash
|
||||
pnpm --dir ${{ github.workspace }}/web run i18n:check
|
||||
```
|
||||
|
||||
This will report:
|
||||
- Missing keys (need to ADD)
|
||||
- Extra keys (need to DELETE)
|
||||
|
||||
### Step 1.4: Generate Change Report
|
||||
|
||||
Create a structured report identifying:
|
||||
```
|
||||
╔══════════════════════════════════════════════════════════════╗
|
||||
║ I18N SYNC CHANGE REPORT ║
|
||||
╠══════════════════════════════════════════════════════════════╣
|
||||
║ Files to process: [list] ║
|
||||
║ Languages to sync: [list] ║
|
||||
╠══════════════════════════════════════════════════════════════╣
|
||||
║ ADD (New Keys): ║
|
||||
║ - [filename].[key]: "English value" ║
|
||||
║ ... ║
|
||||
╠══════════════════════════════════════════════════════════════╣
|
||||
║ UPDATE (Modified Keys - MUST re-translate): ║
|
||||
║ - [filename].[key]: "Old value" → "New value" ║
|
||||
║ ... ║
|
||||
╠══════════════════════════════════════════════════════════════╣
|
||||
║ DELETE (Extra Keys): ║
|
||||
║ - [language]/[filename].[key] ║
|
||||
║ ... ║
|
||||
╚══════════════════════════════════════════════════════════════╝
|
||||
```
|
||||
|
||||
**IMPORTANT**: For UPDATE detection, compare git diff to find keys where
|
||||
the English value changed. These MUST be re-translated even if target
|
||||
language already has a translation (it's now stale!).
|
||||
|
||||
═══════════════════════════════════════════════════════════════
|
||||
║ PHASE 2: SYNC - Execute Changes Based on Report ║
|
||||
═══════════════════════════════════════════════════════════════
|
||||
|
||||
### Step 2.1: Process ADD Operations (BATCH per language file)
|
||||
|
||||
**CRITICAL WORKFLOW for efficiency:**
|
||||
1. First, translate ALL new keys for ALL languages mentally
|
||||
2. Then, for EACH language file, do ONE Edit operation:
|
||||
- Read the file once
|
||||
- Insert ALL new keys at the beginning (right after the opening `{`)
|
||||
- Don't worry about alphabetical order - lint:fix will sort them later
|
||||
|
||||
Example Edit (adding 3 keys to zh-Hans/app.json):
|
||||
```
|
||||
old_string: '{\n "accessControl"'
|
||||
new_string: '{\n "newKey1": "translation1",\n "newKey2": "translation2",\n "newKey3": "translation3",\n "accessControl"'
|
||||
```
|
||||
|
||||
**IMPORTANT**:
|
||||
- ONE Edit per language file (not one Edit per key!)
|
||||
- Always use the Edit tool. NEVER use bash scripts, node, or jq.
|
||||
|
||||
### Step 2.2: Process UPDATE Operations
|
||||
|
||||
**IMPORTANT: Special handling for zh-Hans and ja-JP**
|
||||
If zh-Hans or ja-JP files were ALSO modified in the same push:
|
||||
- Run: `git -C ${{ github.workspace }} diff HEAD~1 --name-only` and check for zh-Hans or ja-JP files
|
||||
- If found, it means someone manually translated them. Apply these rules:
|
||||
|
||||
1. **Missing keys**: Still ADD them (completeness required)
|
||||
2. **Existing translations**: Compare with the NEW English value:
|
||||
- If translation is **completely wrong** or **unrelated** → Update it
|
||||
- If translation is **roughly correct** (captures the meaning) → Keep it, respect manual work
|
||||
- When in doubt, **keep the manual translation**
|
||||
|
||||
Example:
|
||||
- English changed: "Save" → "Save Changes"
|
||||
- Manual translation: "保存更改" → Keep it (correct meaning)
|
||||
- Manual translation: "删除" → Update it (completely wrong)
|
||||
|
||||
For other languages:
|
||||
Use Edit tool to replace the old value with the new translation.
|
||||
You can batch multiple updates in one Edit if they are adjacent.
|
||||
|
||||
### Step 2.3: Process DELETE Operations
|
||||
For extra keys reported by i18n:check:
|
||||
- Run: `pnpm --dir ${{ github.workspace }}/web run i18n:check --auto-remove`
|
||||
- Or manually remove from target language JSON files
|
||||
|
||||
## Translation Guidelines
|
||||
|
||||
- PRESERVE all placeholders exactly as-is:
|
||||
- `{{variable}}` - Mustache interpolation
|
||||
- `${variable}` - Template literal
|
||||
- `<tag>content</tag>` - HTML tags
|
||||
- `_one`, `_other` - Pluralization suffixes (these are KEY suffixes, not values)
|
||||
- Use appropriate language register (formal/informal) based on existing translations
|
||||
- Match existing translation style in each language
|
||||
- Technical terms: check existing conventions per language
|
||||
- For CJK languages: no spaces between characters unless necessary
|
||||
- For RTL languages (ar-TN, fa-IR): ensure proper text handling
|
||||
|
||||
## Output Format Requirements
|
||||
- Alphabetical key ordering (if original file uses it)
|
||||
- 2-space indentation
|
||||
- Trailing newline at end of file
|
||||
- Valid JSON (use proper escaping for special characters)
|
||||
|
||||
═══════════════════════════════════════════════════════════════
|
||||
║ PHASE 3: RE-VERIFY - Confirm All Issues Resolved ║
|
||||
═══════════════════════════════════════════════════════════════
|
||||
|
||||
### Step 3.1: Run Lint Fix (IMPORTANT!)
|
||||
```bash
|
||||
pnpm --dir ${{ github.workspace }}/web lint:fix --quiet -- 'i18n/**/*.json'
|
||||
```
|
||||
This ensures:
|
||||
- JSON keys are sorted alphabetically (jsonc/sort-keys rule)
|
||||
- Valid i18n keys (dify-i18n/valid-i18n-keys rule)
|
||||
- No extra keys (dify-i18n/no-extra-keys rule)
|
||||
|
||||
### Step 3.2: Run Final i18n Check
|
||||
```bash
|
||||
pnpm --dir ${{ github.workspace }}/web run i18n:check
|
||||
```
|
||||
|
||||
### Step 3.3: Fix Any Remaining Issues
|
||||
If check reports issues:
|
||||
- Go back to PHASE 2 for unresolved items
|
||||
- Repeat until check passes
|
||||
|
||||
### Step 3.4: Generate Final Summary
|
||||
```
|
||||
╔══════════════════════════════════════════════════════════════╗
|
||||
║ SYNC COMPLETED SUMMARY ║
|
||||
╠══════════════════════════════════════════════════════════════╣
|
||||
║ Language │ Added │ Updated │ Deleted │ Status ║
|
||||
╠══════════════════════════════════════════════════════════════╣
|
||||
║ zh-Hans │ 5 │ 2 │ 1 │ ✓ Complete ║
|
||||
║ ja-JP │ 5 │ 2 │ 1 │ ✓ Complete ║
|
||||
║ ... │ ... │ ... │ ... │ ... ║
|
||||
╠══════════════════════════════════════════════════════════════╣
|
||||
║ i18n:check │ PASSED - All keys in sync ║
|
||||
╚══════════════════════════════════════════════════════════════╝
|
||||
```
|
||||
|
||||
## Mode-Specific Behavior
|
||||
|
||||
**SYNC_MODE = "incremental"** (default):
|
||||
- Focus on keys identified from git diff
|
||||
- Also check i18n:check output for any missing/extra keys
|
||||
- Efficient for small changes
|
||||
|
||||
**SYNC_MODE = "full"**:
|
||||
- Compare ALL keys between en-US and each language
|
||||
- Run i18n:check to identify all discrepancies
|
||||
- Use for first-time sync or fixing historical issues
|
||||
|
||||
## Important Notes
|
||||
|
||||
1. Always run i18n:check BEFORE and AFTER making changes
|
||||
2. The check script is the source of truth for missing/extra keys
|
||||
3. For UPDATE scenario: git diff is the source of truth for changed values
|
||||
4. Create a single commit with all translation changes
|
||||
5. If any translation fails, continue with others and report failures
|
||||
|
||||
═══════════════════════════════════════════════════════════════
|
||||
║ PHASE 4: COMMIT AND CREATE PR ║
|
||||
═══════════════════════════════════════════════════════════════
|
||||
|
||||
After all translations are complete and verified:
|
||||
|
||||
### Step 4.1: Check for changes
|
||||
```bash
|
||||
git -C ${{ github.workspace }} status --porcelain
|
||||
```
|
||||
|
||||
If there are changes:
|
||||
|
||||
### Step 4.2: Create a new branch and commit
|
||||
Run these git commands ONE BY ONE (not combined with &&).
|
||||
**IMPORTANT**: Do NOT use `$()` command substitution. Use two separate commands:
|
||||
|
||||
1. First, get the timestamp:
|
||||
```bash
|
||||
date +%Y%m%d-%H%M%S
|
||||
```
|
||||
(Note the output, e.g., "20260115-143052")
|
||||
|
||||
2. Then create branch using the timestamp value:
|
||||
```bash
|
||||
git -C ${{ github.workspace }} checkout -b chore/i18n-sync-20260115-143052
|
||||
```
|
||||
(Replace "20260115-143052" with the actual timestamp from step 1)
|
||||
|
||||
3. Stage changes:
|
||||
```bash
|
||||
git -C ${{ github.workspace }} add web/i18n/
|
||||
```
|
||||
|
||||
4. Commit:
|
||||
```bash
|
||||
git -C ${{ github.workspace }} commit -m "chore(i18n): sync translations with en-US - Mode: ${{ steps.detect_changes.outputs.SYNC_MODE }}"
|
||||
```
|
||||
|
||||
5. Push:
|
||||
```bash
|
||||
git -C ${{ github.workspace }} push origin HEAD
|
||||
```
|
||||
|
||||
### Step 4.3: Create Pull Request
|
||||
```bash
|
||||
gh pr create --repo ${{ github.repository }} --title "chore(i18n): sync translations with en-US" --body "## Summary
|
||||
|
||||
This PR was automatically generated to sync i18n translation files.
|
||||
|
||||
### Changes
|
||||
- Mode: ${{ steps.detect_changes.outputs.SYNC_MODE }}
|
||||
- Files processed: ${{ steps.detect_changes.outputs.CHANGED_FILES }}
|
||||
|
||||
### Verification
|
||||
- [x] \`i18n:check\` passed
|
||||
- [x] \`lint:fix\` applied
|
||||
|
||||
🤖 Generated with Claude Code GitHub Action" --base main
|
||||
```
|
||||
|
||||
claude_args: |
|
||||
--max-turns 150
|
||||
--allowedTools "Read,Write,Edit,Bash(git *),Bash(git:*),Bash(gh *),Bash(gh:*),Bash(pnpm *),Bash(pnpm:*),Bash(date *),Bash(date:*),Glob,Grep"
|
||||
6
.github/workflows/vdb-tests.yml
vendored
6
.github/workflows/vdb-tests.yml
vendored
@ -19,19 +19,19 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Free Disk Space
|
||||
uses: endersonmenezes/free-disk-space@v3
|
||||
uses: endersonmenezes/free-disk-space@v2
|
||||
with:
|
||||
remove_dotnet: true
|
||||
remove_haskell: true
|
||||
remove_tool_cache: true
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
8
.github/workflows/web-tests.yml
vendored
8
.github/workflows/web-tests.yml
vendored
@ -18,7 +18,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@ -29,9 +29,9 @@ jobs:
|
||||
run_install: false
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v6
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: 22
|
||||
node-version: 24
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/pnpm-lock.yaml
|
||||
|
||||
@ -360,7 +360,7 @@ jobs:
|
||||
|
||||
- name: Upload Coverage Artifact
|
||||
if: steps.coverage-summary.outputs.has_coverage == 'true'
|
||||
uses: actions/upload-artifact@v6
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: web-coverage-report
|
||||
path: web/coverage
|
||||
|
||||
34
.mcp.json
Normal file
34
.mcp.json
Normal file
@ -0,0 +1,34 @@
|
||||
{
|
||||
"mcpServers": {
|
||||
"context7": {
|
||||
"type": "http",
|
||||
"url": "https://mcp.context7.com/mcp"
|
||||
},
|
||||
"sequential-thinking": {
|
||||
"type": "stdio",
|
||||
"command": "npx",
|
||||
"args": ["-y", "@modelcontextprotocol/server-sequential-thinking"],
|
||||
"env": {}
|
||||
},
|
||||
"github": {
|
||||
"type": "stdio",
|
||||
"command": "npx",
|
||||
"args": ["-y", "@modelcontextprotocol/server-github"],
|
||||
"env": {
|
||||
"GITHUB_PERSONAL_ACCESS_TOKEN": "${GITHUB_PERSONAL_ACCESS_TOKEN}"
|
||||
}
|
||||
},
|
||||
"fetch": {
|
||||
"type": "stdio",
|
||||
"command": "uvx",
|
||||
"args": ["mcp-server-fetch"],
|
||||
"env": {}
|
||||
},
|
||||
"playwright": {
|
||||
"type": "stdio",
|
||||
"command": "npx",
|
||||
"args": ["-y", "@playwright/mcp@latest"],
|
||||
"env": {}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -50,16 +50,33 @@ WORKDIR /app/api
|
||||
|
||||
# Create non-root user
|
||||
ARG dify_uid=1001
|
||||
ARG NODE_MAJOR=22
|
||||
ARG NODE_PACKAGE_VERSION=22.21.0-1nodesource1
|
||||
ARG NODESOURCE_KEY_FPR=6F71F525282841EEDAF851B42F59B5F99B1BE0B4
|
||||
RUN groupadd -r -g ${dify_uid} dify && \
|
||||
useradd -r -u ${dify_uid} -g ${dify_uid} -s /bin/bash dify && \
|
||||
chown -R dify:dify /app
|
||||
|
||||
RUN \
|
||||
apt-get update \
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
ca-certificates \
|
||||
curl \
|
||||
gnupg \
|
||||
&& mkdir -p /etc/apt/keyrings \
|
||||
&& curl -fsSL https://deb.nodesource.com/gpgkey/nodesource-repo.gpg.key -o /tmp/nodesource.gpg \
|
||||
&& gpg --show-keys --with-colons /tmp/nodesource.gpg \
|
||||
| awk -F: '/^fpr:/ {print $10}' \
|
||||
| grep -Fx "${NODESOURCE_KEY_FPR}" \
|
||||
&& gpg --dearmor -o /etc/apt/keyrings/nodesource.gpg /tmp/nodesource.gpg \
|
||||
&& rm -f /tmp/nodesource.gpg \
|
||||
&& echo "deb [signed-by=/etc/apt/keyrings/nodesource.gpg] https://deb.nodesource.com/node_${NODE_MAJOR}.x nodistro main" \
|
||||
> /etc/apt/sources.list.d/nodesource.list \
|
||||
&& apt-get update \
|
||||
# Install dependencies
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
# basic environment
|
||||
curl nodejs \
|
||||
nodejs=${NODE_PACKAGE_VERSION} \
|
||||
# for gmpy2 \
|
||||
libgmp-dev libmpfr-dev libmpc-dev \
|
||||
# For Security
|
||||
|
||||
@ -1,9 +1,8 @@
|
||||
import base64
|
||||
from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from controllers.console import console_ns
|
||||
@ -16,8 +15,22 @@ DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class SubscriptionQuery(BaseModel):
|
||||
plan: Literal[CloudPlan.PROFESSIONAL, CloudPlan.TEAM] = Field(..., description="Subscription plan")
|
||||
interval: Literal["month", "year"] = Field(..., description="Billing interval")
|
||||
plan: str = Field(..., description="Subscription plan")
|
||||
interval: str = Field(..., description="Billing interval")
|
||||
|
||||
@field_validator("plan")
|
||||
@classmethod
|
||||
def validate_plan(cls, value: str) -> str:
|
||||
if value not in [CloudPlan.PROFESSIONAL, CloudPlan.TEAM]:
|
||||
raise ValueError("Invalid plan")
|
||||
return value
|
||||
|
||||
@field_validator("interval")
|
||||
@classmethod
|
||||
def validate_interval(cls, value: str) -> str:
|
||||
if value not in {"month", "year"}:
|
||||
raise ValueError("Invalid interval")
|
||||
return value
|
||||
|
||||
|
||||
class PartnerTenantsPayload(BaseModel):
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import logging
|
||||
from typing import Literal
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import marshal_with
|
||||
@ -25,7 +26,6 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from fields.message_fields import message_infinite_scroll_pagination_fields
|
||||
from libs import helper
|
||||
from libs.helper import UUIDStrOrEmpty
|
||||
from libs.login import current_account_with_tenant
|
||||
from models.model import AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
@ -44,8 +44,8 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MessageListQuery(BaseModel):
|
||||
conversation_id: UUIDStrOrEmpty
|
||||
first_id: UUIDStrOrEmpty | None = None
|
||||
conversation_id: UUID
|
||||
first_id: UUID | None = None
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
|
||||
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import fields, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
@ -8,19 +10,19 @@ from controllers.console import console_ns
|
||||
from controllers.console.explore.error import NotCompletionAppError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from fields.conversation_fields import message_file_fields
|
||||
from libs.helper import TimestampField, UUIDStrOrEmpty
|
||||
from libs.helper import TimestampField
|
||||
from libs.login import current_account_with_tenant
|
||||
from services.errors.message import MessageNotExistsError
|
||||
from services.saved_message_service import SavedMessageService
|
||||
|
||||
|
||||
class SavedMessageListQuery(BaseModel):
|
||||
last_id: UUIDStrOrEmpty | None = None
|
||||
last_id: UUID | None = None
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
|
||||
|
||||
class SavedMessageCreatePayload(BaseModel):
|
||||
message_id: UUIDStrOrEmpty
|
||||
message_id: UUID
|
||||
|
||||
|
||||
register_schema_models(console_ns, SavedMessageListQuery, SavedMessageCreatePayload)
|
||||
|
||||
@ -1,8 +1,6 @@
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel
|
||||
from flask_restx import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
@ -12,20 +10,10 @@ from models import TenantAccountRole
|
||||
from services.model_load_balancing_service import ModelLoadBalancingService
|
||||
|
||||
|
||||
class LoadBalancingCredentialPayload(BaseModel):
|
||||
model: str
|
||||
model_type: ModelType
|
||||
credentials: dict[str, object]
|
||||
|
||||
|
||||
register_schema_models(console_ns, LoadBalancingCredentialPayload)
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/credentials-validate"
|
||||
)
|
||||
class LoadBalancingCredentialsValidateApi(Resource):
|
||||
@console_ns.expect(console_ns.models[LoadBalancingCredentialPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -36,7 +24,20 @@ class LoadBalancingCredentialsValidateApi(Resource):
|
||||
|
||||
tenant_id = current_tenant_id
|
||||
|
||||
payload = LoadBalancingCredentialPayload.model_validate(console_ns.payload or {})
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# validate model load balancing credentials
|
||||
model_load_balancing_service = ModelLoadBalancingService()
|
||||
@ -48,9 +49,9 @@ class LoadBalancingCredentialsValidateApi(Resource):
|
||||
model_load_balancing_service.validate_load_balancing_credentials(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model=payload.model,
|
||||
model_type=payload.model_type,
|
||||
credentials=payload.credentials,
|
||||
model=args["model"],
|
||||
model_type=args["model_type"],
|
||||
credentials=args["credentials"],
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
result = False
|
||||
@ -68,7 +69,6 @@ class LoadBalancingCredentialsValidateApi(Resource):
|
||||
"/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/<string:config_id>/credentials-validate"
|
||||
)
|
||||
class LoadBalancingConfigCredentialsValidateApi(Resource):
|
||||
@console_ns.expect(console_ns.models[LoadBalancingCredentialPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -79,7 +79,20 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
|
||||
|
||||
tenant_id = current_tenant_id
|
||||
|
||||
payload = LoadBalancingCredentialPayload.model_validate(console_ns.payload or {})
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# validate model load balancing config credentials
|
||||
model_load_balancing_service = ModelLoadBalancingService()
|
||||
@ -91,9 +104,9 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
|
||||
model_load_balancing_service.validate_load_balancing_credentials(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model=payload.model,
|
||||
model_type=payload.model_type,
|
||||
credentials=payload.credentials,
|
||||
model=args["model"],
|
||||
model_type=args["model_type"],
|
||||
credentials=args["credentials"],
|
||||
config_id=config_id,
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import io
|
||||
import logging
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from flask import make_response, redirect, request, send_file
|
||||
@ -18,7 +17,6 @@ from controllers.console.wraps import (
|
||||
is_admin_or_owner_required,
|
||||
setup_required,
|
||||
)
|
||||
from core.db.session_factory import session_factory
|
||||
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration
|
||||
from core.helper.tool_provider_cache import ToolProviderListCache
|
||||
from core.mcp.auth.auth_flow import auth, handle_callback
|
||||
@ -42,8 +40,6 @@ from services.tools.tools_manage_service import ToolCommonService
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
from services.tools.workflow_tools_manage_service import WorkflowToolManageService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_valid_url(url: str) -> bool:
|
||||
if not url:
|
||||
@ -949,8 +945,8 @@ class ToolProviderMCPApi(Resource):
|
||||
configuration = MCPConfiguration.model_validate(args["configuration"])
|
||||
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
|
||||
|
||||
# 1) Create provider in a short transaction (no network I/O inside)
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
# Create provider in transaction
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
result = service.create_provider(
|
||||
tenant_id=tenant_id,
|
||||
@ -966,28 +962,7 @@ class ToolProviderMCPApi(Resource):
|
||||
authentication=authentication,
|
||||
)
|
||||
|
||||
# 2) Try to fetch tools immediately after creation so they appear without a second save.
|
||||
# Perform network I/O outside any DB session to avoid holding locks.
|
||||
try:
|
||||
reconnect = MCPToolManageService.reconnect_with_url(
|
||||
server_url=args["server_url"],
|
||||
headers=args.get("headers") or {},
|
||||
timeout=configuration.timeout,
|
||||
sse_read_timeout=configuration.sse_read_timeout,
|
||||
)
|
||||
# Update just-created provider with authed/tools in a new short transaction
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
db_provider = service.get_provider(provider_id=result.id, tenant_id=tenant_id)
|
||||
db_provider.authed = reconnect.authed
|
||||
db_provider.tools = reconnect.tools
|
||||
|
||||
result = ToolTransformService.mcp_provider_to_user_provider(db_provider, for_list=True)
|
||||
except Exception:
|
||||
# Best-effort: if initial fetch fails (e.g., auth required), return created provider as-is
|
||||
logger.warning("Failed to fetch MCP tools after creation", exc_info=True)
|
||||
|
||||
# Final cache invalidation to ensure list views are up to date
|
||||
# Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
|
||||
return jsonable_encoder(result)
|
||||
|
||||
@ -13,6 +13,7 @@ from controllers.service_api.dataset.error import DatasetInUseError, DatasetName
|
||||
from controllers.service_api.wraps import (
|
||||
DatasetApiResource,
|
||||
cloud_edition_billing_rate_limit_check,
|
||||
validate_dataset_token,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.provider_manager import ProviderManager
|
||||
@ -459,8 +460,9 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
401: "Unauthorized - invalid API token",
|
||||
}
|
||||
)
|
||||
@validate_dataset_token
|
||||
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
|
||||
def get(self, _):
|
||||
def get(self, _, dataset_id):
|
||||
"""Get all knowledge type tags."""
|
||||
assert isinstance(current_user, Account)
|
||||
cid = current_user.current_tenant_id
|
||||
@ -480,7 +482,8 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
}
|
||||
)
|
||||
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
|
||||
def post(self, _):
|
||||
@validate_dataset_token
|
||||
def post(self, _, dataset_id):
|
||||
"""Add a knowledge type tag."""
|
||||
assert isinstance(current_user, Account)
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
@ -503,7 +506,8 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
}
|
||||
)
|
||||
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
|
||||
def patch(self, _):
|
||||
@validate_dataset_token
|
||||
def patch(self, _, dataset_id):
|
||||
assert isinstance(current_user, Account)
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
@ -529,8 +533,9 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
403: "Forbidden - insufficient permissions",
|
||||
}
|
||||
)
|
||||
@validate_dataset_token
|
||||
@edit_permission_required
|
||||
def delete(self, _):
|
||||
def delete(self, _, dataset_id):
|
||||
"""Delete a knowledge type tag."""
|
||||
payload = TagDeletePayload.model_validate(service_api_ns.payload or {})
|
||||
TagService.delete_tag(payload.tag_id)
|
||||
@ -550,7 +555,8 @@ class DatasetTagBindingApi(DatasetApiResource):
|
||||
403: "Forbidden - insufficient permissions",
|
||||
}
|
||||
)
|
||||
def post(self, _):
|
||||
@validate_dataset_token
|
||||
def post(self, _, dataset_id):
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
assert isinstance(current_user, Account)
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
@ -574,7 +580,8 @@ class DatasetTagUnbindingApi(DatasetApiResource):
|
||||
403: "Forbidden - insufficient permissions",
|
||||
}
|
||||
)
|
||||
def post(self, _):
|
||||
@validate_dataset_token
|
||||
def post(self, _, dataset_id):
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
assert isinstance(current_user, Account)
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
@ -597,6 +604,7 @@ class DatasetTagsBindingStatusApi(DatasetApiResource):
|
||||
401: "Unauthorized - invalid API token",
|
||||
}
|
||||
)
|
||||
@validate_dataset_token
|
||||
def get(self, _, *args, **kwargs):
|
||||
"""Get all knowledge type tags."""
|
||||
dataset_id = kwargs.get("dataset_id")
|
||||
|
||||
@ -10,7 +10,12 @@ from controllers.console.auth.error import (
|
||||
InvalidEmailError,
|
||||
)
|
||||
from controllers.console.error import AccountBannedError
|
||||
from controllers.console.wraps import only_edition_enterprise, setup_required
|
||||
from controllers.console.wraps import (
|
||||
decrypt_code_field,
|
||||
decrypt_password_field,
|
||||
only_edition_enterprise,
|
||||
setup_required,
|
||||
)
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.wraps import decode_jwt_token
|
||||
from libs.helper import email
|
||||
@ -42,6 +47,7 @@ class LoginApi(Resource):
|
||||
404: "Account not found",
|
||||
}
|
||||
)
|
||||
@decrypt_password_field
|
||||
def post(self):
|
||||
"""Authenticate user and login."""
|
||||
parser = (
|
||||
@ -181,6 +187,7 @@ class EmailCodeLoginApi(Resource):
|
||||
404: "Account not found",
|
||||
}
|
||||
)
|
||||
@decrypt_code_field
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
|
||||
@ -256,7 +256,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
single_iteration_run=AdvancedChatAppGenerateEntity.SingleIterationRunEntity(
|
||||
node_id=node_id, inputs=args["inputs"]
|
||||
),
|
||||
workflow_run_id=str(uuid.uuid4()),
|
||||
)
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||
@ -340,7 +339,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
extras={"auto_generate_conversation_name": False},
|
||||
single_loop_run=AdvancedChatAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]),
|
||||
workflow_run_id=str(uuid.uuid4()),
|
||||
)
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||
|
||||
@ -26,10 +26,7 @@ from core.variables.variables import VariableUnion
|
||||
from core.workflow.enums import WorkflowType
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_engine.layers.persistence import (
|
||||
PersistenceWorkflowInfo,
|
||||
WorkflowPersistenceLayer,
|
||||
)
|
||||
from core.workflow.graph_engine.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
@ -112,7 +109,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
workflow=self._workflow,
|
||||
single_iteration_run=self.application_generate_entity.single_iteration_run,
|
||||
single_loop_run=self.application_generate_entity.single_loop_run,
|
||||
system_variables=system_inputs,
|
||||
)
|
||||
else:
|
||||
inputs = self.application_generate_entity.inputs
|
||||
@ -190,20 +186,20 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
|
||||
self._queue_manager.graph_runtime_state = graph_runtime_state
|
||||
|
||||
if not self.application_generate_entity.is_single_stepping_container_nodes():
|
||||
persistence_layer = WorkflowPersistenceLayer(
|
||||
application_generate_entity=self.application_generate_entity,
|
||||
workflow_info=PersistenceWorkflowInfo(
|
||||
workflow_id=self._workflow.id,
|
||||
workflow_type=WorkflowType(self._workflow.type),
|
||||
version=self._workflow.version,
|
||||
graph_data=self._workflow.graph_dict,
|
||||
),
|
||||
workflow_execution_repository=self._workflow_execution_repository,
|
||||
workflow_node_execution_repository=self._workflow_node_execution_repository,
|
||||
trace_manager=self.application_generate_entity.trace_manager,
|
||||
)
|
||||
workflow_entry.graph_engine.layer(persistence_layer)
|
||||
persistence_layer = WorkflowPersistenceLayer(
|
||||
application_generate_entity=self.application_generate_entity,
|
||||
workflow_info=PersistenceWorkflowInfo(
|
||||
workflow_id=self._workflow.id,
|
||||
workflow_type=WorkflowType(self._workflow.type),
|
||||
version=self._workflow.version,
|
||||
graph_data=self._workflow.graph_dict,
|
||||
),
|
||||
workflow_execution_repository=self._workflow_execution_repository,
|
||||
workflow_node_execution_repository=self._workflow_node_execution_repository,
|
||||
trace_manager=self.application_generate_entity.trace_manager,
|
||||
)
|
||||
|
||||
workflow_entry.graph_engine.layer(persistence_layer)
|
||||
for layer in self._graph_engine_layers:
|
||||
workflow_entry.graph_engine.layer(layer)
|
||||
|
||||
|
||||
@ -90,7 +90,6 @@ class AppQueueManager:
|
||||
"""
|
||||
self._clear_task_belong_cache()
|
||||
self._q.put(None)
|
||||
self._graph_runtime_state = None # Release reference to allow GC to reclaim memory
|
||||
|
||||
def _clear_task_belong_cache(self) -> None:
|
||||
"""
|
||||
|
||||
@ -92,22 +92,6 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
||||
|
||||
db.session.close()
|
||||
|
||||
files = self.application_generate_entity.files
|
||||
system_inputs = SystemVariable(
|
||||
files=files,
|
||||
user_id=user_id,
|
||||
app_id=app_config.app_id,
|
||||
workflow_id=app_config.workflow_id,
|
||||
workflow_execution_id=self.application_generate_entity.workflow_execution_id,
|
||||
document_id=self.application_generate_entity.document_id,
|
||||
original_document_id=self.application_generate_entity.original_document_id,
|
||||
batch=self.application_generate_entity.batch,
|
||||
dataset_id=self.application_generate_entity.dataset_id,
|
||||
datasource_type=self.application_generate_entity.datasource_type,
|
||||
datasource_info=self.application_generate_entity.datasource_info,
|
||||
invoke_from=self.application_generate_entity.invoke_from.value,
|
||||
)
|
||||
|
||||
# if only single iteration run is requested
|
||||
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
|
||||
# Handle single iteration or single loop run
|
||||
@ -115,12 +99,27 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
||||
workflow=workflow,
|
||||
single_iteration_run=self.application_generate_entity.single_iteration_run,
|
||||
single_loop_run=self.application_generate_entity.single_loop_run,
|
||||
system_variables=system_inputs,
|
||||
)
|
||||
else:
|
||||
inputs = self.application_generate_entity.inputs
|
||||
files = self.application_generate_entity.files
|
||||
|
||||
# Create a variable pool.
|
||||
system_inputs = SystemVariable(
|
||||
files=files,
|
||||
user_id=user_id,
|
||||
app_id=app_config.app_id,
|
||||
workflow_id=app_config.workflow_id,
|
||||
workflow_execution_id=self.application_generate_entity.workflow_execution_id,
|
||||
document_id=self.application_generate_entity.document_id,
|
||||
original_document_id=self.application_generate_entity.original_document_id,
|
||||
batch=self.application_generate_entity.batch,
|
||||
dataset_id=self.application_generate_entity.dataset_id,
|
||||
datasource_type=self.application_generate_entity.datasource_type,
|
||||
datasource_info=self.application_generate_entity.datasource_info,
|
||||
invoke_from=self.application_generate_entity.invoke_from.value,
|
||||
)
|
||||
|
||||
rag_pipeline_variables = []
|
||||
if workflow.rag_pipeline_variables:
|
||||
for v in workflow.rag_pipeline_variables:
|
||||
@ -172,21 +171,21 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
||||
)
|
||||
|
||||
self._queue_manager.graph_runtime_state = graph_runtime_state
|
||||
if not self.application_generate_entity.is_single_stepping_container_nodes():
|
||||
persistence_layer = WorkflowPersistenceLayer(
|
||||
application_generate_entity=self.application_generate_entity,
|
||||
workflow_info=PersistenceWorkflowInfo(
|
||||
workflow_id=workflow.id,
|
||||
workflow_type=WorkflowType(workflow.type),
|
||||
version=workflow.version,
|
||||
graph_data=workflow.graph_dict,
|
||||
),
|
||||
workflow_execution_repository=self._workflow_execution_repository,
|
||||
workflow_node_execution_repository=self._workflow_node_execution_repository,
|
||||
trace_manager=self.application_generate_entity.trace_manager,
|
||||
)
|
||||
|
||||
workflow_entry.graph_engine.layer(persistence_layer)
|
||||
persistence_layer = WorkflowPersistenceLayer(
|
||||
application_generate_entity=self.application_generate_entity,
|
||||
workflow_info=PersistenceWorkflowInfo(
|
||||
workflow_id=workflow.id,
|
||||
workflow_type=WorkflowType(workflow.type),
|
||||
version=workflow.version,
|
||||
graph_data=workflow.graph_dict,
|
||||
),
|
||||
workflow_execution_repository=self._workflow_execution_repository,
|
||||
workflow_node_execution_repository=self._workflow_node_execution_repository,
|
||||
trace_manager=self.application_generate_entity.trace_manager,
|
||||
)
|
||||
|
||||
workflow_entry.graph_engine.layer(persistence_layer)
|
||||
|
||||
generator = workflow_entry.run()
|
||||
|
||||
|
||||
@ -10,10 +10,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerat
|
||||
from core.workflow.enums import WorkflowType
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_engine.layers.persistence import (
|
||||
PersistenceWorkflowInfo,
|
||||
WorkflowPersistenceLayer,
|
||||
)
|
||||
from core.workflow.graph_engine.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
@ -83,7 +80,6 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
workflow=self._workflow,
|
||||
single_iteration_run=self.application_generate_entity.single_iteration_run,
|
||||
single_loop_run=self.application_generate_entity.single_loop_run,
|
||||
system_variables=system_inputs,
|
||||
)
|
||||
else:
|
||||
inputs = self.application_generate_entity.inputs
|
||||
@ -136,21 +132,20 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
command_channel=command_channel,
|
||||
)
|
||||
|
||||
if not self.application_generate_entity.is_single_stepping_container_nodes():
|
||||
persistence_layer = WorkflowPersistenceLayer(
|
||||
application_generate_entity=self.application_generate_entity,
|
||||
workflow_info=PersistenceWorkflowInfo(
|
||||
workflow_id=self._workflow.id,
|
||||
workflow_type=WorkflowType(self._workflow.type),
|
||||
version=self._workflow.version,
|
||||
graph_data=self._workflow.graph_dict,
|
||||
),
|
||||
workflow_execution_repository=self._workflow_execution_repository,
|
||||
workflow_node_execution_repository=self._workflow_node_execution_repository,
|
||||
trace_manager=self.application_generate_entity.trace_manager,
|
||||
)
|
||||
persistence_layer = WorkflowPersistenceLayer(
|
||||
application_generate_entity=self.application_generate_entity,
|
||||
workflow_info=PersistenceWorkflowInfo(
|
||||
workflow_id=self._workflow.id,
|
||||
workflow_type=WorkflowType(self._workflow.type),
|
||||
version=self._workflow.version,
|
||||
graph_data=self._workflow.graph_dict,
|
||||
),
|
||||
workflow_execution_repository=self._workflow_execution_repository,
|
||||
workflow_node_execution_repository=self._workflow_node_execution_repository,
|
||||
trace_manager=self.application_generate_entity.trace_manager,
|
||||
)
|
||||
|
||||
workflow_entry.graph_engine.layer(persistence_layer)
|
||||
workflow_entry.graph_engine.layer(persistence_layer)
|
||||
for layer in self._graph_engine_layers:
|
||||
workflow_entry.graph_engine.layer(layer)
|
||||
|
||||
|
||||
@ -130,8 +130,6 @@ class WorkflowBasedAppRunner:
|
||||
workflow: Workflow,
|
||||
single_iteration_run: Any | None = None,
|
||||
single_loop_run: Any | None = None,
|
||||
*,
|
||||
system_variables: SystemVariable | None = None,
|
||||
) -> tuple[Graph, VariablePool, GraphRuntimeState]:
|
||||
"""
|
||||
Prepare graph, variable pool, and runtime state for single node execution
|
||||
@ -149,10 +147,9 @@ class WorkflowBasedAppRunner:
|
||||
ValueError: If neither single_iteration_run nor single_loop_run is specified
|
||||
"""
|
||||
# Create initial runtime state with variable pool containing environment variables
|
||||
system_variables = system_variables or SystemVariable.empty()
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(
|
||||
system_variables=system_variables,
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={},
|
||||
environment_variables=workflow.environment_variables,
|
||||
),
|
||||
@ -223,7 +220,7 @@ class WorkflowBasedAppRunner:
|
||||
# filter nodes only in the specified node type (iteration or loop)
|
||||
main_node_config = next((n for n in graph_config.get("nodes", []) if n.get("id") == node_id), None)
|
||||
start_node_id = main_node_config.get("data", {}).get("start_node_id") if main_node_config else None
|
||||
base_node_configs = [
|
||||
node_configs = [
|
||||
node
|
||||
for node in graph_config.get("nodes", [])
|
||||
if node.get("id") == node_id
|
||||
@ -231,74 +228,26 @@ class WorkflowBasedAppRunner:
|
||||
or (start_node_id and node.get("id") == start_node_id)
|
||||
]
|
||||
|
||||
# Build a base graph config (without synthetic entry) to keep node-level context minimal.
|
||||
base_graph_config = graph_config.copy()
|
||||
base_graph_config["nodes"] = base_node_configs
|
||||
graph_config["nodes"] = node_configs
|
||||
|
||||
node_ids = [node.get("id") for node in base_node_configs if isinstance(node.get("id"), str)]
|
||||
node_ids = [node.get("id") for node in node_configs]
|
||||
|
||||
# filter edges only in the specified node type
|
||||
base_edge_configs = [
|
||||
edge_configs = [
|
||||
edge
|
||||
for edge in graph_config.get("edges", [])
|
||||
if (edge.get("source") is None or edge.get("source") in node_ids)
|
||||
and (edge.get("target") is None or edge.get("target") in node_ids)
|
||||
]
|
||||
|
||||
base_graph_config["edges"] = base_edge_configs
|
||||
|
||||
# Inject a synthetic start node so Graph validation accepts the single-node graph
|
||||
# (loop/iteration nodes are containers and cannot serve as graph roots).
|
||||
synthetic_start_node_id = f"{node_id}_single_step_start"
|
||||
synthetic_start_node = {
|
||||
"id": synthetic_start_node_id,
|
||||
"type": "custom",
|
||||
"data": {
|
||||
"type": NodeType.START,
|
||||
"title": "Start",
|
||||
"desc": "Synthetic start for single-step run",
|
||||
"version": "1",
|
||||
"variables": [],
|
||||
},
|
||||
}
|
||||
|
||||
synthetic_end_node_id = f"{node_id}_single_step_end"
|
||||
synthetic_end_node = {
|
||||
"id": synthetic_end_node_id,
|
||||
"type": "custom",
|
||||
"data": {
|
||||
"type": NodeType.END,
|
||||
"title": "End",
|
||||
"desc": "Synthetic end for single-step run",
|
||||
"version": "1",
|
||||
"outputs": [],
|
||||
},
|
||||
}
|
||||
|
||||
graph_config_with_entry = base_graph_config.copy()
|
||||
graph_config_with_entry["nodes"] = [*base_node_configs, synthetic_start_node, synthetic_end_node]
|
||||
graph_config_with_entry["edges"] = [
|
||||
*base_edge_configs,
|
||||
{
|
||||
"source": synthetic_start_node_id,
|
||||
"target": node_id,
|
||||
"sourceHandle": "source",
|
||||
"targetHandle": "target",
|
||||
},
|
||||
{
|
||||
"source": node_id,
|
||||
"target": synthetic_end_node_id,
|
||||
"sourceHandle": "source",
|
||||
"targetHandle": "target",
|
||||
},
|
||||
]
|
||||
graph_config["edges"] = edge_configs
|
||||
|
||||
# Create required parameters for Graph.init
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=self._app_id,
|
||||
workflow_id=workflow.id,
|
||||
graph_config=base_graph_config,
|
||||
graph_config=graph_config,
|
||||
user_id="",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
@ -311,16 +260,14 @@ class WorkflowBasedAppRunner:
|
||||
)
|
||||
|
||||
# init graph
|
||||
graph = Graph.init(
|
||||
graph_config=graph_config_with_entry, node_factory=node_factory, root_node_id=synthetic_start_node_id
|
||||
)
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=node_id)
|
||||
|
||||
if not graph:
|
||||
raise ValueError("graph not found in workflow")
|
||||
|
||||
# fetch node config from node id
|
||||
target_node_config = None
|
||||
for node in base_node_configs:
|
||||
for node in node_configs:
|
||||
if node.get("id") == node_id:
|
||||
target_node_config = node
|
||||
break
|
||||
|
||||
@ -228,9 +228,6 @@ class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity):
|
||||
|
||||
single_loop_run: SingleLoopRunEntity | None = None
|
||||
|
||||
def is_single_stepping_container_nodes(self) -> bool:
|
||||
return self.single_iteration_run is not None or self.single_loop_run is not None
|
||||
|
||||
|
||||
class WorkflowAppGenerateEntity(AppGenerateEntity):
|
||||
"""
|
||||
@ -261,9 +258,6 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
|
||||
|
||||
single_loop_run: SingleLoopRunEntity | None = None
|
||||
|
||||
def is_single_stepping_container_nodes(self) -> bool:
|
||||
return self.single_iteration_run is not None or self.single_loop_run is not None
|
||||
|
||||
|
||||
class RagPipelineGenerateEntity(WorkflowAppGenerateEntity):
|
||||
"""
|
||||
|
||||
@ -1,14 +1,9 @@
|
||||
from collections.abc import Mapping
|
||||
from textwrap import dedent
|
||||
from typing import Any
|
||||
|
||||
from core.helper.code_executor.template_transformer import TemplateTransformer
|
||||
|
||||
|
||||
class Jinja2TemplateTransformer(TemplateTransformer):
|
||||
# Use separate placeholder for base64-encoded template to avoid confusion
|
||||
_template_b64_placeholder: str = "{{template_b64}}"
|
||||
|
||||
@classmethod
|
||||
def transform_response(cls, response: str):
|
||||
"""
|
||||
@ -18,35 +13,18 @@ class Jinja2TemplateTransformer(TemplateTransformer):
|
||||
"""
|
||||
return {"result": cls.extract_result_str_from_response(response)}
|
||||
|
||||
@classmethod
|
||||
def assemble_runner_script(cls, code: str, inputs: Mapping[str, Any]) -> str:
|
||||
"""
|
||||
Override base class to use base64 encoding for template code.
|
||||
This prevents issues with special characters (quotes, newlines) in templates
|
||||
breaking the generated Python script. Fixes #26818.
|
||||
"""
|
||||
script = cls.get_runner_script()
|
||||
# Encode template as base64 to safely embed any content including quotes
|
||||
code_b64 = cls.serialize_code(code)
|
||||
script = script.replace(cls._template_b64_placeholder, code_b64)
|
||||
inputs_str = cls.serialize_inputs(inputs)
|
||||
script = script.replace(cls._inputs_placeholder, inputs_str)
|
||||
return script
|
||||
|
||||
@classmethod
|
||||
def get_runner_script(cls) -> str:
|
||||
runner_script = dedent(f"""
|
||||
import jinja2
|
||||
import json
|
||||
from base64 import b64decode
|
||||
|
||||
# declare main function
|
||||
def main(**inputs):
|
||||
# Decode base64-encoded template to handle special characters safely
|
||||
template_code = b64decode('{cls._template_b64_placeholder}').decode('utf-8')
|
||||
template = jinja2.Template(template_code)
|
||||
import jinja2
|
||||
template = jinja2.Template('''{cls._code_placeholder}''')
|
||||
return template.render(**inputs)
|
||||
|
||||
import json
|
||||
from base64 import b64decode
|
||||
|
||||
# decode and prepare input dict
|
||||
inputs_obj = json.loads(b64decode('{cls._inputs_placeholder}').decode('utf-8'))
|
||||
|
||||
|
||||
@ -13,15 +13,6 @@ class TemplateTransformer(ABC):
|
||||
_inputs_placeholder: str = "{{inputs}}"
|
||||
_result_tag: str = "<<RESULT>>"
|
||||
|
||||
@classmethod
|
||||
def serialize_code(cls, code: str) -> str:
|
||||
"""
|
||||
Serialize template code to base64 to safely embed in generated script.
|
||||
This prevents issues with special characters like quotes breaking the script.
|
||||
"""
|
||||
code_bytes = code.encode("utf-8")
|
||||
return b64encode(code_bytes).decode("utf-8")
|
||||
|
||||
@classmethod
|
||||
def transform_caller(cls, code: str, inputs: Mapping[str, Any]) -> tuple[str, str]:
|
||||
"""
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, cast
|
||||
from typing import Any
|
||||
|
||||
from core.tools.entities.api_entities import ToolProviderTypeApiLiteral
|
||||
from extensions.ext_redis import redis_client, redis_fallback
|
||||
@ -50,9 +50,7 @@ class ToolProviderListCache:
|
||||
redis_client.delete(cache_key)
|
||||
else:
|
||||
# Invalidate all caches for this tenant
|
||||
keys = ["builtin", "model", "api", "workflow", "mcp"]
|
||||
pipeline = redis_client.pipeline()
|
||||
for key in keys:
|
||||
cache_key = ToolProviderListCache._generate_cache_key(tenant_id, cast(ToolProviderTypeApiLiteral, key))
|
||||
pipeline.delete(cache_key)
|
||||
pipeline.execute()
|
||||
pattern = f"tool_providers:tenant_id:{tenant_id}:*"
|
||||
keys = list(redis_client.scan_iter(pattern))
|
||||
if keys:
|
||||
redis_client.delete(*keys)
|
||||
|
||||
@ -313,20 +313,17 @@ class StreamableHTTPTransport:
|
||||
if is_initialization:
|
||||
self._maybe_extract_session_id_from_response(response)
|
||||
|
||||
# Per https://modelcontextprotocol.io/specification/2025-06-18/basic#notifications:
|
||||
# The server MUST NOT send a response to notifications.
|
||||
if isinstance(message.root, JSONRPCRequest):
|
||||
content_type = cast(str, response.headers.get(CONTENT_TYPE, "").lower())
|
||||
content_type = cast(str, response.headers.get(CONTENT_TYPE, "").lower())
|
||||
|
||||
if content_type.startswith(JSON):
|
||||
self._handle_json_response(response, ctx.server_to_client_queue)
|
||||
elif content_type.startswith(SSE):
|
||||
self._handle_sse_response(response, ctx)
|
||||
else:
|
||||
self._handle_unexpected_content_type(
|
||||
content_type,
|
||||
ctx.server_to_client_queue,
|
||||
)
|
||||
if content_type.startswith(JSON):
|
||||
self._handle_json_response(response, ctx.server_to_client_queue)
|
||||
elif content_type.startswith(SSE):
|
||||
self._handle_sse_response(response, ctx)
|
||||
else:
|
||||
self._handle_unexpected_content_type(
|
||||
content_type,
|
||||
ctx.server_to_client_queue,
|
||||
)
|
||||
|
||||
def _handle_json_response(
|
||||
self,
|
||||
|
||||
@ -76,7 +76,7 @@ class PluginParameter(BaseModel):
|
||||
auto_generate: PluginParameterAutoGenerate | None = None
|
||||
template: PluginParameterTemplate | None = None
|
||||
required: bool = False
|
||||
default: Union[float, int, str, bool, list, dict] | None = None
|
||||
default: Union[float, int, str, bool] | None = None
|
||||
min: Union[float, int] | None = None
|
||||
max: Union[float, int] | None = None
|
||||
precision: int | None = None
|
||||
|
||||
@ -13,7 +13,7 @@ from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
|
||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.embedding.retrieval import RetrievalChildChunk, RetrievalSegments
|
||||
from core.rag.embedding.retrieval import RetrievalSegments
|
||||
from core.rag.entities.metadata_entities import MetadataCondition
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
@ -416,12 +416,12 @@ class RetrievalService:
|
||||
child_index_node_ids = [i for i in child_index_node_ids if i]
|
||||
index_node_ids = [i for i in index_node_ids if i]
|
||||
|
||||
segment_ids: list[str] = []
|
||||
segment_ids = []
|
||||
index_node_segments: list[DocumentSegment] = []
|
||||
segments: list[DocumentSegment] = []
|
||||
attachment_map: dict[str, list[dict[str, Any]]] = {}
|
||||
child_chunk_map: dict[str, list[ChildChunk]] = {}
|
||||
doc_segment_map: dict[str, list[str]] = {}
|
||||
attachment_map = {}
|
||||
child_chunk_map: dict[Any, Any] = {}
|
||||
doc_segment_map = {}
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
attachments = cls.get_segment_attachment_infos(image_doc_ids, session)
|
||||
@ -432,7 +432,7 @@ class RetrievalService:
|
||||
attachment_map[attachment["segment_id"]].append(attachment["attachment_info"])
|
||||
else:
|
||||
attachment_map[attachment["segment_id"]] = [attachment["attachment_info"]]
|
||||
if attachment["segment_id"] in doc_segment_map:
|
||||
if attachment["attachment_id"] in doc_segment_map:
|
||||
doc_segment_map[attachment["segment_id"]].append(attachment["attachment_id"])
|
||||
else:
|
||||
doc_segment_map[attachment["segment_id"]] = [attachment["attachment_id"]]
|
||||
@ -502,7 +502,7 @@ class RetrievalService:
|
||||
"child_chunks": child_chunk_details,
|
||||
}
|
||||
segment_child_map[segment.id] = map_detail
|
||||
record: dict[str, Any] = {
|
||||
record = {
|
||||
"segment": segment,
|
||||
}
|
||||
records.append(record)
|
||||
@ -510,13 +510,13 @@ class RetrievalService:
|
||||
if segment.id not in include_segment_ids:
|
||||
include_segment_ids.add(segment.id)
|
||||
max_score = 0.0
|
||||
segment_document = doc_to_document_map.get(segment.index_node_id)
|
||||
if segment_document:
|
||||
max_score = max(max_score, segment_document.metadata.get("score", 0.0))
|
||||
document = doc_to_document_map.get(segment.index_node_id)
|
||||
if document:
|
||||
max_score = max(max_score, document.metadata.get("score", 0.0))
|
||||
for attachment_info in attachment_infos:
|
||||
file_doc = doc_to_document_map.get(attachment_info["id"])
|
||||
if file_doc:
|
||||
max_score = max(max_score, file_doc.metadata.get("score", 0.0))
|
||||
file_document = doc_to_document_map.get(attachment_info["id"])
|
||||
if file_document:
|
||||
max_score = max(max_score, file_document.metadata.get("score", 0.0))
|
||||
record = {
|
||||
"segment": segment,
|
||||
"score": max_score,
|
||||
@ -531,26 +531,18 @@ class RetrievalService:
|
||||
if record["segment"].id in attachment_map:
|
||||
record["files"] = attachment_map[record["segment"].id] # type: ignore[assignment]
|
||||
|
||||
result: list[RetrievalSegments] = []
|
||||
result = []
|
||||
for record in records:
|
||||
# Extract segment
|
||||
segment = record["segment"]
|
||||
|
||||
# Extract child_chunks, ensuring it's a list or None
|
||||
raw_child_chunks = record.get("child_chunks")
|
||||
child_chunks_list: list[RetrievalChildChunk] | None = None
|
||||
if isinstance(raw_child_chunks, list):
|
||||
# Sort by score descending
|
||||
sorted_chunks = sorted(raw_child_chunks, key=lambda x: x.get("score", 0.0), reverse=True)
|
||||
child_chunks_list = [
|
||||
RetrievalChildChunk(
|
||||
id=chunk["id"],
|
||||
content=chunk["content"],
|
||||
score=chunk.get("score", 0.0),
|
||||
position=chunk["position"],
|
||||
)
|
||||
for chunk in sorted_chunks
|
||||
]
|
||||
child_chunks = record.get("child_chunks")
|
||||
if not isinstance(child_chunks, list):
|
||||
child_chunks = None
|
||||
|
||||
if child_chunks:
|
||||
child_chunks = sorted(child_chunks, key=lambda x: x.get("score", 0.0), reverse=True)
|
||||
|
||||
# Extract files, ensuring it's a list or None
|
||||
files = record.get("files")
|
||||
@ -567,11 +559,11 @@ class RetrievalService:
|
||||
|
||||
# Create RetrievalSegments object
|
||||
retrieval_segment = RetrievalSegments(
|
||||
segment=segment, child_chunks=child_chunks_list, score=score, files=files
|
||||
segment=segment, child_chunks=child_chunks, score=score, files=files
|
||||
)
|
||||
result.append(retrieval_segment)
|
||||
|
||||
return sorted(result, key=lambda x: x.score if x.score is not None else 0.0, reverse=True)
|
||||
return sorted(result, key=lambda x: x.score, reverse=True)
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
raise e
|
||||
|
||||
@ -255,10 +255,7 @@ class PGVector(BaseVector):
|
||||
return
|
||||
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute("SELECT 1 FROM pg_extension WHERE extname = 'vector'")
|
||||
if not cur.fetchone():
|
||||
cur.execute("CREATE EXTENSION vector")
|
||||
|
||||
cur.execute("CREATE EXTENSION IF NOT EXISTS vector")
|
||||
cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name, dimension=dimension))
|
||||
# PG hnsw index only support 2000 dimension or less
|
||||
# ref: https://github.com/pgvector/pgvector?tab=readme-ov-file#indexing
|
||||
|
||||
@ -7,7 +7,7 @@ from collections.abc import Generator, Mapping
|
||||
from typing import Any, Union, cast
|
||||
|
||||
from flask import Flask, current_app
|
||||
from sqlalchemy import and_, literal, or_, select
|
||||
from sqlalchemy import and_, or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.app_config.entities import (
|
||||
@ -1036,7 +1036,7 @@ class DatasetRetrieval:
|
||||
if automatic_metadata_filters:
|
||||
conditions = []
|
||||
for sequence, filter in enumerate(automatic_metadata_filters):
|
||||
self.process_metadata_filter_func(
|
||||
self._process_metadata_filter_func(
|
||||
sequence,
|
||||
filter.get("condition"), # type: ignore
|
||||
filter.get("metadata_name"), # type: ignore
|
||||
@ -1072,7 +1072,7 @@ class DatasetRetrieval:
|
||||
value=expected_value,
|
||||
)
|
||||
)
|
||||
filters = self.process_metadata_filter_func(
|
||||
filters = self._process_metadata_filter_func(
|
||||
sequence,
|
||||
condition.comparison_operator,
|
||||
metadata_name,
|
||||
@ -1168,9 +1168,8 @@ class DatasetRetrieval:
|
||||
return None
|
||||
return automatic_metadata_filters
|
||||
|
||||
@classmethod
|
||||
def process_metadata_filter_func(
|
||||
cls, sequence: int, condition: str, metadata_name: str, value: Any | None, filters: list
|
||||
def _process_metadata_filter_func(
|
||||
self, sequence: int, condition: str, metadata_name: str, value: Any | None, filters: list
|
||||
):
|
||||
if value is None and condition not in ("empty", "not empty"):
|
||||
return filters
|
||||
@ -1219,20 +1218,6 @@ class DatasetRetrieval:
|
||||
|
||||
case "≥" | ">=":
|
||||
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() >= value)
|
||||
case "in" | "not in":
|
||||
if isinstance(value, str):
|
||||
value_list = [v.strip() for v in value.split(",") if v.strip()]
|
||||
elif isinstance(value, (list, tuple)):
|
||||
value_list = [str(v) for v in value if v is not None]
|
||||
else:
|
||||
value_list = [str(value)] if value is not None else []
|
||||
|
||||
if not value_list:
|
||||
# `field in []` is False, `field not in []` is True
|
||||
filters.append(literal(condition == "not in"))
|
||||
else:
|
||||
op = json_field.in_ if condition == "in" else json_field.notin_
|
||||
filters.append(op(value_list))
|
||||
case _:
|
||||
pass
|
||||
|
||||
|
||||
@ -6,15 +6,7 @@ from typing import Any
|
||||
|
||||
from core.mcp.auth_client import MCPClientWithAuthRetry
|
||||
from core.mcp.error import MCPConnectionError
|
||||
from core.mcp.types import (
|
||||
AudioContent,
|
||||
BlobResourceContents,
|
||||
CallToolResult,
|
||||
EmbeddedResource,
|
||||
ImageContent,
|
||||
TextContent,
|
||||
TextResourceContents,
|
||||
)
|
||||
from core.mcp.types import AudioContent, CallToolResult, ImageContent, TextContent
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType
|
||||
@ -61,19 +53,10 @@ class MCPTool(Tool):
|
||||
for content in result.content:
|
||||
if isinstance(content, TextContent):
|
||||
yield from self._process_text_content(content)
|
||||
elif isinstance(content, ImageContent | AudioContent):
|
||||
yield self.create_blob_message(
|
||||
blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType}
|
||||
)
|
||||
elif isinstance(content, EmbeddedResource):
|
||||
resource = content.resource
|
||||
if isinstance(resource, TextResourceContents):
|
||||
yield self.create_text_message(resource.text)
|
||||
elif isinstance(resource, BlobResourceContents):
|
||||
mime_type = resource.mimeType or "application/octet-stream"
|
||||
yield self.create_blob_message(blob=base64.b64decode(resource.blob), meta={"mime_type": mime_type})
|
||||
else:
|
||||
raise ToolInvokeError(f"Unsupported embedded resource type: {type(resource)}")
|
||||
elif isinstance(content, ImageContent):
|
||||
yield self._process_image_content(content)
|
||||
elif isinstance(content, AudioContent):
|
||||
yield self._process_audio_content(content)
|
||||
else:
|
||||
logger.warning("Unsupported content type=%s", type(content))
|
||||
|
||||
@ -118,6 +101,14 @@ class MCPTool(Tool):
|
||||
for item in json_list:
|
||||
yield self.create_json_message(item)
|
||||
|
||||
def _process_image_content(self, content: ImageContent) -> ToolInvokeMessage:
|
||||
"""Process image content and return a blob message."""
|
||||
return self.create_blob_message(blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType})
|
||||
|
||||
def _process_audio_content(self, content: AudioContent) -> ToolInvokeMessage:
|
||||
"""Process audio content and return a blob message."""
|
||||
return self.create_blob_message(blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType})
|
||||
|
||||
def fork_tool_runtime(self, runtime: ToolRuntime) -> "MCPTool":
|
||||
return MCPTool(
|
||||
entity=self.entity,
|
||||
|
||||
@ -5,7 +5,6 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.app_config.entities import VariableEntity, VariableEntityType
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||
from core.db.session_factory import session_factory
|
||||
from core.plugin.entities.parameters import PluginParameterOption
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
@ -48,30 +47,33 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, db_provider: WorkflowToolProvider) -> "WorkflowToolProviderController":
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
app = session.get(App, db_provider.app_id)
|
||||
with Session(db.engine, expire_on_commit=False) as session, session.begin():
|
||||
provider = session.get(WorkflowToolProvider, db_provider.id) if db_provider.id else None
|
||||
if not provider:
|
||||
raise ValueError("workflow provider not found")
|
||||
app = session.get(App, provider.app_id)
|
||||
if not app:
|
||||
raise ValueError("app not found")
|
||||
|
||||
user = session.get(Account, db_provider.user_id) if db_provider.user_id else None
|
||||
user = session.get(Account, provider.user_id) if provider.user_id else None
|
||||
|
||||
controller = WorkflowToolProviderController(
|
||||
entity=ToolProviderEntity(
|
||||
identity=ToolProviderIdentity(
|
||||
author=user.name if user else "",
|
||||
name=db_provider.label,
|
||||
label=I18nObject(en_US=db_provider.label, zh_Hans=db_provider.label),
|
||||
description=I18nObject(en_US=db_provider.description, zh_Hans=db_provider.description),
|
||||
icon=db_provider.icon,
|
||||
name=provider.label,
|
||||
label=I18nObject(en_US=provider.label, zh_Hans=provider.label),
|
||||
description=I18nObject(en_US=provider.description, zh_Hans=provider.description),
|
||||
icon=provider.icon,
|
||||
),
|
||||
credentials_schema=[],
|
||||
plugin_id=None,
|
||||
),
|
||||
provider_id="",
|
||||
provider_id=provider.id or "",
|
||||
)
|
||||
|
||||
controller.tools = [
|
||||
controller._get_db_provider_tool(db_provider, app, session=session, user=user),
|
||||
controller._get_db_provider_tool(provider, app, session=session, user=user),
|
||||
]
|
||||
|
||||
return controller
|
||||
|
||||
@ -149,49 +149,18 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
return isinstance(variable, NoneSegment) or len(variable.value) == 0
|
||||
|
||||
def _handle_empty_iteration(self, variable: ArraySegment | NoneSegment) -> Generator[NodeEventBase, None, None]:
|
||||
started_at = naive_utc_now()
|
||||
inputs: dict[str, object] = {"iterator_selector": []}
|
||||
usage = LLMUsage.empty_usage()
|
||||
|
||||
yield IterationStartedEvent(
|
||||
start_at=started_at,
|
||||
inputs=inputs,
|
||||
metadata={"iteration_length": 0},
|
||||
)
|
||||
|
||||
# Try our best to preserve the type information.
|
||||
if isinstance(variable, ArraySegment):
|
||||
output = variable.model_copy(update={"value": []})
|
||||
else:
|
||||
output = ArrayAnySegment(value=[])
|
||||
|
||||
yield IterationSucceededEvent(
|
||||
start_at=started_at,
|
||||
inputs=inputs,
|
||||
outputs={"output": []},
|
||||
steps=0,
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
|
||||
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: {},
|
||||
},
|
||||
)
|
||||
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
# TODO(QuantumGhost): is it possible to compute the type of `output`
|
||||
# from graph definition?
|
||||
outputs={"output": output},
|
||||
inputs=inputs,
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
|
||||
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: {},
|
||||
},
|
||||
llm_usage=usage,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@ -6,7 +6,7 @@ from collections import defaultdict
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from sqlalchemy import and_, func, or_, select
|
||||
from sqlalchemy import and_, func, literal, or_, select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.app.app_config.entities import DatasetRetrieveConfigEntity
|
||||
@ -460,7 +460,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
if automatic_metadata_filters:
|
||||
conditions = []
|
||||
for sequence, filter in enumerate(automatic_metadata_filters):
|
||||
DatasetRetrieval.process_metadata_filter_func(
|
||||
self._process_metadata_filter_func(
|
||||
sequence,
|
||||
filter.get("condition", ""),
|
||||
filter.get("metadata_name", ""),
|
||||
@ -504,7 +504,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
value=expected_value,
|
||||
)
|
||||
)
|
||||
filters = DatasetRetrieval.process_metadata_filter_func(
|
||||
filters = self._process_metadata_filter_func(
|
||||
sequence,
|
||||
condition.comparison_operator,
|
||||
metadata_name,
|
||||
@ -603,6 +603,87 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
return [], usage
|
||||
return automatic_metadata_filters, usage
|
||||
|
||||
def _process_metadata_filter_func(
|
||||
self, sequence: int, condition: str, metadata_name: str, value: Any, filters: list[Any]
|
||||
) -> list[Any]:
|
||||
if value is None and condition not in ("empty", "not empty"):
|
||||
return filters
|
||||
|
||||
json_field = Document.doc_metadata[metadata_name].as_string()
|
||||
|
||||
match condition:
|
||||
case "contains":
|
||||
filters.append(json_field.like(f"%{value}%"))
|
||||
|
||||
case "not contains":
|
||||
filters.append(json_field.notlike(f"%{value}%"))
|
||||
|
||||
case "start with":
|
||||
filters.append(json_field.like(f"{value}%"))
|
||||
|
||||
case "end with":
|
||||
filters.append(json_field.like(f"%{value}"))
|
||||
case "in":
|
||||
if isinstance(value, str):
|
||||
value_list = [v.strip() for v in value.split(",") if v.strip()]
|
||||
elif isinstance(value, (list, tuple)):
|
||||
value_list = [str(v) for v in value if v is not None]
|
||||
else:
|
||||
value_list = [str(value)] if value is not None else []
|
||||
|
||||
if not value_list:
|
||||
filters.append(literal(False))
|
||||
else:
|
||||
filters.append(json_field.in_(value_list))
|
||||
|
||||
case "not in":
|
||||
if isinstance(value, str):
|
||||
value_list = [v.strip() for v in value.split(",") if v.strip()]
|
||||
elif isinstance(value, (list, tuple)):
|
||||
value_list = [str(v) for v in value if v is not None]
|
||||
else:
|
||||
value_list = [str(value)] if value is not None else []
|
||||
|
||||
if not value_list:
|
||||
filters.append(literal(True))
|
||||
else:
|
||||
filters.append(json_field.notin_(value_list))
|
||||
|
||||
case "is" | "=":
|
||||
if isinstance(value, str):
|
||||
filters.append(json_field == value)
|
||||
elif isinstance(value, (int, float)):
|
||||
filters.append(Document.doc_metadata[metadata_name].as_float() == value)
|
||||
|
||||
case "is not" | "≠":
|
||||
if isinstance(value, str):
|
||||
filters.append(json_field != value)
|
||||
elif isinstance(value, (int, float)):
|
||||
filters.append(Document.doc_metadata[metadata_name].as_float() != value)
|
||||
|
||||
case "empty":
|
||||
filters.append(Document.doc_metadata[metadata_name].is_(None))
|
||||
|
||||
case "not empty":
|
||||
filters.append(Document.doc_metadata[metadata_name].isnot(None))
|
||||
|
||||
case "before" | "<":
|
||||
filters.append(Document.doc_metadata[metadata_name].as_float() < value)
|
||||
|
||||
case "after" | ">":
|
||||
filters.append(Document.doc_metadata[metadata_name].as_float() > value)
|
||||
|
||||
case "≤" | "<=":
|
||||
filters.append(Document.doc_metadata[metadata_name].as_float() <= value)
|
||||
|
||||
case "≥" | ">=":
|
||||
filters.append(Document.doc_metadata[metadata_name].as_float() >= value)
|
||||
|
||||
case _:
|
||||
pass
|
||||
|
||||
return filters
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
|
||||
@ -6,6 +6,7 @@ from .create_site_record_when_app_created import handle as handle_create_site_re
|
||||
from .delete_tool_parameters_cache_when_sync_draft_workflow import (
|
||||
handle as handle_delete_tool_parameters_cache_when_sync_draft_workflow,
|
||||
)
|
||||
from .queue_credential_sync_when_tenant_created import handle as handle_queue_credential_sync_when_tenant_created
|
||||
from .sync_plugin_trigger_when_app_created import handle as handle_sync_plugin_trigger_when_app_created
|
||||
from .sync_webhook_when_app_created import handle as handle_sync_webhook_when_app_created
|
||||
from .sync_workflow_schedule_when_app_published import handle as handle_sync_workflow_schedule_when_app_published
|
||||
@ -30,6 +31,7 @@ __all__ = [
|
||||
"handle_create_installed_app_when_app_created",
|
||||
"handle_create_site_record_when_app_created",
|
||||
"handle_delete_tool_parameters_cache_when_sync_draft_workflow",
|
||||
"handle_queue_credential_sync_when_tenant_created",
|
||||
"handle_sync_plugin_trigger_when_app_created",
|
||||
"handle_sync_webhook_when_app_created",
|
||||
"handle_sync_workflow_schedule_when_app_published",
|
||||
|
||||
@ -0,0 +1,19 @@
|
||||
from configs import dify_config
|
||||
from events.tenant_event import tenant_was_created
|
||||
from services.enterprise.workspace_sync import WorkspaceSyncService
|
||||
|
||||
|
||||
@tenant_was_created.connect
|
||||
def handle(sender, **kwargs):
|
||||
"""Queue credential sync when a tenant/workspace is created."""
|
||||
# Only queue sync tasks if plugin manager (enterprise feature) is enabled
|
||||
if not dify_config.ENTERPRISE_ENABLED:
|
||||
return
|
||||
|
||||
tenant = sender
|
||||
|
||||
# Determine source from kwargs if available, otherwise use generic
|
||||
source = kwargs.get("source", "tenant_created")
|
||||
|
||||
# Queue credential sync task to Redis for enterprise backend to process
|
||||
WorkspaceSyncService.queue_credential_sync(tenant.id, source=source)
|
||||
@ -14,8 +14,7 @@ from enums.quota_type import QuotaType, unlimited
|
||||
from extensions.otel import AppGenerateHandler, trace_span
|
||||
from models.model import Account, App, AppMode, EndUser
|
||||
from models.workflow import Workflow
|
||||
from services.errors.app import QuotaExceededError, WorkflowIdFormatError, WorkflowNotFoundError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
from services.errors.app import InvokeRateLimitError, QuotaExceededError, WorkflowIdFormatError, WorkflowNotFoundError
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
|
||||
|
||||
@ -21,7 +21,7 @@ from models.model import App, EndUser
|
||||
from models.trigger import WorkflowTriggerLog
|
||||
from models.workflow import Workflow
|
||||
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
|
||||
from services.errors.app import QuotaExceededError, WorkflowNotFoundError, WorkflowQuotaLimitError
|
||||
from services.errors.app import InvokeRateLimitError, QuotaExceededError, WorkflowNotFoundError
|
||||
from services.workflow.entities import AsyncTriggerResponse, TriggerData, WorkflowTaskData
|
||||
from services.workflow.queue_dispatcher import QueueDispatcherManager, QueuePriority
|
||||
from services.workflow_service import WorkflowService
|
||||
@ -141,7 +141,7 @@ class AsyncWorkflowService:
|
||||
trigger_log_repo.update(trigger_log)
|
||||
session.commit()
|
||||
|
||||
raise WorkflowQuotaLimitError(
|
||||
raise InvokeRateLimitError(
|
||||
f"Workflow execution quota limit reached for tenant {trigger_data.tenant_id}"
|
||||
) from e
|
||||
|
||||
|
||||
58
api/services/enterprise/workspace_sync.py
Normal file
58
api/services/enterprise/workspace_sync.py
Normal file
@ -0,0 +1,58 @@
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from redis import RedisError
|
||||
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
WORKSPACE_SYNC_QUEUE = "enterprise:workspace:sync:queue"
|
||||
WORKSPACE_SYNC_PROCESSING = "enterprise:workspace:sync:processing"
|
||||
|
||||
|
||||
class WorkspaceSyncService:
|
||||
"""Service to publish workspace sync tasks to Redis queue for enterprise backend consumption"""
|
||||
|
||||
@staticmethod
|
||||
def queue_credential_sync(workspace_id: str, *, source: str) -> bool:
|
||||
"""
|
||||
Queue a credential sync task for a newly created workspace.
|
||||
|
||||
This publishes a task to Redis that will be consumed by the enterprise backend
|
||||
worker to sync credentials with the plugin-manager.
|
||||
|
||||
Args:
|
||||
workspace_id: The workspace/tenant ID to sync credentials for
|
||||
source: Source of the sync request (for debugging/tracking)
|
||||
|
||||
Returns:
|
||||
bool: True if task was queued successfully, False otherwise
|
||||
"""
|
||||
try:
|
||||
task = {
|
||||
"task_id": str(uuid.uuid4()),
|
||||
"workspace_id": workspace_id,
|
||||
"retry_count": 0,
|
||||
"created_at": datetime.now(UTC).isoformat(),
|
||||
"source": source,
|
||||
}
|
||||
|
||||
# Push to Redis list (queue) - LPUSH adds to the head, worker consumes from tail with RPOP
|
||||
redis_client.lpush(WORKSPACE_SYNC_QUEUE, json.dumps(task))
|
||||
|
||||
logger.info(
|
||||
"Queued credential sync task for workspace %s, task_id: %s, source: %s",
|
||||
workspace_id,
|
||||
task["task_id"],
|
||||
source,
|
||||
)
|
||||
return True
|
||||
|
||||
except (RedisError, TypeError) as e:
|
||||
logger.error("Failed to queue credential sync for workspace %s: %s", workspace_id, str(e), exc_info=True)
|
||||
# Don't raise - we don't want to fail workspace creation if queueing fails
|
||||
# The scheduled task will catch it later
|
||||
return False
|
||||
@ -18,8 +18,8 @@ class WorkflowIdFormatError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class WorkflowQuotaLimitError(Exception):
|
||||
"""Raised when workflow execution quota is exceeded (for async/background workflows)."""
|
||||
class InvokeRateLimitError(Exception):
|
||||
"""Raised when rate limit is exceeded for workflow invocations."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@ -146,7 +146,7 @@ class PluginParameterService:
|
||||
provider,
|
||||
action,
|
||||
resolved_credentials,
|
||||
original_subscription.credential_type or CredentialType.UNAUTHORIZED.value,
|
||||
CredentialType.API_KEY.value,
|
||||
parameter,
|
||||
)
|
||||
.options
|
||||
|
||||
@ -319,14 +319,8 @@ class MCPToolManageService:
|
||||
except MCPError as e:
|
||||
raise ValueError(f"Failed to connect to MCP server: {e}")
|
||||
|
||||
# Update database with retrieved tools (ensure description is a non-null string)
|
||||
tools_payload = []
|
||||
for tool in tools:
|
||||
data = tool.model_dump()
|
||||
if data.get("description") is None:
|
||||
data["description"] = ""
|
||||
tools_payload.append(data)
|
||||
db_provider.tools = json.dumps(tools_payload)
|
||||
# Update database with retrieved tools
|
||||
db_provider.tools = json.dumps([tool.model_dump() for tool in tools])
|
||||
db_provider.authed = True
|
||||
db_provider.updated_at = datetime.now()
|
||||
self._session.flush()
|
||||
@ -626,21 +620,6 @@ class MCPToolManageService:
|
||||
server_url_hash=new_server_url_hash,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def reconnect_with_url(
|
||||
*,
|
||||
server_url: str,
|
||||
headers: dict[str, str],
|
||||
timeout: float | None,
|
||||
sse_read_timeout: float | None,
|
||||
) -> ReconnectResult:
|
||||
return MCPToolManageService._reconnect_with_url(
|
||||
server_url=server_url,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
sse_read_timeout=sse_read_timeout,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _reconnect_with_url(
|
||||
*,
|
||||
@ -663,16 +642,9 @@ class MCPToolManageService:
|
||||
sse_read_timeout=sse_read_timeout,
|
||||
) as mcp_client:
|
||||
tools = mcp_client.list_tools()
|
||||
# Ensure tool descriptions are non-null in payload
|
||||
tools_payload = []
|
||||
for t in tools:
|
||||
d = t.model_dump()
|
||||
if d.get("description") is None:
|
||||
d["description"] = ""
|
||||
tools_payload.append(d)
|
||||
return ReconnectResult(
|
||||
authed=True,
|
||||
tools=json.dumps(tools_payload),
|
||||
tools=json.dumps([tool.model_dump() for tool in tools]),
|
||||
encrypted_credentials=EMPTY_CREDENTIALS_JSON,
|
||||
)
|
||||
except MCPAuthError:
|
||||
|
||||
@ -5,8 +5,8 @@ from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.helper.tool_provider_cache import ToolProviderListCache
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
@ -68,27 +68,26 @@ class WorkflowToolManageService:
|
||||
if workflow is None:
|
||||
raise ValueError(f"Workflow not found for app {workflow_app_id}")
|
||||
|
||||
workflow_tool_provider = WorkflowToolProvider(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
app_id=workflow_app_id,
|
||||
name=name,
|
||||
label=label,
|
||||
icon=json.dumps(icon),
|
||||
description=description,
|
||||
parameter_configuration=json.dumps(parameters),
|
||||
privacy_policy=privacy_policy,
|
||||
version=workflow.version,
|
||||
)
|
||||
with Session(db.engine, expire_on_commit=False) as session, session.begin():
|
||||
workflow_tool_provider = WorkflowToolProvider(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
app_id=workflow_app_id,
|
||||
name=name,
|
||||
label=label,
|
||||
icon=json.dumps(icon),
|
||||
description=description,
|
||||
parameter_configuration=json.dumps(parameters),
|
||||
privacy_policy=privacy_policy,
|
||||
version=workflow.version,
|
||||
)
|
||||
session.add(workflow_tool_provider)
|
||||
|
||||
try:
|
||||
WorkflowToolProviderController.from_db(workflow_tool_provider)
|
||||
except Exception as e:
|
||||
raise ValueError(str(e))
|
||||
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
session.add(workflow_tool_provider)
|
||||
|
||||
if labels is not None:
|
||||
ToolLabelManager.update_tool_labels(
|
||||
ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels
|
||||
|
||||
@ -868,111 +868,48 @@ class TriggerProviderService:
|
||||
if not provider_controller:
|
||||
raise ValueError(f"Provider {provider_id} not found")
|
||||
|
||||
# Use distributed lock to prevent race conditions on the same subscription
|
||||
lock_key = f"trigger_subscription_rebuild_lock:{tenant_id}_{subscription_id}"
|
||||
with redis_client.lock(lock_key, timeout=20):
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
try:
|
||||
# Get subscription within the transaction
|
||||
subscription: TriggerSubscription | None = (
|
||||
session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
|
||||
)
|
||||
if not subscription:
|
||||
raise ValueError(f"Subscription {subscription_id} not found")
|
||||
subscription = TriggerProviderService.get_subscription_by_id(
|
||||
tenant_id=tenant_id,
|
||||
subscription_id=subscription_id,
|
||||
)
|
||||
if not subscription:
|
||||
raise ValueError(f"Subscription {subscription_id} not found")
|
||||
|
||||
credential_type = CredentialType.of(subscription.credential_type)
|
||||
if credential_type not in [CredentialType.OAUTH2, CredentialType.API_KEY]:
|
||||
raise ValueError("Credential type not supported for rebuild")
|
||||
credential_type = CredentialType.of(subscription.credential_type)
|
||||
if credential_type not in [CredentialType.OAUTH2, CredentialType.API_KEY]:
|
||||
raise ValueError("Credential type not supported for rebuild")
|
||||
|
||||
# Decrypt existing credentials for merging
|
||||
credential_encrypter, _ = create_trigger_provider_encrypter_for_subscription(
|
||||
tenant_id=tenant_id,
|
||||
controller=provider_controller,
|
||||
subscription=subscription,
|
||||
)
|
||||
decrypted_credentials = dict(credential_encrypter.decrypt(subscription.credentials))
|
||||
# TODO: Trying to invoke update api of the plugin trigger provider
|
||||
|
||||
# Merge credentials: if caller passed HIDDEN_VALUE, retain existing decrypted value
|
||||
merged_credentials: dict[str, Any] = {
|
||||
key: value if value != HIDDEN_VALUE else decrypted_credentials.get(key, UNKNOWN_VALUE)
|
||||
for key, value in credentials.items()
|
||||
}
|
||||
# FALLBACK: If the update api is not implemented, delete the previous subscription and create a new one
|
||||
|
||||
user_id = subscription.user_id
|
||||
# Delete the previous subscription
|
||||
user_id = subscription.user_id
|
||||
TriggerManager.unsubscribe_trigger(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
provider_id=provider_id,
|
||||
subscription=subscription.to_entity(),
|
||||
credentials=subscription.credentials,
|
||||
credential_type=credential_type,
|
||||
)
|
||||
|
||||
# TODO: Trying to invoke update api of the plugin trigger provider
|
||||
|
||||
# FALLBACK: If the update api is not implemented,
|
||||
# delete the previous subscription and create a new one
|
||||
|
||||
# Unsubscribe the previous subscription (external call, but we'll handle errors)
|
||||
try:
|
||||
TriggerManager.unsubscribe_trigger(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
provider_id=provider_id,
|
||||
subscription=subscription.to_entity(),
|
||||
credentials=decrypted_credentials,
|
||||
credential_type=credential_type,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Error unsubscribing trigger during rebuild", exc_info=e)
|
||||
# Continue anyway - the subscription might already be deleted externally
|
||||
|
||||
# Create a new subscription with the same subscription_id and endpoint_id (external call)
|
||||
new_subscription: TriggerSubscriptionEntity = TriggerManager.subscribe_trigger(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
provider_id=provider_id,
|
||||
endpoint=generate_plugin_trigger_endpoint_url(subscription.endpoint_id),
|
||||
parameters=parameters,
|
||||
credentials=merged_credentials,
|
||||
credential_type=credential_type,
|
||||
)
|
||||
|
||||
# Update the subscription in the same transaction
|
||||
# Inline update logic to reuse the same session
|
||||
if name is not None and name != subscription.name:
|
||||
existing = (
|
||||
session.query(TriggerSubscription)
|
||||
.filter_by(tenant_id=tenant_id, provider_id=str(provider_id), name=name)
|
||||
.first()
|
||||
)
|
||||
if existing and existing.id != subscription.id:
|
||||
raise ValueError(f"Subscription name '{name}' already exists for this provider")
|
||||
subscription.name = name
|
||||
|
||||
# Update parameters
|
||||
subscription.parameters = dict(parameters)
|
||||
|
||||
# Update credentials with merged (and encrypted) values
|
||||
subscription.credentials = dict(credential_encrypter.encrypt(merged_credentials))
|
||||
|
||||
# Update properties
|
||||
if new_subscription.properties:
|
||||
properties_encrypter, _ = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=provider_controller.get_properties_schema(),
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
subscription.properties = dict(properties_encrypter.encrypt(dict(new_subscription.properties)))
|
||||
|
||||
# Update expiration timestamp
|
||||
if new_subscription.expires_at is not None:
|
||||
subscription.expires_at = new_subscription.expires_at
|
||||
|
||||
# Commit the transaction
|
||||
session.commit()
|
||||
|
||||
# Clear subscription cache
|
||||
delete_cache_for_subscription(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=subscription.provider_id,
|
||||
subscription_id=subscription.id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# Rollback on any error
|
||||
session.rollback()
|
||||
logger.exception("Failed to rebuild trigger subscription", exc_info=e)
|
||||
raise
|
||||
# Create a new subscription with the same subscription_id and endpoint_id
|
||||
new_subscription: TriggerSubscriptionEntity = TriggerManager.subscribe_trigger(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
provider_id=provider_id,
|
||||
endpoint=generate_plugin_trigger_endpoint_url(subscription.endpoint_id),
|
||||
parameters=parameters,
|
||||
credentials=credentials,
|
||||
credential_type=credential_type,
|
||||
)
|
||||
TriggerProviderService.update_trigger_subscription(
|
||||
tenant_id=tenant_id,
|
||||
subscription_id=subscription.id,
|
||||
name=name,
|
||||
parameters=parameters,
|
||||
credentials=credentials,
|
||||
properties=new_subscription.properties,
|
||||
expires_at=new_subscription.expires_at,
|
||||
)
|
||||
|
||||
@ -863,18 +863,10 @@ class WebhookService:
|
||||
not_found_in_cache.append(node_id)
|
||||
continue
|
||||
|
||||
lock_key = f"{cls.__WEBHOOK_NODE_CACHE_KEY__}:apps:{app.id}:lock"
|
||||
lock = redis_client.lock(lock_key, timeout=10)
|
||||
lock_acquired = False
|
||||
|
||||
try:
|
||||
# acquire the lock with blocking and timeout
|
||||
lock_acquired = lock.acquire(blocking=True, blocking_timeout=10)
|
||||
if not lock_acquired:
|
||||
logger.warning("Failed to acquire lock for webhook sync, app %s", app.id)
|
||||
raise RuntimeError("Failed to acquire lock for webhook trigger synchronization")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with Session(db.engine) as session:
|
||||
try:
|
||||
# lock the concurrent webhook trigger creation
|
||||
redis_client.lock(f"{cls.__WEBHOOK_NODE_CACHE_KEY__}:apps:{app.id}:lock", timeout=10)
|
||||
# fetch the non-cached nodes from DB
|
||||
all_records = session.scalars(
|
||||
select(WorkflowWebhookTrigger).where(
|
||||
@ -911,16 +903,11 @@ class WebhookService:
|
||||
session.delete(nodes_id_in_db[node_id])
|
||||
redis_client.delete(f"{cls.__WEBHOOK_NODE_CACHE_KEY__}:{app.id}:{node_id}")
|
||||
session.commit()
|
||||
except Exception:
|
||||
logger.exception("Failed to sync webhook relationships for app %s", app.id)
|
||||
raise
|
||||
finally:
|
||||
# release the lock only if it was acquired
|
||||
if lock_acquired:
|
||||
try:
|
||||
lock.release()
|
||||
except Exception:
|
||||
logger.exception("Failed to release lock for webhook sync, app %s", app.id)
|
||||
except Exception:
|
||||
logger.exception("Failed to sync webhook relationships for app %s", app.id)
|
||||
raise
|
||||
finally:
|
||||
redis_client.delete(f"{cls.__WEBHOOK_NODE_CACHE_KEY__}:apps:{app.id}:lock")
|
||||
|
||||
@classmethod
|
||||
def generate_webhook_id(cls) -> str:
|
||||
|
||||
@ -7,14 +7,11 @@ CODE_LANGUAGE = CodeLanguage.JINJA2
|
||||
|
||||
|
||||
def test_jinja2():
|
||||
"""Test basic Jinja2 template rendering."""
|
||||
template = "Hello {{template}}"
|
||||
# Template must be base64 encoded to match the new safe embedding approach
|
||||
template_b64 = base64.b64encode(template.encode("utf-8")).decode("utf-8")
|
||||
inputs = base64.b64encode(b'{"template": "World"}').decode("utf-8")
|
||||
code = (
|
||||
Jinja2TemplateTransformer.get_runner_script()
|
||||
.replace(Jinja2TemplateTransformer._template_b64_placeholder, template_b64)
|
||||
.replace(Jinja2TemplateTransformer._code_placeholder, template)
|
||||
.replace(Jinja2TemplateTransformer._inputs_placeholder, inputs)
|
||||
)
|
||||
result = CodeExecutor.execute_code(
|
||||
@ -24,7 +21,6 @@ def test_jinja2():
|
||||
|
||||
|
||||
def test_jinja2_with_code_template():
|
||||
"""Test template rendering via the high-level workflow API."""
|
||||
result = CodeExecutor.execute_workflow_code_template(
|
||||
language=CODE_LANGUAGE, code="Hello {{template}}", inputs={"template": "World"}
|
||||
)
|
||||
@ -32,64 +28,7 @@ def test_jinja2_with_code_template():
|
||||
|
||||
|
||||
def test_jinja2_get_runner_script():
|
||||
"""Test that runner script contains required placeholders."""
|
||||
runner_script = Jinja2TemplateTransformer.get_runner_script()
|
||||
assert runner_script.count(Jinja2TemplateTransformer._template_b64_placeholder) == 1
|
||||
assert runner_script.count(Jinja2TemplateTransformer._code_placeholder) == 1
|
||||
assert runner_script.count(Jinja2TemplateTransformer._inputs_placeholder) == 1
|
||||
assert runner_script.count(Jinja2TemplateTransformer._result_tag) == 2
|
||||
|
||||
|
||||
def test_jinja2_template_with_special_characters():
|
||||
"""
|
||||
Test that templates with special characters (quotes, newlines) render correctly.
|
||||
This is a regression test for issue #26818 where textarea pre-fill values
|
||||
containing special characters would break template rendering.
|
||||
"""
|
||||
# Template with triple quotes, single quotes, double quotes, and newlines
|
||||
template = """<html>
|
||||
<body>
|
||||
<input value="{{ task.get('Task ID', '') }}"/>
|
||||
<textarea>{{ task.get('Issues', 'No issues reported') }}</textarea>
|
||||
<p>Status: "{{ status }}"</p>
|
||||
<pre>'''code block'''</pre>
|
||||
</body>
|
||||
</html>"""
|
||||
inputs = {"task": {"Task ID": "TASK-123", "Issues": "Line 1\nLine 2\nLine 3"}, "status": "completed"}
|
||||
|
||||
result = CodeExecutor.execute_workflow_code_template(language=CODE_LANGUAGE, code=template, inputs=inputs)
|
||||
|
||||
# Verify the template rendered correctly with all special characters
|
||||
output = result["result"]
|
||||
assert 'value="TASK-123"' in output
|
||||
assert "<textarea>Line 1\nLine 2\nLine 3</textarea>" in output
|
||||
assert 'Status: "completed"' in output
|
||||
assert "'''code block'''" in output
|
||||
|
||||
|
||||
def test_jinja2_template_with_html_textarea_prefill():
|
||||
"""
|
||||
Specific test for HTML textarea with Jinja2 variable pre-fill.
|
||||
Verifies fix for issue #26818.
|
||||
"""
|
||||
template = "<textarea name='notes'>{{ notes }}</textarea>"
|
||||
notes_content = "This is a multi-line note.\nWith special chars: 'single' and \"double\" quotes."
|
||||
inputs = {"notes": notes_content}
|
||||
|
||||
result = CodeExecutor.execute_workflow_code_template(language=CODE_LANGUAGE, code=template, inputs=inputs)
|
||||
|
||||
expected_output = f"<textarea name='notes'>{notes_content}</textarea>"
|
||||
assert result["result"] == expected_output
|
||||
|
||||
|
||||
def test_jinja2_assemble_runner_script_encodes_template():
|
||||
"""Test that assemble_runner_script properly base64 encodes the template."""
|
||||
template = "Hello {{ name }}!"
|
||||
inputs = {"name": "World"}
|
||||
|
||||
script = Jinja2TemplateTransformer.assemble_runner_script(template, inputs)
|
||||
|
||||
# The template should be base64 encoded in the script
|
||||
template_b64 = base64.b64encode(template.encode("utf-8")).decode("utf-8")
|
||||
assert template_b64 in script
|
||||
# The raw template should NOT appear in the script (it's encoded)
|
||||
assert "Hello {{ name }}!" not in script
|
||||
|
||||
@ -1,682 +0,0 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
|
||||
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.trigger.entities.entities import Subscription as TriggerSubscriptionEntity
|
||||
from extensions.ext_database import db
|
||||
from models.provider_ids import TriggerProviderID
|
||||
from models.trigger import TriggerSubscription
|
||||
from services.trigger.trigger_provider_service import TriggerProviderService
|
||||
|
||||
|
||||
class TestTriggerProviderService:
|
||||
"""Integration tests for TriggerProviderService using testcontainers."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_external_service_dependencies(self):
|
||||
"""Mock setup for external service dependencies."""
|
||||
with (
|
||||
patch("services.trigger.trigger_provider_service.TriggerManager") as mock_trigger_manager,
|
||||
patch("services.trigger.trigger_provider_service.redis_client") as mock_redis_client,
|
||||
patch("services.trigger.trigger_provider_service.delete_cache_for_subscription") as mock_delete_cache,
|
||||
patch("services.account_service.FeatureService") as mock_account_feature_service,
|
||||
):
|
||||
# Setup default mock returns
|
||||
mock_provider_controller = MagicMock()
|
||||
mock_provider_controller.get_credential_schema_config.return_value = MagicMock()
|
||||
mock_provider_controller.get_properties_schema.return_value = MagicMock()
|
||||
mock_trigger_manager.get_trigger_provider.return_value = mock_provider_controller
|
||||
|
||||
# Mock redis lock
|
||||
mock_lock = MagicMock()
|
||||
mock_lock.__enter__ = MagicMock(return_value=None)
|
||||
mock_lock.__exit__ = MagicMock(return_value=None)
|
||||
mock_redis_client.lock.return_value = mock_lock
|
||||
|
||||
# Setup account feature service mock
|
||||
mock_account_feature_service.get_system_features.return_value.is_allow_register = True
|
||||
|
||||
yield {
|
||||
"trigger_manager": mock_trigger_manager,
|
||||
"redis_client": mock_redis_client,
|
||||
"delete_cache": mock_delete_cache,
|
||||
"provider_controller": mock_provider_controller,
|
||||
"account_feature_service": mock_account_feature_service,
|
||||
}
|
||||
|
||||
def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Helper method to create a test account and tenant for testing.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
mock_external_service_dependencies: Mock dependencies
|
||||
|
||||
Returns:
|
||||
tuple: (account, tenant) - Created account and tenant instances
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
from services.account_service import AccountService, TenantService
|
||||
|
||||
# Setup mocks for account creation
|
||||
mock_external_service_dependencies[
|
||||
"account_feature_service"
|
||||
].get_system_features.return_value.is_allow_register = True
|
||||
mock_external_service_dependencies[
|
||||
"trigger_manager"
|
||||
].get_trigger_provider.return_value = mock_external_service_dependencies["provider_controller"]
|
||||
|
||||
# Create account and tenant
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
return account, tenant
|
||||
|
||||
def _create_test_subscription(
|
||||
self,
|
||||
db_session_with_containers,
|
||||
tenant_id,
|
||||
user_id,
|
||||
provider_id,
|
||||
credential_type,
|
||||
credentials,
|
||||
mock_external_service_dependencies,
|
||||
):
|
||||
"""
|
||||
Helper method to create a test trigger subscription.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session
|
||||
tenant_id: Tenant ID
|
||||
user_id: User ID
|
||||
provider_id: Provider ID
|
||||
credential_type: Credential type
|
||||
credentials: Credentials dict
|
||||
mock_external_service_dependencies: Mock dependencies
|
||||
|
||||
Returns:
|
||||
TriggerSubscription: Created subscription instance
|
||||
"""
|
||||
fake = Faker()
|
||||
from core.helper.provider_cache import NoOpProviderCredentialCache
|
||||
from core.helper.provider_encryption import create_provider_encrypter
|
||||
|
||||
# Use mock provider controller to encrypt credentials
|
||||
provider_controller = mock_external_service_dependencies["provider_controller"]
|
||||
|
||||
# Create encrypter for credentials
|
||||
credential_encrypter, _ = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=provider_controller.get_credential_schema_config(credential_type),
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
|
||||
subscription = TriggerSubscription(
|
||||
name=fake.word(),
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
provider_id=str(provider_id),
|
||||
endpoint_id=fake.uuid4(),
|
||||
parameters={"param1": "value1"},
|
||||
properties={"prop1": "value1"},
|
||||
credentials=dict(credential_encrypter.encrypt(credentials)),
|
||||
credential_type=credential_type.value,
|
||||
credential_expires_at=-1,
|
||||
expires_at=-1,
|
||||
)
|
||||
|
||||
db.session.add(subscription)
|
||||
db.session.commit()
|
||||
db.session.refresh(subscription)
|
||||
|
||||
return subscription
|
||||
|
||||
def test_rebuild_trigger_subscription_success_with_merged_credentials(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful rebuild with credential merging (HIDDEN_VALUE handling).
|
||||
|
||||
This test verifies:
|
||||
- Credentials are properly merged (HIDDEN_VALUE replaced with existing values)
|
||||
- Single transaction wraps all operations
|
||||
- Merged credentials are used for subscribe and update
|
||||
- Database state is correctly updated
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
|
||||
credential_type = CredentialType.API_KEY
|
||||
|
||||
# Create initial subscription with credentials
|
||||
original_credentials = {"api_key": "original-secret-key", "api_secret": "original-secret"}
|
||||
subscription = self._create_test_subscription(
|
||||
db_session_with_containers,
|
||||
tenant.id,
|
||||
account.id,
|
||||
provider_id,
|
||||
credential_type,
|
||||
original_credentials,
|
||||
mock_external_service_dependencies,
|
||||
)
|
||||
|
||||
# Prepare new credentials with HIDDEN_VALUE for api_key (should keep original)
|
||||
# and new value for api_secret (should update)
|
||||
new_credentials = {
|
||||
"api_key": HIDDEN_VALUE, # Should be replaced with original
|
||||
"api_secret": "new-secret-value", # Should be updated
|
||||
}
|
||||
|
||||
# Mock subscribe_trigger to return a new subscription entity
|
||||
new_subscription_entity = TriggerSubscriptionEntity(
|
||||
endpoint=subscription.endpoint_id,
|
||||
parameters={"param1": "value1"},
|
||||
properties={"prop1": "new_prop_value"},
|
||||
expires_at=1234567890,
|
||||
)
|
||||
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.return_value = new_subscription_entity
|
||||
|
||||
# Mock unsubscribe_trigger
|
||||
mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.return_value = MagicMock()
|
||||
|
||||
# Execute rebuild
|
||||
TriggerProviderService.rebuild_trigger_subscription(
|
||||
tenant_id=tenant.id,
|
||||
provider_id=provider_id,
|
||||
subscription_id=subscription.id,
|
||||
credentials=new_credentials,
|
||||
parameters={"param1": "updated_value"},
|
||||
name="updated_name",
|
||||
)
|
||||
|
||||
# Verify unsubscribe was called with decrypted original credentials
|
||||
mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.assert_called_once()
|
||||
unsubscribe_call_args = mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.call_args
|
||||
assert unsubscribe_call_args.kwargs["tenant_id"] == tenant.id
|
||||
assert unsubscribe_call_args.kwargs["provider_id"] == provider_id
|
||||
assert unsubscribe_call_args.kwargs["credential_type"] == credential_type
|
||||
|
||||
# Verify subscribe was called with merged credentials (api_key from original, api_secret new)
|
||||
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.assert_called_once()
|
||||
subscribe_call_args = mock_external_service_dependencies["trigger_manager"].subscribe_trigger.call_args
|
||||
subscribe_credentials = subscribe_call_args.kwargs["credentials"]
|
||||
assert subscribe_credentials["api_key"] == original_credentials["api_key"] # Merged from original
|
||||
assert subscribe_credentials["api_secret"] == "new-secret-value" # New value
|
||||
|
||||
# Verify database state was updated
|
||||
db.session.refresh(subscription)
|
||||
assert subscription.name == "updated_name"
|
||||
assert subscription.parameters == {"param1": "updated_value"}
|
||||
|
||||
# Verify credentials in DB were updated with merged values (decrypt to check)
|
||||
from core.helper.provider_cache import NoOpProviderCredentialCache
|
||||
from core.helper.provider_encryption import create_provider_encrypter
|
||||
|
||||
# Use mock provider controller to decrypt credentials
|
||||
provider_controller = mock_external_service_dependencies["provider_controller"]
|
||||
credential_encrypter, _ = create_provider_encrypter(
|
||||
tenant_id=tenant.id,
|
||||
config=provider_controller.get_credential_schema_config(credential_type),
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
decrypted_db_credentials = dict(credential_encrypter.decrypt(subscription.credentials))
|
||||
assert decrypted_db_credentials["api_key"] == original_credentials["api_key"]
|
||||
assert decrypted_db_credentials["api_secret"] == "new-secret-value"
|
||||
|
||||
# Verify cache was cleared
|
||||
mock_external_service_dependencies["delete_cache"].assert_called_once_with(
|
||||
tenant_id=tenant.id,
|
||||
provider_id=subscription.provider_id,
|
||||
subscription_id=subscription.id,
|
||||
)
|
||||
|
||||
def test_rebuild_trigger_subscription_with_all_new_credentials(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test rebuild when all credentials are new (no HIDDEN_VALUE).
|
||||
|
||||
This test verifies:
|
||||
- All new credentials are used when no HIDDEN_VALUE is present
|
||||
- Merged credentials contain only new values
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
|
||||
credential_type = CredentialType.API_KEY
|
||||
|
||||
# Create initial subscription
|
||||
original_credentials = {"api_key": "original-key", "api_secret": "original-secret"}
|
||||
subscription = self._create_test_subscription(
|
||||
db_session_with_containers,
|
||||
tenant.id,
|
||||
account.id,
|
||||
provider_id,
|
||||
credential_type,
|
||||
original_credentials,
|
||||
mock_external_service_dependencies,
|
||||
)
|
||||
|
||||
# All new credentials (no HIDDEN_VALUE)
|
||||
new_credentials = {
|
||||
"api_key": "completely-new-key",
|
||||
"api_secret": "completely-new-secret",
|
||||
}
|
||||
|
||||
new_subscription_entity = TriggerSubscriptionEntity(
|
||||
endpoint=subscription.endpoint_id,
|
||||
parameters={},
|
||||
properties={},
|
||||
expires_at=-1,
|
||||
)
|
||||
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.return_value = new_subscription_entity
|
||||
mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.return_value = MagicMock()
|
||||
|
||||
# Execute rebuild
|
||||
TriggerProviderService.rebuild_trigger_subscription(
|
||||
tenant_id=tenant.id,
|
||||
provider_id=provider_id,
|
||||
subscription_id=subscription.id,
|
||||
credentials=new_credentials,
|
||||
parameters={},
|
||||
)
|
||||
|
||||
# Verify subscribe was called with all new credentials
|
||||
subscribe_call_args = mock_external_service_dependencies["trigger_manager"].subscribe_trigger.call_args
|
||||
subscribe_credentials = subscribe_call_args.kwargs["credentials"]
|
||||
assert subscribe_credentials["api_key"] == "completely-new-key"
|
||||
assert subscribe_credentials["api_secret"] == "completely-new-secret"
|
||||
|
||||
def test_rebuild_trigger_subscription_with_all_hidden_values(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test rebuild when all credentials are HIDDEN_VALUE (preserve all existing).
|
||||
|
||||
This test verifies:
|
||||
- All HIDDEN_VALUE credentials are replaced with existing values
|
||||
- Original credentials are preserved
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
|
||||
credential_type = CredentialType.API_KEY
|
||||
|
||||
original_credentials = {"api_key": "original-key", "api_secret": "original-secret"}
|
||||
subscription = self._create_test_subscription(
|
||||
db_session_with_containers,
|
||||
tenant.id,
|
||||
account.id,
|
||||
provider_id,
|
||||
credential_type,
|
||||
original_credentials,
|
||||
mock_external_service_dependencies,
|
||||
)
|
||||
|
||||
# All HIDDEN_VALUE (should preserve all original)
|
||||
new_credentials = {
|
||||
"api_key": HIDDEN_VALUE,
|
||||
"api_secret": HIDDEN_VALUE,
|
||||
}
|
||||
|
||||
new_subscription_entity = TriggerSubscriptionEntity(
|
||||
endpoint=subscription.endpoint_id,
|
||||
parameters={},
|
||||
properties={},
|
||||
expires_at=-1,
|
||||
)
|
||||
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.return_value = new_subscription_entity
|
||||
mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.return_value = MagicMock()
|
||||
|
||||
# Execute rebuild
|
||||
TriggerProviderService.rebuild_trigger_subscription(
|
||||
tenant_id=tenant.id,
|
||||
provider_id=provider_id,
|
||||
subscription_id=subscription.id,
|
||||
credentials=new_credentials,
|
||||
parameters={},
|
||||
)
|
||||
|
||||
# Verify subscribe was called with all original credentials
|
||||
subscribe_call_args = mock_external_service_dependencies["trigger_manager"].subscribe_trigger.call_args
|
||||
subscribe_credentials = subscribe_call_args.kwargs["credentials"]
|
||||
assert subscribe_credentials["api_key"] == original_credentials["api_key"]
|
||||
assert subscribe_credentials["api_secret"] == original_credentials["api_secret"]
|
||||
|
||||
def test_rebuild_trigger_subscription_with_missing_key_uses_unknown_value(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test rebuild when HIDDEN_VALUE is used for a key that doesn't exist in original.
|
||||
|
||||
This test verifies:
|
||||
- UNKNOWN_VALUE is used when HIDDEN_VALUE key doesn't exist in original credentials
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
|
||||
credential_type = CredentialType.API_KEY
|
||||
|
||||
# Original has only api_key
|
||||
original_credentials = {"api_key": "original-key"}
|
||||
subscription = self._create_test_subscription(
|
||||
db_session_with_containers,
|
||||
tenant.id,
|
||||
account.id,
|
||||
provider_id,
|
||||
credential_type,
|
||||
original_credentials,
|
||||
mock_external_service_dependencies,
|
||||
)
|
||||
|
||||
# HIDDEN_VALUE for non-existent key should use UNKNOWN_VALUE
|
||||
new_credentials = {
|
||||
"api_key": HIDDEN_VALUE,
|
||||
"non_existent_key": HIDDEN_VALUE, # This key doesn't exist in original
|
||||
}
|
||||
|
||||
new_subscription_entity = TriggerSubscriptionEntity(
|
||||
endpoint=subscription.endpoint_id,
|
||||
parameters={},
|
||||
properties={},
|
||||
expires_at=-1,
|
||||
)
|
||||
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.return_value = new_subscription_entity
|
||||
mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.return_value = MagicMock()
|
||||
|
||||
# Execute rebuild
|
||||
TriggerProviderService.rebuild_trigger_subscription(
|
||||
tenant_id=tenant.id,
|
||||
provider_id=provider_id,
|
||||
subscription_id=subscription.id,
|
||||
credentials=new_credentials,
|
||||
parameters={},
|
||||
)
|
||||
|
||||
# Verify subscribe was called with original api_key and UNKNOWN_VALUE for missing key
|
||||
subscribe_call_args = mock_external_service_dependencies["trigger_manager"].subscribe_trigger.call_args
|
||||
subscribe_credentials = subscribe_call_args.kwargs["credentials"]
|
||||
assert subscribe_credentials["api_key"] == original_credentials["api_key"]
|
||||
assert subscribe_credentials["non_existent_key"] == UNKNOWN_VALUE
|
||||
|
||||
def test_rebuild_trigger_subscription_rollback_on_error(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test that transaction is rolled back on error.
|
||||
|
||||
This test verifies:
|
||||
- Database transaction is rolled back when an error occurs
|
||||
- Original subscription state is preserved
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
|
||||
credential_type = CredentialType.API_KEY
|
||||
|
||||
original_credentials = {"api_key": "original-key"}
|
||||
subscription = self._create_test_subscription(
|
||||
db_session_with_containers,
|
||||
tenant.id,
|
||||
account.id,
|
||||
provider_id,
|
||||
credential_type,
|
||||
original_credentials,
|
||||
mock_external_service_dependencies,
|
||||
)
|
||||
|
||||
original_name = subscription.name
|
||||
original_parameters = subscription.parameters.copy()
|
||||
|
||||
# Make subscribe_trigger raise an error
|
||||
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.side_effect = ValueError(
|
||||
"Subscribe failed"
|
||||
)
|
||||
mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.return_value = MagicMock()
|
||||
|
||||
# Execute rebuild and expect error
|
||||
with pytest.raises(ValueError, match="Subscribe failed"):
|
||||
TriggerProviderService.rebuild_trigger_subscription(
|
||||
tenant_id=tenant.id,
|
||||
provider_id=provider_id,
|
||||
subscription_id=subscription.id,
|
||||
credentials={"api_key": "new-key"},
|
||||
parameters={},
|
||||
)
|
||||
|
||||
# Verify subscription state was not changed (rolled back)
|
||||
db.session.refresh(subscription)
|
||||
assert subscription.name == original_name
|
||||
assert subscription.parameters == original_parameters
|
||||
|
||||
def test_rebuild_trigger_subscription_unsubscribe_error_continues(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test that unsubscribe errors are handled gracefully and operation continues.
|
||||
|
||||
This test verifies:
|
||||
- Unsubscribe errors are caught and logged but don't stop the rebuild
|
||||
- Rebuild continues even if unsubscribe fails
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
|
||||
credential_type = CredentialType.API_KEY
|
||||
|
||||
original_credentials = {"api_key": "original-key"}
|
||||
subscription = self._create_test_subscription(
|
||||
db_session_with_containers,
|
||||
tenant.id,
|
||||
account.id,
|
||||
provider_id,
|
||||
credential_type,
|
||||
original_credentials,
|
||||
mock_external_service_dependencies,
|
||||
)
|
||||
|
||||
# Make unsubscribe_trigger raise an error (should be caught and continue)
|
||||
mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.side_effect = ValueError(
|
||||
"Unsubscribe failed"
|
||||
)
|
||||
|
||||
new_subscription_entity = TriggerSubscriptionEntity(
|
||||
endpoint=subscription.endpoint_id,
|
||||
parameters={},
|
||||
properties={},
|
||||
expires_at=-1,
|
||||
)
|
||||
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.return_value = new_subscription_entity
|
||||
|
||||
# Execute rebuild - should succeed despite unsubscribe error
|
||||
TriggerProviderService.rebuild_trigger_subscription(
|
||||
tenant_id=tenant.id,
|
||||
provider_id=provider_id,
|
||||
subscription_id=subscription.id,
|
||||
credentials={"api_key": "new-key"},
|
||||
parameters={},
|
||||
)
|
||||
|
||||
# Verify subscribe was still called (operation continued)
|
||||
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.assert_called_once()
|
||||
|
||||
# Verify subscription was updated
|
||||
db.session.refresh(subscription)
|
||||
assert subscription.parameters == {}
|
||||
|
||||
def test_rebuild_trigger_subscription_subscription_not_found(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test error when subscription is not found.
|
||||
|
||||
This test verifies:
|
||||
- Proper error is raised when subscription doesn't exist
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
|
||||
fake_subscription_id = fake.uuid4()
|
||||
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
TriggerProviderService.rebuild_trigger_subscription(
|
||||
tenant_id=tenant.id,
|
||||
provider_id=provider_id,
|
||||
subscription_id=fake_subscription_id,
|
||||
credentials={},
|
||||
parameters={},
|
||||
)
|
||||
|
||||
def test_rebuild_trigger_subscription_provider_not_found(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test error when provider is not found.
|
||||
|
||||
This test verifies:
|
||||
- Proper error is raised when provider doesn't exist
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
provider_id = TriggerProviderID("non_existent_org/non_existent_plugin/non_existent_provider")
|
||||
|
||||
# Make get_trigger_provider return None
|
||||
mock_external_service_dependencies["trigger_manager"].get_trigger_provider.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="Provider.*not found"):
|
||||
TriggerProviderService.rebuild_trigger_subscription(
|
||||
tenant_id=tenant.id,
|
||||
provider_id=provider_id,
|
||||
subscription_id=fake.uuid4(),
|
||||
credentials={},
|
||||
parameters={},
|
||||
)
|
||||
|
||||
def test_rebuild_trigger_subscription_unsupported_credential_type(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test error when credential type is not supported for rebuild.
|
||||
|
||||
This test verifies:
|
||||
- Proper error is raised for unsupported credential types (not OAUTH2 or API_KEY)
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
|
||||
credential_type = CredentialType.UNAUTHORIZED # Not supported
|
||||
|
||||
subscription = self._create_test_subscription(
|
||||
db_session_with_containers,
|
||||
tenant.id,
|
||||
account.id,
|
||||
provider_id,
|
||||
credential_type,
|
||||
{},
|
||||
mock_external_service_dependencies,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Credential type not supported for rebuild"):
|
||||
TriggerProviderService.rebuild_trigger_subscription(
|
||||
tenant_id=tenant.id,
|
||||
provider_id=provider_id,
|
||||
subscription_id=subscription.id,
|
||||
credentials={},
|
||||
parameters={},
|
||||
)
|
||||
|
||||
def test_rebuild_trigger_subscription_name_uniqueness_check(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test that name uniqueness is checked when updating name.
|
||||
|
||||
This test verifies:
|
||||
- Error is raised when new name conflicts with existing subscription
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
|
||||
credential_type = CredentialType.API_KEY
|
||||
|
||||
# Create first subscription
|
||||
subscription1 = self._create_test_subscription(
|
||||
db_session_with_containers,
|
||||
tenant.id,
|
||||
account.id,
|
||||
provider_id,
|
||||
credential_type,
|
||||
{"api_key": "key1"},
|
||||
mock_external_service_dependencies,
|
||||
)
|
||||
|
||||
# Create second subscription with different name
|
||||
subscription2 = self._create_test_subscription(
|
||||
db_session_with_containers,
|
||||
tenant.id,
|
||||
account.id,
|
||||
provider_id,
|
||||
credential_type,
|
||||
{"api_key": "key2"},
|
||||
mock_external_service_dependencies,
|
||||
)
|
||||
|
||||
new_subscription_entity = TriggerSubscriptionEntity(
|
||||
endpoint=subscription2.endpoint_id,
|
||||
parameters={},
|
||||
properties={},
|
||||
expires_at=-1,
|
||||
)
|
||||
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.return_value = new_subscription_entity
|
||||
mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.return_value = MagicMock()
|
||||
|
||||
# Try to rename subscription2 to subscription1's name (should fail)
|
||||
with pytest.raises(ValueError, match="already exists"):
|
||||
TriggerProviderService.rebuild_trigger_subscription(
|
||||
tenant_id=tenant.id,
|
||||
provider_id=provider_id,
|
||||
subscription_id=subscription2.id,
|
||||
credentials={"api_key": "new-key"},
|
||||
parameters={},
|
||||
name=subscription1.name, # Conflicting name
|
||||
)
|
||||
@ -705,207 +705,3 @@ class TestWorkflowToolManageService:
|
||||
db.session.refresh(created_tool)
|
||||
assert created_tool.name == first_tool_name
|
||||
assert created_tool.updated_at is not None
|
||||
|
||||
def test_create_workflow_tool_with_file_parameter_default(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test workflow tool creation with FILE parameter having a file object as default.
|
||||
|
||||
This test verifies:
|
||||
- FILE parameters can have file object defaults
|
||||
- The default value (dict with id/base64Url) is properly handled
|
||||
- Tool creation succeeds without Pydantic validation errors
|
||||
|
||||
Related issue: Array[File] default value causes Pydantic validation errors.
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create test data
|
||||
app, account, workflow = self._create_test_app_and_account(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Create workflow graph with a FILE variable that has a default value
|
||||
workflow_graph = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "start_node",
|
||||
"data": {
|
||||
"type": "start",
|
||||
"variables": [
|
||||
{
|
||||
"variable": "document",
|
||||
"label": "Document",
|
||||
"type": "file",
|
||||
"required": False,
|
||||
"default": {"id": fake.uuid4(), "base64Url": ""},
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
workflow.graph = json.dumps(workflow_graph)
|
||||
|
||||
# Setup workflow tool parameters with FILE type
|
||||
file_parameters = [
|
||||
{
|
||||
"name": "document",
|
||||
"description": "Upload a document",
|
||||
"form": "form",
|
||||
"type": "file",
|
||||
"required": False,
|
||||
}
|
||||
]
|
||||
|
||||
# Execute the method under test
|
||||
# Note: from_db is mocked, so this test primarily validates the parameter configuration
|
||||
result = WorkflowToolManageService.create_workflow_tool(
|
||||
user_id=account.id,
|
||||
tenant_id=account.current_tenant.id,
|
||||
workflow_app_id=app.id,
|
||||
name=fake.word(),
|
||||
label=fake.word(),
|
||||
icon={"type": "emoji", "emoji": "📄"},
|
||||
description=fake.text(max_nb_chars=200),
|
||||
parameters=file_parameters,
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert result == {"result": "success"}
|
||||
|
||||
def test_create_workflow_tool_with_files_parameter_default(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test workflow tool creation with FILES (Array[File]) parameter having file objects as default.
|
||||
|
||||
This test verifies:
|
||||
- FILES parameters can have a list of file objects as default
|
||||
- The default value (list of dicts with id/base64Url) is properly handled
|
||||
- Tool creation succeeds without Pydantic validation errors
|
||||
|
||||
Related issue: Array[File] default value causes 4 Pydantic validation errors
|
||||
because PluginParameter.default only accepts Union[float, int, str, bool] | None.
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create test data
|
||||
app, account, workflow = self._create_test_app_and_account(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Create workflow graph with a FILE_LIST variable that has a default value
|
||||
workflow_graph = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "start_node",
|
||||
"data": {
|
||||
"type": "start",
|
||||
"variables": [
|
||||
{
|
||||
"variable": "documents",
|
||||
"label": "Documents",
|
||||
"type": "file-list",
|
||||
"required": False,
|
||||
"default": [
|
||||
{"id": fake.uuid4(), "base64Url": ""},
|
||||
{"id": fake.uuid4(), "base64Url": ""},
|
||||
],
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
workflow.graph = json.dumps(workflow_graph)
|
||||
|
||||
# Setup workflow tool parameters with FILES type
|
||||
files_parameters = [
|
||||
{
|
||||
"name": "documents",
|
||||
"description": "Upload multiple documents",
|
||||
"form": "form",
|
||||
"type": "files",
|
||||
"required": False,
|
||||
}
|
||||
]
|
||||
|
||||
# Execute the method under test
|
||||
# Note: from_db is mocked, so this test primarily validates the parameter configuration
|
||||
result = WorkflowToolManageService.create_workflow_tool(
|
||||
user_id=account.id,
|
||||
tenant_id=account.current_tenant.id,
|
||||
workflow_app_id=app.id,
|
||||
name=fake.word(),
|
||||
label=fake.word(),
|
||||
icon={"type": "emoji", "emoji": "📁"},
|
||||
description=fake.text(max_nb_chars=200),
|
||||
parameters=files_parameters,
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert result == {"result": "success"}
|
||||
|
||||
def test_create_workflow_tool_db_commit_before_validation(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test that database commit happens before validation, causing DB pollution on validation failure.
|
||||
|
||||
This test verifies the second bug:
|
||||
- WorkflowToolProvider is committed to database BEFORE from_db validation
|
||||
- If validation fails, the record remains in the database
|
||||
- Subsequent attempts fail with "Tool already exists" error
|
||||
|
||||
This demonstrates why we need to validate BEFORE database commit.
|
||||
"""
|
||||
|
||||
fake = Faker()
|
||||
|
||||
# Create test data
|
||||
app, account, workflow = self._create_test_app_and_account(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
tool_name = fake.word()
|
||||
|
||||
# Mock from_db to raise validation error
|
||||
mock_external_service_dependencies["workflow_tool_provider_controller"].from_db.side_effect = ValueError(
|
||||
"Validation failed: default parameter type mismatch"
|
||||
)
|
||||
|
||||
# Attempt to create workflow tool (will fail at validation stage)
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
WorkflowToolManageService.create_workflow_tool(
|
||||
user_id=account.id,
|
||||
tenant_id=account.current_tenant.id,
|
||||
workflow_app_id=app.id,
|
||||
name=tool_name,
|
||||
label=fake.word(),
|
||||
icon={"type": "emoji", "emoji": "🔧"},
|
||||
description=fake.text(max_nb_chars=200),
|
||||
parameters=self._create_test_workflow_tool_parameters(),
|
||||
)
|
||||
|
||||
assert "Validation failed" in str(exc_info.value)
|
||||
|
||||
# Verify the tool was NOT created in database
|
||||
# This is the expected behavior (no pollution)
|
||||
from extensions.ext_database import db
|
||||
|
||||
tool_count = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
.where(
|
||||
WorkflowToolProvider.tenant_id == account.current_tenant.id,
|
||||
WorkflowToolProvider.name == tool_name,
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
# The record should NOT exist because the transaction should be rolled back
|
||||
# Currently, due to the bug, the record might exist (this test documents the bug)
|
||||
# After the fix, this should always be 0
|
||||
# For now, we document that the record may exist, demonstrating the bug
|
||||
# assert tool_count == 0 # Expected after fix
|
||||
|
||||
@ -12,12 +12,10 @@ class TestJinja2CodeExecutor(CodeExecutorTestMixin):
|
||||
_, Jinja2TemplateTransformer = self.jinja2_imports
|
||||
|
||||
template = "Hello {{template}}"
|
||||
# Template must be base64 encoded to match the new safe embedding approach
|
||||
template_b64 = base64.b64encode(template.encode("utf-8")).decode("utf-8")
|
||||
inputs = base64.b64encode(b'{"template": "World"}').decode("utf-8")
|
||||
code = (
|
||||
Jinja2TemplateTransformer.get_runner_script()
|
||||
.replace(Jinja2TemplateTransformer._template_b64_placeholder, template_b64)
|
||||
.replace(Jinja2TemplateTransformer._code_placeholder, template)
|
||||
.replace(Jinja2TemplateTransformer._inputs_placeholder, inputs)
|
||||
)
|
||||
result = CodeExecutor.execute_code(
|
||||
@ -39,34 +37,6 @@ class TestJinja2CodeExecutor(CodeExecutorTestMixin):
|
||||
_, Jinja2TemplateTransformer = self.jinja2_imports
|
||||
|
||||
runner_script = Jinja2TemplateTransformer.get_runner_script()
|
||||
assert runner_script.count(Jinja2TemplateTransformer._template_b64_placeholder) == 1
|
||||
assert runner_script.count(Jinja2TemplateTransformer._code_placeholder) == 1
|
||||
assert runner_script.count(Jinja2TemplateTransformer._inputs_placeholder) == 1
|
||||
assert runner_script.count(Jinja2TemplateTransformer._result_tag) == 2
|
||||
|
||||
def test_jinja2_template_with_special_characters(self, flask_app_with_containers):
|
||||
"""
|
||||
Test that templates with special characters (quotes, newlines) render correctly.
|
||||
This is a regression test for issue #26818 where textarea pre-fill values
|
||||
containing special characters would break template rendering.
|
||||
"""
|
||||
CodeExecutor, CodeLanguage = self.code_executor_imports
|
||||
|
||||
# Template with triple quotes, single quotes, double quotes, and newlines
|
||||
template = """<html>
|
||||
<body>
|
||||
<input value="{{ task.get('Task ID', '') }}"/>
|
||||
<textarea>{{ task.get('Issues', 'No issues reported') }}</textarea>
|
||||
<p>Status: "{{ status }}"</p>
|
||||
<pre>'''code block'''</pre>
|
||||
</body>
|
||||
</html>"""
|
||||
inputs = {"task": {"Task ID": "TASK-123", "Issues": "Line 1\nLine 2\nLine 3"}, "status": "completed"}
|
||||
|
||||
result = CodeExecutor.execute_workflow_code_template(language=CodeLanguage.JINJA2, code=template, inputs=inputs)
|
||||
|
||||
# Verify the template rendered correctly with all special characters
|
||||
output = result["result"]
|
||||
assert 'value="TASK-123"' in output
|
||||
assert "<textarea>Line 1\nLine 2\nLine 3</textarea>" in output
|
||||
assert 'Status: "completed"' in output
|
||||
assert "'''code block'''" in output
|
||||
|
||||
@ -1,145 +0,0 @@
|
||||
"""Unit tests for load balancing credential validation APIs."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import builtins
|
||||
import importlib
|
||||
import sys
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask.views import MethodView
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
|
||||
if not hasattr(builtins, "MethodView"):
|
||||
builtins.MethodView = MethodView # type: ignore[attr-defined]
|
||||
|
||||
from models.account import TenantAccountRole
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app() -> Flask:
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def load_balancing_module(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Reload controller module with lightweight decorators for testing."""
|
||||
|
||||
from controllers.console import console_ns, wraps
|
||||
from libs import login
|
||||
|
||||
def _noop(func):
|
||||
return func
|
||||
|
||||
monkeypatch.setattr(login, "login_required", _noop)
|
||||
monkeypatch.setattr(wraps, "setup_required", _noop)
|
||||
monkeypatch.setattr(wraps, "account_initialization_required", _noop)
|
||||
|
||||
def _noop_route(*args, **kwargs): # type: ignore[override]
|
||||
def _decorator(cls):
|
||||
return cls
|
||||
|
||||
return _decorator
|
||||
|
||||
monkeypatch.setattr(console_ns, "route", _noop_route)
|
||||
|
||||
module_name = "controllers.console.workspace.load_balancing_config"
|
||||
sys.modules.pop(module_name, None)
|
||||
module = importlib.import_module(module_name)
|
||||
return module
|
||||
|
||||
|
||||
def _mock_user(role: TenantAccountRole) -> SimpleNamespace:
|
||||
return SimpleNamespace(current_role=role)
|
||||
|
||||
|
||||
def _prepare_context(module, monkeypatch: pytest.MonkeyPatch, role=TenantAccountRole.OWNER):
|
||||
user = _mock_user(role)
|
||||
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "tenant-123"))
|
||||
mock_service = MagicMock()
|
||||
monkeypatch.setattr(module, "ModelLoadBalancingService", lambda: mock_service)
|
||||
return mock_service
|
||||
|
||||
|
||||
def _request_payload():
|
||||
return {"model": "gpt-4o", "model_type": ModelType.LLM, "credentials": {"api_key": "sk-***"}}
|
||||
|
||||
|
||||
def test_validate_credentials_success(app: Flask, load_balancing_module, monkeypatch: pytest.MonkeyPatch):
|
||||
service = _prepare_context(load_balancing_module, monkeypatch)
|
||||
|
||||
with app.test_request_context(
|
||||
"/workspaces/current/model-providers/openai/models/load-balancing-configs/credentials-validate",
|
||||
method="POST",
|
||||
json=_request_payload(),
|
||||
):
|
||||
response = load_balancing_module.LoadBalancingCredentialsValidateApi().post(provider="openai")
|
||||
|
||||
assert response == {"result": "success"}
|
||||
service.validate_load_balancing_credentials.assert_called_once_with(
|
||||
tenant_id="tenant-123",
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
model_type=ModelType.LLM,
|
||||
credentials={"api_key": "sk-***"},
|
||||
)
|
||||
|
||||
|
||||
def test_validate_credentials_returns_error_message(app: Flask, load_balancing_module, monkeypatch: pytest.MonkeyPatch):
|
||||
service = _prepare_context(load_balancing_module, monkeypatch)
|
||||
service.validate_load_balancing_credentials.side_effect = CredentialsValidateFailedError("invalid credentials")
|
||||
|
||||
with app.test_request_context(
|
||||
"/workspaces/current/model-providers/openai/models/load-balancing-configs/credentials-validate",
|
||||
method="POST",
|
||||
json=_request_payload(),
|
||||
):
|
||||
response = load_balancing_module.LoadBalancingCredentialsValidateApi().post(provider="openai")
|
||||
|
||||
assert response == {"result": "error", "error": "invalid credentials"}
|
||||
|
||||
|
||||
def test_validate_credentials_requires_privileged_role(
|
||||
app: Flask, load_balancing_module, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
_prepare_context(load_balancing_module, monkeypatch, role=TenantAccountRole.NORMAL)
|
||||
|
||||
with app.test_request_context(
|
||||
"/workspaces/current/model-providers/openai/models/load-balancing-configs/credentials-validate",
|
||||
method="POST",
|
||||
json=_request_payload(),
|
||||
):
|
||||
api = load_balancing_module.LoadBalancingCredentialsValidateApi()
|
||||
with pytest.raises(Forbidden):
|
||||
api.post(provider="openai")
|
||||
|
||||
|
||||
def test_validate_credentials_with_config_id(app: Flask, load_balancing_module, monkeypatch: pytest.MonkeyPatch):
|
||||
service = _prepare_context(load_balancing_module, monkeypatch)
|
||||
|
||||
with app.test_request_context(
|
||||
"/workspaces/current/model-providers/openai/models/load-balancing-configs/cfg-1/credentials-validate",
|
||||
method="POST",
|
||||
json=_request_payload(),
|
||||
):
|
||||
response = load_balancing_module.LoadBalancingConfigCredentialsValidateApi().post(
|
||||
provider="openai", config_id="cfg-1"
|
||||
)
|
||||
|
||||
assert response == {"result": "success"}
|
||||
service.validate_load_balancing_credentials.assert_called_once_with(
|
||||
tenant_id="tenant-123",
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
model_type=ModelType.LLM,
|
||||
credentials={"api_key": "sk-***"},
|
||||
config_id="cfg-1",
|
||||
)
|
||||
@ -1,103 +0,0 @@
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask_restx import Api
|
||||
|
||||
from controllers.console.workspace.tool_providers import ToolProviderMCPApi
|
||||
from core.db.session_factory import configure_session_factory
|
||||
from extensions.ext_database import db
|
||||
from services.tools.mcp_tools_manage_service import ReconnectResult
|
||||
|
||||
|
||||
# Backward-compat fixtures referenced by @pytest.mark.usefixtures in this file.
|
||||
# They are intentionally no-ops because the test already patches the required
|
||||
# behaviors explicitly via @patch and context managers below.
|
||||
@pytest.fixture
|
||||
def _mock_cache():
|
||||
return
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def _mock_user_tenant():
|
||||
return
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///:memory:"
|
||||
api = Api(app)
|
||||
api.add_resource(ToolProviderMCPApi, "/console/api/workspaces/current/tool-provider/mcp")
|
||||
db.init_app(app)
|
||||
# Configure session factory used by controller code
|
||||
with app.app_context():
|
||||
configure_session_factory(db.engine)
|
||||
return app.test_client()
|
||||
|
||||
|
||||
@patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant", return_value=(MagicMock(id="u1"), "t1")
|
||||
)
|
||||
@patch("controllers.console.workspace.tool_providers.ToolProviderListCache.invalidate_cache", return_value=None)
|
||||
@patch("controllers.console.workspace.tool_providers.Session")
|
||||
@patch("controllers.console.workspace.tool_providers.MCPToolManageService._reconnect_with_url")
|
||||
@pytest.mark.usefixtures("_mock_cache", "_mock_user_tenant")
|
||||
def test_create_mcp_provider_populates_tools(
|
||||
mock_reconnect, mock_session, mock_invalidate_cache, mock_current_account_with_tenant, client
|
||||
):
|
||||
# Arrange: reconnect returns tools immediately
|
||||
mock_reconnect.return_value = ReconnectResult(
|
||||
authed=True,
|
||||
tools=json.dumps(
|
||||
[{"name": "ping", "description": "ok", "inputSchema": {"type": "object"}, "outputSchema": {}}]
|
||||
),
|
||||
encrypted_credentials="{}",
|
||||
)
|
||||
|
||||
# Fake service.create_provider -> returns object with id for reload
|
||||
svc = MagicMock()
|
||||
create_result = MagicMock()
|
||||
create_result.id = "provider-1"
|
||||
svc.create_provider.return_value = create_result
|
||||
svc.get_provider.return_value = MagicMock(id="provider-1", tenant_id="t1") # used by reload path
|
||||
mock_session.return_value.__enter__.return_value = MagicMock()
|
||||
# Patch MCPToolManageService constructed inside controller
|
||||
with patch("controllers.console.workspace.tool_providers.MCPToolManageService", return_value=svc):
|
||||
payload = {
|
||||
"server_url": "http://example.com/mcp",
|
||||
"name": "demo",
|
||||
"icon": "😀",
|
||||
"icon_type": "emoji",
|
||||
"icon_background": "#000",
|
||||
"server_identifier": "demo-sid",
|
||||
"configuration": {"timeout": 5, "sse_read_timeout": 30},
|
||||
"headers": {},
|
||||
"authentication": {},
|
||||
}
|
||||
# Act
|
||||
with (
|
||||
patch("controllers.console.wraps.dify_config.EDITION", "CLOUD"), # bypass setup_required DB check
|
||||
patch("controllers.console.wraps.current_account_with_tenant", return_value=(MagicMock(id="u1"), "t1")),
|
||||
patch("libs.login.check_csrf_token", return_value=None), # bypass CSRF in login_required
|
||||
patch("libs.login._get_user", return_value=MagicMock(id="u1", is_authenticated=True)), # login
|
||||
patch(
|
||||
"services.tools.tools_transform_service.ToolTransformService.mcp_provider_to_user_provider",
|
||||
return_value={"id": "provider-1", "tools": [{"name": "ping"}]},
|
||||
),
|
||||
):
|
||||
resp = client.post(
|
||||
"/console/api/workspaces/current/tool-provider/mcp",
|
||||
data=json.dumps(payload),
|
||||
content_type="application/json",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert resp.status_code == 200
|
||||
body = resp.get_json()
|
||||
assert body.get("id") == "provider-1"
|
||||
# 若 transform 后包含 tools 字段,确保非空
|
||||
assert isinstance(body.get("tools"), list)
|
||||
assert body["tools"]
|
||||
@ -1,245 +0,0 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import core.app.apps.workflow.app_runner as workflow_app_runner
|
||||
from core.app.apps.workflow.app_runner import WorkflowAppRunner, WorkflowBasedAppRunner
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.enums import NodeType, SystemVariableKey, WorkflowType
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
||||
|
||||
def test_prepare_single_iteration_injects_system_variables_and_fake_workflow():
|
||||
node_id = "iteration_node"
|
||||
execution_id = "workflow-exec-123"
|
||||
|
||||
workflow = SimpleNamespace(
|
||||
id="workflow-id",
|
||||
tenant_id="tenant-id",
|
||||
app_id="app-id",
|
||||
environment_variables=[],
|
||||
graph_dict={
|
||||
"nodes": [
|
||||
{
|
||||
"id": node_id,
|
||||
"type": "custom",
|
||||
"data": {
|
||||
"type": NodeType.ITERATION,
|
||||
"title": "Iteration",
|
||||
"version": "1",
|
||||
"iterator_selector": ["start", "items"],
|
||||
"output_selector": [node_id, "output"],
|
||||
},
|
||||
}
|
||||
],
|
||||
"edges": [],
|
||||
},
|
||||
)
|
||||
|
||||
runner = WorkflowBasedAppRunner(queue_manager=MagicMock(), app_id="app-id")
|
||||
|
||||
system_inputs = SystemVariable(app_id="app-id", workflow_id="workflow-id", workflow_execution_id=execution_id)
|
||||
|
||||
graph, _, runtime_state = runner._prepare_single_node_execution(
|
||||
workflow=workflow,
|
||||
single_iteration_run=SimpleNamespace(node_id=node_id, inputs={"input_selector": [1, 2, 3]}),
|
||||
system_variables=system_inputs,
|
||||
)
|
||||
|
||||
assert runtime_state.variable_pool.system_variables.workflow_execution_id == execution_id
|
||||
assert runtime_state.variable_pool.get_by_prefix("sys")[SystemVariableKey.WORKFLOW_EXECUTION_ID] == execution_id
|
||||
assert graph.root_node.id == f"{node_id}_single_step_start"
|
||||
assert f"{node_id}_single_step_end" in graph.nodes
|
||||
|
||||
|
||||
def test_prepare_single_loop_injects_system_variables_and_fake_workflow():
|
||||
node_id = "loop_node"
|
||||
execution_id = "workflow-exec-456"
|
||||
|
||||
workflow = SimpleNamespace(
|
||||
id="workflow-id",
|
||||
tenant_id="tenant-id",
|
||||
app_id="app-id",
|
||||
environment_variables=[],
|
||||
graph_dict={
|
||||
"nodes": [
|
||||
{
|
||||
"id": node_id,
|
||||
"type": "custom",
|
||||
"data": {
|
||||
"type": NodeType.LOOP,
|
||||
"title": "Loop",
|
||||
"version": "1",
|
||||
"loop_count": 1,
|
||||
"break_conditions": [],
|
||||
"logical_operator": "and",
|
||||
"loop_variables": [],
|
||||
"outputs": {},
|
||||
},
|
||||
}
|
||||
],
|
||||
"edges": [],
|
||||
},
|
||||
)
|
||||
|
||||
runner = WorkflowBasedAppRunner(queue_manager=MagicMock(), app_id="app-id")
|
||||
|
||||
system_inputs = SystemVariable(app_id="app-id", workflow_id="workflow-id", workflow_execution_id=execution_id)
|
||||
|
||||
graph, _, runtime_state = runner._prepare_single_node_execution(
|
||||
workflow=workflow,
|
||||
single_loop_run=SimpleNamespace(node_id=node_id, inputs={}),
|
||||
system_variables=system_inputs,
|
||||
)
|
||||
|
||||
assert runtime_state.variable_pool.system_variables.workflow_execution_id == execution_id
|
||||
assert graph.root_node.id == f"{node_id}_single_step_start"
|
||||
assert f"{node_id}_single_step_end" in graph.nodes
|
||||
|
||||
|
||||
class DummyCommandChannel:
|
||||
def fetch_commands(self):
|
||||
return []
|
||||
|
||||
def send_command(self, command):
|
||||
return None
|
||||
|
||||
|
||||
def _empty_graph_engine_run(self):
|
||||
if False: # pragma: no cover
|
||||
yield None
|
||||
|
||||
|
||||
def _build_generate_entity(*, single_iteration_run=None, single_loop_run=None):
|
||||
if isinstance(single_iteration_run, dict):
|
||||
single_iteration_run = SimpleNamespace(**single_iteration_run)
|
||||
if isinstance(single_loop_run, dict):
|
||||
single_loop_run = SimpleNamespace(**single_loop_run)
|
||||
|
||||
base = SimpleNamespace(
|
||||
app_config=SimpleNamespace(app_id="app-id", workflow_id="workflow-id"),
|
||||
workflow_execution_id="workflow-exec-id",
|
||||
files=[],
|
||||
user_id="user-id",
|
||||
inputs={},
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
task_id="task-id",
|
||||
trace_manager=None,
|
||||
single_iteration_run=single_iteration_run,
|
||||
single_loop_run=single_loop_run,
|
||||
)
|
||||
|
||||
def is_single_stepping_container_nodes():
|
||||
return base.single_iteration_run is not None or base.single_loop_run is not None
|
||||
|
||||
base.is_single_stepping_container_nodes = is_single_stepping_container_nodes # type: ignore[attr-defined]
|
||||
return base
|
||||
|
||||
|
||||
def test_workflow_runner_attaches_persistence_for_full_run(monkeypatch):
|
||||
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||
|
||||
monkeypatch.setattr(GraphEngine, "run", _empty_graph_engine_run)
|
||||
persistence_ctor = MagicMock(name="persistence_layer_ctor")
|
||||
monkeypatch.setattr(workflow_app_runner, "WorkflowPersistenceLayer", persistence_ctor)
|
||||
monkeypatch.setattr(workflow_app_runner, "RedisChannel", lambda *args, **kwargs: DummyCommandChannel())
|
||||
|
||||
queue_manager = MagicMock()
|
||||
workflow = SimpleNamespace(
|
||||
id="workflow-id",
|
||||
tenant_id="tenant-id",
|
||||
app_id="app-id",
|
||||
type=WorkflowType.WORKFLOW,
|
||||
version="1",
|
||||
graph_dict={
|
||||
"nodes": [
|
||||
{
|
||||
"id": "start",
|
||||
"type": "custom",
|
||||
"data": {"type": NodeType.START, "title": "Start", "version": "1", "variables": []},
|
||||
},
|
||||
{
|
||||
"id": "end",
|
||||
"type": "custom",
|
||||
"data": {"type": NodeType.END, "title": "End", "version": "1", "outputs": []},
|
||||
},
|
||||
],
|
||||
"edges": [
|
||||
{"source": "start", "target": "end", "sourceHandle": "source", "targetHandle": "target"},
|
||||
],
|
||||
},
|
||||
environment_variables=[],
|
||||
)
|
||||
generate_entity = _build_generate_entity()
|
||||
generate_entity.inputs = {"input": "value"}
|
||||
|
||||
runner = WorkflowAppRunner(
|
||||
application_generate_entity=generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
variable_loader=MagicMock(),
|
||||
workflow=workflow,
|
||||
system_user_id="system-user-id",
|
||||
root_node_id=None,
|
||||
workflow_execution_repository=MagicMock(),
|
||||
workflow_node_execution_repository=MagicMock(),
|
||||
graph_engine_layers=(),
|
||||
)
|
||||
|
||||
runner.run()
|
||||
|
||||
assert persistence_ctor.call_count == 1
|
||||
|
||||
|
||||
def test_workflow_runner_skips_persistence_for_single_step(monkeypatch):
|
||||
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||
|
||||
monkeypatch.setattr(GraphEngine, "run", _empty_graph_engine_run)
|
||||
persistence_ctor = MagicMock(name="persistence_layer_ctor")
|
||||
monkeypatch.setattr(workflow_app_runner, "WorkflowPersistenceLayer", persistence_ctor)
|
||||
monkeypatch.setattr(workflow_app_runner, "RedisChannel", lambda *args, **kwargs: DummyCommandChannel())
|
||||
|
||||
queue_manager = MagicMock()
|
||||
workflow = SimpleNamespace(
|
||||
id="workflow-id",
|
||||
tenant_id="tenant-id",
|
||||
app_id="app-id",
|
||||
type=WorkflowType.WORKFLOW,
|
||||
version="1",
|
||||
graph_dict={
|
||||
"nodes": [
|
||||
{
|
||||
"id": "loop",
|
||||
"type": "custom",
|
||||
"data": {
|
||||
"type": NodeType.LOOP,
|
||||
"title": "Loop",
|
||||
"version": "1",
|
||||
"loop_count": 1,
|
||||
"break_conditions": [],
|
||||
"logical_operator": "and",
|
||||
"loop_variables": [],
|
||||
"outputs": {},
|
||||
},
|
||||
}
|
||||
],
|
||||
"edges": [],
|
||||
},
|
||||
environment_variables=[],
|
||||
)
|
||||
generate_entity = _build_generate_entity(single_loop_run={"node_id": "loop", "inputs": {}})
|
||||
|
||||
runner = WorkflowAppRunner(
|
||||
application_generate_entity=generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
variable_loader=MagicMock(),
|
||||
workflow=workflow,
|
||||
system_user_id="system-user-id",
|
||||
root_node_id=None,
|
||||
workflow_execution_repository=MagicMock(),
|
||||
workflow_node_execution_repository=MagicMock(),
|
||||
graph_engine_layers=(),
|
||||
)
|
||||
|
||||
runner.run()
|
||||
|
||||
assert persistence_ctor.call_count == 0
|
||||
@ -96,6 +96,9 @@ class TestToolProviderListCache:
|
||||
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
|
||||
mock_redis_client.scan_iter.assert_called_once_with(f"tool_providers:tenant_id:{tenant_id}:*")
|
||||
mock_redis_client.delete.assert_called_once_with(*mock_keys)
|
||||
|
||||
def test_invalidate_cache_no_keys(self, mock_redis_client):
|
||||
"""Test invalidate cache - no cache keys for tenant"""
|
||||
tenant_id = "tenant_123"
|
||||
|
||||
@ -1,327 +0,0 @@
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.rag.datasource.vdb.pgvector.pgvector import (
|
||||
PGVector,
|
||||
PGVectorConfig,
|
||||
)
|
||||
|
||||
|
||||
class TestPGVector(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.config = PGVectorConfig(
|
||||
host="localhost",
|
||||
port=5432,
|
||||
user="test_user",
|
||||
password="test_password",
|
||||
database="test_db",
|
||||
min_connection=1,
|
||||
max_connection=5,
|
||||
pg_bigm=False,
|
||||
)
|
||||
self.collection_name = "test_collection"
|
||||
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
|
||||
def test_init(self, mock_pool_class):
|
||||
"""Test PGVector initialization."""
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
pgvector = PGVector(self.collection_name, self.config)
|
||||
|
||||
assert pgvector._collection_name == self.collection_name
|
||||
assert pgvector.table_name == f"embedding_{self.collection_name}"
|
||||
assert pgvector.get_type() == "pgvector"
|
||||
assert pgvector.pool is not None
|
||||
assert pgvector.pg_bigm is False
|
||||
assert pgvector.index_hash is not None
|
||||
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
|
||||
def test_init_with_pg_bigm(self, mock_pool_class):
|
||||
"""Test PGVector initialization with pg_bigm enabled."""
|
||||
config = PGVectorConfig(
|
||||
host="localhost",
|
||||
port=5432,
|
||||
user="test_user",
|
||||
password="test_password",
|
||||
database="test_db",
|
||||
min_connection=1,
|
||||
max_connection=5,
|
||||
pg_bigm=True,
|
||||
)
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
pgvector = PGVector(self.collection_name, config)
|
||||
|
||||
assert pgvector.pg_bigm is True
|
||||
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client")
|
||||
def test_create_collection_basic(self, mock_redis, mock_pool_class):
|
||||
"""Test basic collection creation."""
|
||||
# Mock Redis operations
|
||||
mock_lock = MagicMock()
|
||||
mock_lock.__enter__ = MagicMock()
|
||||
mock_lock.__exit__ = MagicMock()
|
||||
mock_redis.lock.return_value = mock_lock
|
||||
mock_redis.get.return_value = None
|
||||
mock_redis.set.return_value = None
|
||||
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
# Mock connection and cursor
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.getconn.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.return_value = [1] # vector extension exists
|
||||
|
||||
pgvector = PGVector(self.collection_name, self.config)
|
||||
pgvector._create_collection(1536)
|
||||
|
||||
# Verify SQL execution calls
|
||||
assert mock_cursor.execute.called
|
||||
|
||||
# Check that CREATE TABLE was called with correct dimension
|
||||
create_table_calls = [call for call in mock_cursor.execute.call_args_list if "CREATE TABLE" in str(call)]
|
||||
assert len(create_table_calls) == 1
|
||||
assert "vector(1536)" in create_table_calls[0][0][0]
|
||||
|
||||
# Check that CREATE INDEX was called (dimension <= 2000)
|
||||
create_index_calls = [
|
||||
call for call in mock_cursor.execute.call_args_list if "CREATE INDEX" in str(call) and "hnsw" in str(call)
|
||||
]
|
||||
assert len(create_index_calls) == 1
|
||||
|
||||
# Verify Redis cache was set
|
||||
mock_redis.set.assert_called_once()
|
||||
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client")
|
||||
def test_create_collection_with_large_dimension(self, mock_redis, mock_pool_class):
|
||||
"""Test collection creation with dimension > 2000 (no HNSW index)."""
|
||||
# Mock Redis operations
|
||||
mock_lock = MagicMock()
|
||||
mock_lock.__enter__ = MagicMock()
|
||||
mock_lock.__exit__ = MagicMock()
|
||||
mock_redis.lock.return_value = mock_lock
|
||||
mock_redis.get.return_value = None
|
||||
mock_redis.set.return_value = None
|
||||
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
# Mock connection and cursor
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.getconn.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.return_value = [1] # vector extension exists
|
||||
|
||||
pgvector = PGVector(self.collection_name, self.config)
|
||||
pgvector._create_collection(3072) # Dimension > 2000
|
||||
|
||||
# Check that CREATE TABLE was called
|
||||
create_table_calls = [call for call in mock_cursor.execute.call_args_list if "CREATE TABLE" in str(call)]
|
||||
assert len(create_table_calls) == 1
|
||||
assert "vector(3072)" in create_table_calls[0][0][0]
|
||||
|
||||
# Check that HNSW index was NOT created (dimension > 2000)
|
||||
hnsw_index_calls = [call for call in mock_cursor.execute.call_args_list if "hnsw" in str(call)]
|
||||
assert len(hnsw_index_calls) == 0
|
||||
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client")
|
||||
def test_create_collection_with_pg_bigm(self, mock_redis, mock_pool_class):
|
||||
"""Test collection creation with pg_bigm enabled."""
|
||||
config = PGVectorConfig(
|
||||
host="localhost",
|
||||
port=5432,
|
||||
user="test_user",
|
||||
password="test_password",
|
||||
database="test_db",
|
||||
min_connection=1,
|
||||
max_connection=5,
|
||||
pg_bigm=True,
|
||||
)
|
||||
|
||||
# Mock Redis operations
|
||||
mock_lock = MagicMock()
|
||||
mock_lock.__enter__ = MagicMock()
|
||||
mock_lock.__exit__ = MagicMock()
|
||||
mock_redis.lock.return_value = mock_lock
|
||||
mock_redis.get.return_value = None
|
||||
mock_redis.set.return_value = None
|
||||
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
# Mock connection and cursor
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.getconn.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.return_value = [1] # vector extension exists
|
||||
|
||||
pgvector = PGVector(self.collection_name, config)
|
||||
pgvector._create_collection(1536)
|
||||
|
||||
# Check that pg_bigm index was created
|
||||
bigm_index_calls = [call for call in mock_cursor.execute.call_args_list if "gin_bigm_ops" in str(call)]
|
||||
assert len(bigm_index_calls) == 1
|
||||
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client")
|
||||
def test_create_collection_creates_vector_extension(self, mock_redis, mock_pool_class):
|
||||
"""Test that vector extension is created if it doesn't exist."""
|
||||
# Mock Redis operations
|
||||
mock_lock = MagicMock()
|
||||
mock_lock.__enter__ = MagicMock()
|
||||
mock_lock.__exit__ = MagicMock()
|
||||
mock_redis.lock.return_value = mock_lock
|
||||
mock_redis.get.return_value = None
|
||||
mock_redis.set.return_value = None
|
||||
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
# Mock connection and cursor
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.getconn.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
# First call: vector extension doesn't exist
|
||||
mock_cursor.fetchone.return_value = None
|
||||
|
||||
pgvector = PGVector(self.collection_name, self.config)
|
||||
pgvector._create_collection(1536)
|
||||
|
||||
# Check that CREATE EXTENSION was called
|
||||
create_extension_calls = [
|
||||
call for call in mock_cursor.execute.call_args_list if "CREATE EXTENSION vector" in str(call)
|
||||
]
|
||||
assert len(create_extension_calls) == 1
|
||||
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client")
|
||||
def test_create_collection_with_cache_hit(self, mock_redis, mock_pool_class):
|
||||
"""Test that collection creation is skipped when cache exists."""
|
||||
# Mock Redis operations - cache exists
|
||||
mock_lock = MagicMock()
|
||||
mock_lock.__enter__ = MagicMock()
|
||||
mock_lock.__exit__ = MagicMock()
|
||||
mock_redis.lock.return_value = mock_lock
|
||||
mock_redis.get.return_value = 1 # Cache exists
|
||||
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
# Mock connection and cursor
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.getconn.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
|
||||
pgvector = PGVector(self.collection_name, self.config)
|
||||
pgvector._create_collection(1536)
|
||||
|
||||
# Check that no SQL was executed (early return due to cache)
|
||||
assert mock_cursor.execute.call_count == 0
|
||||
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client")
|
||||
def test_create_collection_with_redis_lock(self, mock_redis, mock_pool_class):
|
||||
"""Test that Redis lock is used during collection creation."""
|
||||
# Mock Redis operations
|
||||
mock_lock = MagicMock()
|
||||
mock_lock.__enter__ = MagicMock()
|
||||
mock_lock.__exit__ = MagicMock()
|
||||
mock_redis.lock.return_value = mock_lock
|
||||
mock_redis.get.return_value = None
|
||||
mock_redis.set.return_value = None
|
||||
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
# Mock connection and cursor
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.getconn.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.return_value = [1] # vector extension exists
|
||||
|
||||
pgvector = PGVector(self.collection_name, self.config)
|
||||
pgvector._create_collection(1536)
|
||||
|
||||
# Verify Redis lock was acquired with correct lock name
|
||||
mock_redis.lock.assert_called_once_with("vector_indexing_test_collection_lock", timeout=20)
|
||||
|
||||
# Verify lock context manager was entered and exited
|
||||
mock_lock.__enter__.assert_called_once()
|
||||
mock_lock.__exit__.assert_called_once()
|
||||
|
||||
@patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool")
|
||||
def test_get_cursor_context_manager(self, mock_pool_class):
|
||||
"""Test that _get_cursor properly manages connection lifecycle."""
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.getconn.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
|
||||
pgvector = PGVector(self.collection_name, self.config)
|
||||
|
||||
with pgvector._get_cursor() as cur:
|
||||
assert cur == mock_cursor
|
||||
|
||||
# Verify connection lifecycle methods were called
|
||||
mock_pool.getconn.assert_called_once()
|
||||
mock_cursor.close.assert_called_once()
|
||||
mock_conn.commit.assert_called_once()
|
||||
mock_pool.putconn.assert_called_once_with(mock_conn)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_config_override",
|
||||
[
|
||||
{"host": ""}, # Test empty host
|
||||
{"port": 0}, # Test invalid port
|
||||
{"user": ""}, # Test empty user
|
||||
{"password": ""}, # Test empty password
|
||||
{"database": ""}, # Test empty database
|
||||
{"min_connection": 0}, # Test invalid min_connection
|
||||
{"max_connection": 0}, # Test invalid max_connection
|
||||
{"min_connection": 10, "max_connection": 5}, # Test min > max
|
||||
],
|
||||
)
|
||||
def test_config_validation_parametrized(invalid_config_override):
|
||||
"""Test configuration validation for various invalid inputs using parametrize."""
|
||||
config = {
|
||||
"host": "localhost",
|
||||
"port": 5432,
|
||||
"user": "test_user",
|
||||
"password": "test_password",
|
||||
"database": "test_db",
|
||||
"min_connection": 1,
|
||||
"max_connection": 5,
|
||||
}
|
||||
config.update(invalid_config_override)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
PGVectorConfig(**config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@ -1,873 +0,0 @@
|
||||
"""
|
||||
Unit tests for DatasetRetrieval.process_metadata_filter_func.
|
||||
|
||||
This module provides comprehensive test coverage for the process_metadata_filter_func
|
||||
method in the DatasetRetrieval class, which is responsible for building SQLAlchemy
|
||||
filter expressions based on metadata filtering conditions.
|
||||
|
||||
Conditions Tested:
|
||||
==================
|
||||
1. **String Conditions**: contains, not contains, start with, end with
|
||||
2. **Equality Conditions**: is / =, is not / ≠
|
||||
3. **Null Conditions**: empty, not empty
|
||||
4. **Numeric Comparisons**: before / <, after / >, ≤ / <=, ≥ / >=
|
||||
5. **List Conditions**: in
|
||||
6. **Edge Cases**: None values, different data types (str, int, float)
|
||||
|
||||
Test Architecture:
|
||||
==================
|
||||
- Direct instantiation of DatasetRetrieval
|
||||
- Mocking of DatasetDocument model attributes
|
||||
- Verification of SQLAlchemy filter expressions
|
||||
- Follows Arrange-Act-Assert (AAA) pattern
|
||||
|
||||
Running Tests:
|
||||
==============
|
||||
# Run all tests in this module
|
||||
uv run --project api pytest \
|
||||
api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_metadata_filter.py -v
|
||||
|
||||
# Run a specific test
|
||||
uv run --project api pytest \
|
||||
api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_metadata_filter.py::\
|
||||
TestProcessMetadataFilterFunc::test_contains_condition -v
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
|
||||
|
||||
class TestProcessMetadataFilterFunc:
|
||||
"""
|
||||
Comprehensive test suite for process_metadata_filter_func method.
|
||||
|
||||
This test class validates all metadata filtering conditions supported by
|
||||
the DatasetRetrieval class, including string operations, numeric comparisons,
|
||||
null checks, and list operations.
|
||||
|
||||
Method Signature:
|
||||
==================
|
||||
def process_metadata_filter_func(
|
||||
self, sequence: int, condition: str, metadata_name: str, value: Any | None, filters: list
|
||||
) -> list:
|
||||
|
||||
The method builds SQLAlchemy filter expressions by:
|
||||
1. Validating value is not None (except for empty/not empty conditions)
|
||||
2. Using DatasetDocument.doc_metadata JSON field operations
|
||||
3. Adding appropriate SQLAlchemy expressions to the filters list
|
||||
4. Returning the updated filters list
|
||||
|
||||
Mocking Strategy:
|
||||
==================
|
||||
- Mock DatasetDocument.doc_metadata to avoid database dependencies
|
||||
- Verify filter expressions are created correctly
|
||||
- Test with various data types (str, int, float, list)
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def retrieval(self):
|
||||
"""
|
||||
Create a DatasetRetrieval instance for testing.
|
||||
|
||||
Returns:
|
||||
DatasetRetrieval: Instance to test process_metadata_filter_func
|
||||
"""
|
||||
return DatasetRetrieval()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_doc_metadata(self):
|
||||
"""
|
||||
Mock the DatasetDocument.doc_metadata JSON field.
|
||||
|
||||
The method uses DatasetDocument.doc_metadata[metadata_name] to access
|
||||
JSON fields. We mock this to avoid database dependencies.
|
||||
|
||||
Returns:
|
||||
Mock: Mocked doc_metadata attribute
|
||||
"""
|
||||
mock_metadata_field = MagicMock()
|
||||
|
||||
# Create mock for string access
|
||||
mock_string_access = MagicMock()
|
||||
mock_string_access.like = MagicMock()
|
||||
mock_string_access.notlike = MagicMock()
|
||||
mock_string_access.__eq__ = MagicMock(return_value=MagicMock())
|
||||
mock_string_access.__ne__ = MagicMock(return_value=MagicMock())
|
||||
mock_string_access.in_ = MagicMock(return_value=MagicMock())
|
||||
|
||||
# Create mock for float access (for numeric comparisons)
|
||||
mock_float_access = MagicMock()
|
||||
mock_float_access.__eq__ = MagicMock(return_value=MagicMock())
|
||||
mock_float_access.__ne__ = MagicMock(return_value=MagicMock())
|
||||
mock_float_access.__lt__ = MagicMock(return_value=MagicMock())
|
||||
mock_float_access.__gt__ = MagicMock(return_value=MagicMock())
|
||||
mock_float_access.__le__ = MagicMock(return_value=MagicMock())
|
||||
mock_float_access.__ge__ = MagicMock(return_value=MagicMock())
|
||||
|
||||
# Create mock for null checks
|
||||
mock_null_access = MagicMock()
|
||||
mock_null_access.is_ = MagicMock(return_value=MagicMock())
|
||||
mock_null_access.isnot = MagicMock(return_value=MagicMock())
|
||||
|
||||
# Setup __getitem__ to return appropriate mock based on usage
|
||||
def getitem_side_effect(name):
|
||||
if name in ["author", "title", "category"]:
|
||||
return mock_string_access
|
||||
elif name in ["year", "price", "rating"]:
|
||||
return mock_float_access
|
||||
else:
|
||||
return mock_string_access
|
||||
|
||||
mock_metadata_field.__getitem__ = MagicMock(side_effect=getitem_side_effect)
|
||||
mock_metadata_field.as_string.return_value = mock_string_access
|
||||
mock_metadata_field.as_float.return_value = mock_float_access
|
||||
mock_metadata_field[metadata_name:str].is_ = mock_null_access.is_
|
||||
mock_metadata_field[metadata_name:str].isnot = mock_null_access.isnot
|
||||
|
||||
return mock_metadata_field
|
||||
|
||||
# ==================== String Condition Tests ====================
|
||||
|
||||
def test_contains_condition_string_value(self, retrieval):
|
||||
"""
|
||||
Test 'contains' condition with string value.
|
||||
|
||||
Verifies:
|
||||
- Filters list is populated with LIKE expression
|
||||
- Pattern matching uses %value% syntax
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "contains"
|
||||
metadata_name = "author"
|
||||
value = "John"
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_not_contains_condition(self, retrieval):
|
||||
"""
|
||||
Test 'not contains' condition.
|
||||
|
||||
Verifies:
|
||||
- Filters list is populated with NOT LIKE expression
|
||||
- Pattern matching uses %value% syntax with negation
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "not contains"
|
||||
metadata_name = "title"
|
||||
value = "banned"
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_start_with_condition(self, retrieval):
|
||||
"""
|
||||
Test 'start with' condition.
|
||||
|
||||
Verifies:
|
||||
- Filters list is populated with LIKE expression
|
||||
- Pattern matching uses value% syntax
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "start with"
|
||||
metadata_name = "category"
|
||||
value = "tech"
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_end_with_condition(self, retrieval):
|
||||
"""
|
||||
Test 'end with' condition.
|
||||
|
||||
Verifies:
|
||||
- Filters list is populated with LIKE expression
|
||||
- Pattern matching uses %value syntax
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "end with"
|
||||
metadata_name = "filename"
|
||||
value = ".pdf"
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
# ==================== Equality Condition Tests ====================
|
||||
|
||||
def test_is_condition_with_string_value(self, retrieval):
|
||||
"""
|
||||
Test 'is' (=) condition with string value.
|
||||
|
||||
Verifies:
|
||||
- Filters list is populated with equality expression
|
||||
- String comparison is used
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "is"
|
||||
metadata_name = "author"
|
||||
value = "Jane Doe"
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_equals_condition_with_string_value(self, retrieval):
|
||||
"""
|
||||
Test '=' condition with string value.
|
||||
|
||||
Verifies:
|
||||
- Same behavior as 'is' condition
|
||||
- String comparison is used
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "="
|
||||
metadata_name = "category"
|
||||
value = "technology"
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_is_condition_with_int_value(self, retrieval):
|
||||
"""
|
||||
Test 'is' condition with integer value.
|
||||
|
||||
Verifies:
|
||||
- Numeric comparison is used
|
||||
- as_float() is called on the metadata field
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "is"
|
||||
metadata_name = "year"
|
||||
value = 2023
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_is_condition_with_float_value(self, retrieval):
|
||||
"""
|
||||
Test 'is' condition with float value.
|
||||
|
||||
Verifies:
|
||||
- Numeric comparison is used
|
||||
- as_float() is called on the metadata field
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "is"
|
||||
metadata_name = "price"
|
||||
value = 19.99
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_is_not_condition_with_string_value(self, retrieval):
|
||||
"""
|
||||
Test 'is not' (≠) condition with string value.
|
||||
|
||||
Verifies:
|
||||
- Filters list is populated with inequality expression
|
||||
- String comparison is used
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "is not"
|
||||
metadata_name = "author"
|
||||
value = "Unknown"
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_not_equals_condition(self, retrieval):
|
||||
"""
|
||||
Test '≠' condition with string value.
|
||||
|
||||
Verifies:
|
||||
- Same behavior as 'is not' condition
|
||||
- Inequality expression is used
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "≠"
|
||||
metadata_name = "category"
|
||||
value = "archived"
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_is_not_condition_with_numeric_value(self, retrieval):
|
||||
"""
|
||||
Test 'is not' condition with numeric value.
|
||||
|
||||
Verifies:
|
||||
- Numeric inequality comparison is used
|
||||
- as_float() is called on the metadata field
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "is not"
|
||||
metadata_name = "year"
|
||||
value = 2000
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
# ==================== Null Condition Tests ====================
|
||||
|
||||
def test_empty_condition(self, retrieval):
|
||||
"""
|
||||
Test 'empty' condition (null check).
|
||||
|
||||
Verifies:
|
||||
- Filters list is populated with IS NULL expression
|
||||
- Value can be None for this condition
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "empty"
|
||||
metadata_name = "author"
|
||||
value = None
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_not_empty_condition(self, retrieval):
|
||||
"""
|
||||
Test 'not empty' condition (not null check).
|
||||
|
||||
Verifies:
|
||||
- Filters list is populated with IS NOT NULL expression
|
||||
- Value can be None for this condition
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "not empty"
|
||||
metadata_name = "description"
|
||||
value = None
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
# ==================== Numeric Comparison Tests ====================
|
||||
|
||||
def test_before_condition(self, retrieval):
|
||||
"""
|
||||
Test 'before' (<) condition.
|
||||
|
||||
Verifies:
|
||||
- Filters list is populated with less than expression
|
||||
- Numeric comparison is used
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "before"
|
||||
metadata_name = "year"
|
||||
value = 2020
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_less_than_condition(self, retrieval):
|
||||
"""
|
||||
Test '<' condition.
|
||||
|
||||
Verifies:
|
||||
- Same behavior as 'before' condition
|
||||
- Less than expression is used
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "<"
|
||||
metadata_name = "price"
|
||||
value = 100.0
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_after_condition(self, retrieval):
|
||||
"""
|
||||
Test 'after' (>) condition.
|
||||
|
||||
Verifies:
|
||||
- Filters list is populated with greater than expression
|
||||
- Numeric comparison is used
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "after"
|
||||
metadata_name = "year"
|
||||
value = 2020
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_greater_than_condition(self, retrieval):
|
||||
"""
|
||||
Test '>' condition.
|
||||
|
||||
Verifies:
|
||||
- Same behavior as 'after' condition
|
||||
- Greater than expression is used
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = ">"
|
||||
metadata_name = "rating"
|
||||
value = 4.5
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_less_than_or_equal_condition_unicode(self, retrieval):
|
||||
"""
|
||||
Test '≤' condition.
|
||||
|
||||
Verifies:
|
||||
- Filters list is populated with less than or equal expression
|
||||
- Numeric comparison is used
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "≤"
|
||||
metadata_name = "price"
|
||||
value = 50.0
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_less_than_or_equal_condition_ascii(self, retrieval):
|
||||
"""
|
||||
Test '<=' condition.
|
||||
|
||||
Verifies:
|
||||
- Same behavior as '≤' condition
|
||||
- Less than or equal expression is used
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "<="
|
||||
metadata_name = "year"
|
||||
value = 2023
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_greater_than_or_equal_condition_unicode(self, retrieval):
|
||||
"""
|
||||
Test '≥' condition.
|
||||
|
||||
Verifies:
|
||||
- Filters list is populated with greater than or equal expression
|
||||
- Numeric comparison is used
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "≥"
|
||||
metadata_name = "rating"
|
||||
value = 3.5
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_greater_than_or_equal_condition_ascii(self, retrieval):
|
||||
"""
|
||||
Test '>=' condition.
|
||||
|
||||
Verifies:
|
||||
- Same behavior as '≥' condition
|
||||
- Greater than or equal expression is used
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = ">="
|
||||
metadata_name = "year"
|
||||
value = 2000
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
# ==================== List/In Condition Tests ====================
|
||||
|
||||
def test_in_condition_with_comma_separated_string(self, retrieval):
|
||||
"""
|
||||
Test 'in' condition with comma-separated string value.
|
||||
|
||||
Verifies:
|
||||
- String is split into list
|
||||
- Whitespace is trimmed from each value
|
||||
- IN expression is created
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "in"
|
||||
metadata_name = "category"
|
||||
value = "tech, science, AI "
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_in_condition_with_list_value(self, retrieval):
|
||||
"""
|
||||
Test 'in' condition with list value.
|
||||
|
||||
Verifies:
|
||||
- List is processed correctly
|
||||
- None values are filtered out
|
||||
- IN expression is created with valid values
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "in"
|
||||
metadata_name = "tags"
|
||||
value = ["python", "javascript", None, "golang"]
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_in_condition_with_tuple_value(self, retrieval):
|
||||
"""
|
||||
Test 'in' condition with tuple value.
|
||||
|
||||
Verifies:
|
||||
- Tuple is processed like a list
|
||||
- IN expression is created
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "in"
|
||||
metadata_name = "category"
|
||||
value = ("tech", "science", "ai")
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_in_condition_with_empty_string(self, retrieval):
|
||||
"""
|
||||
Test 'in' condition with empty string value.
|
||||
|
||||
Verifies:
|
||||
- Empty string results in literal(False) filter
|
||||
- No valid values to match
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "in"
|
||||
metadata_name = "category"
|
||||
value = ""
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
# Verify it's a literal(False) expression
|
||||
# This is a bit tricky to test without access to the actual expression
|
||||
|
||||
def test_in_condition_with_only_whitespace(self, retrieval):
|
||||
"""
|
||||
Test 'in' condition with whitespace-only string value.
|
||||
|
||||
Verifies:
|
||||
- Whitespace-only string results in literal(False) filter
|
||||
- All values are stripped and filtered out
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "in"
|
||||
metadata_name = "category"
|
||||
value = " , , "
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_in_condition_with_single_string(self, retrieval):
|
||||
"""
|
||||
Test 'in' condition with single non-comma string.
|
||||
|
||||
Verifies:
|
||||
- Single string is treated as single-item list
|
||||
- IN expression is created with one value
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "in"
|
||||
metadata_name = "category"
|
||||
value = "technology"
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
# ==================== Edge Case Tests ====================
|
||||
|
||||
def test_none_value_with_non_empty_condition(self, retrieval):
|
||||
"""
|
||||
Test None value with conditions that require value.
|
||||
|
||||
Verifies:
|
||||
- Original filters list is returned unchanged
|
||||
- No filter is added for None values (except empty/not empty)
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "contains"
|
||||
metadata_name = "author"
|
||||
value = None
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 0 # No filter added
|
||||
|
||||
def test_none_value_with_equals_condition(self, retrieval):
|
||||
"""
|
||||
Test None value with 'is' (=) condition.
|
||||
|
||||
Verifies:
|
||||
- Original filters list is returned unchanged
|
||||
- No filter is added for None values
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "is"
|
||||
metadata_name = "author"
|
||||
value = None
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 0
|
||||
|
||||
def test_none_value_with_numeric_condition(self, retrieval):
|
||||
"""
|
||||
Test None value with numeric comparison condition.
|
||||
|
||||
Verifies:
|
||||
- Original filters list is returned unchanged
|
||||
- No filter is added for None values
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = ">"
|
||||
metadata_name = "year"
|
||||
value = None
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 0
|
||||
|
||||
def test_existing_filters_preserved(self, retrieval):
|
||||
"""
|
||||
Test that existing filters are preserved.
|
||||
|
||||
Verifies:
|
||||
- Existing filters in the list are not removed
|
||||
- New filters are appended to the list
|
||||
"""
|
||||
existing_filter = MagicMock()
|
||||
filters = [existing_filter]
|
||||
sequence = 0
|
||||
condition = "contains"
|
||||
metadata_name = "author"
|
||||
value = "test"
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 2
|
||||
assert filters[0] == existing_filter
|
||||
|
||||
def test_multiple_filters_accumulated(self, retrieval):
|
||||
"""
|
||||
Test multiple calls to accumulate filters.
|
||||
|
||||
Verifies:
|
||||
- Each call adds a new filter to the list
|
||||
- All filters are preserved across calls
|
||||
"""
|
||||
filters = []
|
||||
|
||||
# First filter
|
||||
retrieval.process_metadata_filter_func(0, "contains", "author", "John", filters)
|
||||
assert len(filters) == 1
|
||||
|
||||
# Second filter
|
||||
retrieval.process_metadata_filter_func(1, ">", "year", 2020, filters)
|
||||
assert len(filters) == 2
|
||||
|
||||
# Third filter
|
||||
retrieval.process_metadata_filter_func(2, "is", "category", "tech", filters)
|
||||
assert len(filters) == 3
|
||||
|
||||
def test_unknown_condition(self, retrieval):
|
||||
"""
|
||||
Test unknown/unsupported condition.
|
||||
|
||||
Verifies:
|
||||
- Original filters list is returned unchanged
|
||||
- No filter is added for unknown conditions
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "unknown_condition"
|
||||
metadata_name = "author"
|
||||
value = "test"
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 0
|
||||
|
||||
def test_empty_string_value_with_contains(self, retrieval):
|
||||
"""
|
||||
Test empty string value with 'contains' condition.
|
||||
|
||||
Verifies:
|
||||
- Filter is added even with empty string
|
||||
- LIKE expression is created
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "contains"
|
||||
metadata_name = "author"
|
||||
value = ""
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_special_characters_in_value(self, retrieval):
|
||||
"""
|
||||
Test special characters in value string.
|
||||
|
||||
Verifies:
|
||||
- Special characters are handled in value
|
||||
- LIKE expression is created correctly
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "contains"
|
||||
metadata_name = "title"
|
||||
value = "C++ & Python's features"
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_zero_value_with_numeric_condition(self, retrieval):
|
||||
"""
|
||||
Test zero value with numeric comparison condition.
|
||||
|
||||
Verifies:
|
||||
- Zero is treated as valid value
|
||||
- Numeric comparison is performed
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = ">"
|
||||
metadata_name = "price"
|
||||
value = 0
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_negative_value_with_numeric_condition(self, retrieval):
|
||||
"""
|
||||
Test negative value with numeric comparison condition.
|
||||
|
||||
Verifies:
|
||||
- Negative numbers are handled correctly
|
||||
- Numeric comparison is performed
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = "<"
|
||||
metadata_name = "temperature"
|
||||
value = -10.5
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
|
||||
def test_float_value_with_integer_comparison(self, retrieval):
|
||||
"""
|
||||
Test float value with numeric comparison condition.
|
||||
|
||||
Verifies:
|
||||
- Float values work correctly
|
||||
- Numeric comparison is performed
|
||||
"""
|
||||
filters = []
|
||||
sequence = 0
|
||||
condition = ">="
|
||||
metadata_name = "rating"
|
||||
value = 4.5
|
||||
|
||||
result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters)
|
||||
|
||||
assert result == filters
|
||||
assert len(filters) == 1
|
||||
@ -1,51 +0,0 @@
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.graph_events import (
|
||||
NodeRunIterationStartedEvent,
|
||||
NodeRunIterationSucceededEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.nodes.iteration.iteration_node import IterationNode
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
||||
|
||||
def test_iteration_node_emits_iteration_events_when_iterator_empty():
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
workflow_id="workflow",
|
||||
graph_config={},
|
||||
user_id="user",
|
||||
user_from="account",
|
||||
invoke_from="debugger",
|
||||
call_depth=0,
|
||||
)
|
||||
runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=SystemVariable.empty(), user_inputs={}),
|
||||
start_at=0.0,
|
||||
)
|
||||
runtime_state.variable_pool.add(("start", "items"), [])
|
||||
|
||||
node = IterationNode(
|
||||
id="iteration-node",
|
||||
config={
|
||||
"id": "iteration-node",
|
||||
"data": {
|
||||
"title": "Iteration",
|
||||
"iterator_selector": ["start", "items"],
|
||||
"output_selector": ["iteration-node", "output"],
|
||||
},
|
||||
},
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
events = list(node.run())
|
||||
|
||||
assert any(isinstance(event, NodeRunIterationStartedEvent) for event in events)
|
||||
|
||||
iteration_succeeded_event = next(event for event in events if isinstance(event, NodeRunIterationSucceededEvent))
|
||||
assert iteration_succeeded_event.steps == 0
|
||||
assert iteration_succeeded_event.outputs == {"output": []}
|
||||
|
||||
assert any(isinstance(event, NodeRunSucceededEvent) for event in events)
|
||||
@ -1,122 +0,0 @@
|
||||
import base64
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.mcp.types import (
|
||||
AudioContent,
|
||||
BlobResourceContents,
|
||||
CallToolResult,
|
||||
EmbeddedResource,
|
||||
ImageContent,
|
||||
TextResourceContents,
|
||||
)
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage
|
||||
from core.tools.mcp_tool.tool import MCPTool
|
||||
|
||||
|
||||
def _make_mcp_tool(output_schema: dict | None = None) -> MCPTool:
|
||||
identity = ToolIdentity(
|
||||
author="test",
|
||||
name="test_mcp_tool",
|
||||
label=I18nObject(en_US="Test MCP Tool", zh_Hans="测试MCP工具"),
|
||||
provider="test_provider",
|
||||
)
|
||||
entity = ToolEntity(identity=identity, output_schema=output_schema or {})
|
||||
runtime = Mock(spec=ToolRuntime)
|
||||
runtime.credentials = {}
|
||||
return MCPTool(
|
||||
entity=entity,
|
||||
runtime=runtime,
|
||||
tenant_id="test_tenant",
|
||||
icon="",
|
||||
server_url="https://server.invalid",
|
||||
provider_id="provider_1",
|
||||
headers={},
|
||||
)
|
||||
|
||||
|
||||
class TestMCPToolInvoke:
|
||||
@pytest.mark.parametrize(
|
||||
("content_factory", "mime_type"),
|
||||
[
|
||||
(
|
||||
lambda b64, mt: ImageContent(type="image", data=b64, mimeType=mt),
|
||||
"image/png",
|
||||
),
|
||||
(
|
||||
lambda b64, mt: AudioContent(type="audio", data=b64, mimeType=mt),
|
||||
"audio/mpeg",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_invoke_image_or_audio_yields_blob(self, content_factory, mime_type) -> None:
|
||||
tool = _make_mcp_tool()
|
||||
raw = b"\x00\x01test-bytes\x02"
|
||||
b64 = base64.b64encode(raw).decode()
|
||||
content = content_factory(b64, mime_type)
|
||||
result = CallToolResult(content=[content])
|
||||
|
||||
with patch.object(tool, "invoke_remote_mcp_tool", return_value=result):
|
||||
messages = list(tool._invoke(user_id="test_user", tool_parameters={}))
|
||||
|
||||
assert len(messages) == 1
|
||||
msg = messages[0]
|
||||
assert msg.type == ToolInvokeMessage.MessageType.BLOB
|
||||
assert isinstance(msg.message, ToolInvokeMessage.BlobMessage)
|
||||
assert msg.message.blob == raw
|
||||
assert msg.meta == {"mime_type": mime_type}
|
||||
|
||||
def test_invoke_embedded_text_resource_yields_text(self) -> None:
|
||||
tool = _make_mcp_tool()
|
||||
text_resource = TextResourceContents(uri="file://test.txt", mimeType="text/plain", text="hello world")
|
||||
content = EmbeddedResource(type="resource", resource=text_resource)
|
||||
result = CallToolResult(content=[content])
|
||||
|
||||
with patch.object(tool, "invoke_remote_mcp_tool", return_value=result):
|
||||
messages = list(tool._invoke(user_id="test_user", tool_parameters={}))
|
||||
|
||||
assert len(messages) == 1
|
||||
msg = messages[0]
|
||||
assert msg.type == ToolInvokeMessage.MessageType.TEXT
|
||||
assert isinstance(msg.message, ToolInvokeMessage.TextMessage)
|
||||
assert msg.message.text == "hello world"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("mime_type", "expected_mime"),
|
||||
[("application/pdf", "application/pdf"), (None, "application/octet-stream")],
|
||||
)
|
||||
def test_invoke_embedded_blob_resource_yields_blob(self, mime_type, expected_mime) -> None:
|
||||
tool = _make_mcp_tool()
|
||||
raw = b"binary-data"
|
||||
b64 = base64.b64encode(raw).decode()
|
||||
blob_resource = BlobResourceContents(uri="file://doc.bin", mimeType=mime_type, blob=b64)
|
||||
content = EmbeddedResource(type="resource", resource=blob_resource)
|
||||
result = CallToolResult(content=[content])
|
||||
|
||||
with patch.object(tool, "invoke_remote_mcp_tool", return_value=result):
|
||||
messages = list(tool._invoke(user_id="test_user", tool_parameters={}))
|
||||
|
||||
assert len(messages) == 1
|
||||
msg = messages[0]
|
||||
assert msg.type == ToolInvokeMessage.MessageType.BLOB
|
||||
assert isinstance(msg.message, ToolInvokeMessage.BlobMessage)
|
||||
assert msg.message.blob == raw
|
||||
assert msg.meta == {"mime_type": expected_mime}
|
||||
|
||||
def test_invoke_yields_variables_when_structured_content_and_schema(self) -> None:
|
||||
tool = _make_mcp_tool(output_schema={"type": "object"})
|
||||
result = CallToolResult(content=[], structuredContent={"a": 1, "b": "x"})
|
||||
|
||||
with patch.object(tool, "invoke_remote_mcp_tool", return_value=result):
|
||||
messages = list(tool._invoke(user_id="test_user", tool_parameters={}))
|
||||
|
||||
# Expect two variable messages corresponding to keys a and b
|
||||
assert len(messages) == 2
|
||||
var_msgs = [m for m in messages if isinstance(m.message, ToolInvokeMessage.VariableMessage)]
|
||||
assert {m.message.variable_name for m in var_msgs} == {"a", "b"}
|
||||
# Validate values
|
||||
values = {m.message.variable_name: m.message.variable_value for m in var_msgs}
|
||||
assert values == {"a": 1, "b": "x"}
|
||||
12
api/uv.lock
generated
12
api/uv.lock
generated
@ -1953,14 +1953,14 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "fickling"
|
||||
version = "0.1.5"
|
||||
version = "0.1.6"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "stdlib-list" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/41/94/0d0ce455952c036cfee235637f786c1d1d07d1b90f6a4dfb50e0eff929d6/fickling-0.1.5.tar.gz", hash = "sha256:92f9b49e717fa8dbc198b4b7b685587adb652d85aa9ede8131b3e44494efca05", size = 282462, upload-time = "2025-11-18T05:04:30.748Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/07/ab/7571453f9365c17c047b5a7b7e82692a7f6be51203f295030886758fd57a/fickling-0.1.6.tar.gz", hash = "sha256:03cb5d7bd09f9169c7583d2079fad4b3b88b25f865ed0049172e5cb68582311d", size = 284033, upload-time = "2025-12-15T18:14:58.721Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/bf/a7/d25912b2e3a5b0a37e6f460050bbc396042b5906a6563a1962c484abc3c6/fickling-0.1.5-py3-none-any.whl", hash = "sha256:6aed7270bfa276e188b0abe043a27b3a042129d28ec1fa6ff389bdcc5ad178bb", size = 46240, upload-time = "2025-11-18T05:04:29.048Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/76/99/cc04258dda421bc612cdfe4be8c253f45b922f1c7f268b5a0b9962d9cd12/fickling-0.1.6-py3-none-any.whl", hash = "sha256:465d0069548bfc731bdd75a583cb4cf5a4b2666739c0f76287807d724b147ed3", size = 47922, upload-time = "2025-12-15T18:14:57.526Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -3072,11 +3072,11 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "json-repair"
|
||||
version = "0.54.3"
|
||||
version = "0.54.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/b5/86/48b12ac02032f121ac7e5f11a32143edca6c1e3d19ffc54d6fb9ca0aafd0/json_repair-0.54.3.tar.gz", hash = "sha256:e50feec9725e52ac91f12184609754684ac1656119dfbd31de09bdaf9a1d8bf6", size = 38626, upload-time = "2025-12-15T09:41:58.594Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/00/46/d3a4d9a3dad39bb4a2ad16b8adb9fe2e8611b20b71197fe33daa6768e85d/json_repair-0.54.1.tar.gz", hash = "sha256:d010bc31f1fc66e7c36dc33bff5f8902674498ae5cb8e801ad455a53b455ad1d", size = 38555, upload-time = "2025-11-19T14:55:24.265Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/e9/08/abe317237add63c3e62f18a981bccf92112b431835b43d844aedaf61f4a0/json_repair-0.54.3-py3-none-any.whl", hash = "sha256:4cdc132ee27d4780576f71bf27a113877046224a808bfc17392e079cb344fb81", size = 29357, upload-time = "2025-12-15T09:41:57.436Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/db/96/c9aad7ee949cc1bf15df91f347fbc2d3bd10b30b80c7df689ce6fe9332b5/json_repair-0.54.1-py3-none-any.whl", hash = "sha256:016160c5db5d5fe443164927bb58d2dfbba5f43ad85719fa9bc51c713a443ab1", size = 29311, upload-time = "2025-11-19T14:55:22.886Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
@ -399,7 +399,6 @@ CONSOLE_CORS_ALLOW_ORIGINS=*
|
||||
COOKIE_DOMAIN=
|
||||
# When the frontend and backend run on different subdomains, set NEXT_PUBLIC_COOKIE_DOMAIN=1.
|
||||
NEXT_PUBLIC_COOKIE_DOMAIN=
|
||||
NEXT_PUBLIC_BATCH_CONCURRENCY=5
|
||||
|
||||
# ------------------------------
|
||||
# File Storage Configuration
|
||||
|
||||
@ -108,7 +108,6 @@ x-shared-env: &shared-api-worker-env
|
||||
CONSOLE_CORS_ALLOW_ORIGINS: ${CONSOLE_CORS_ALLOW_ORIGINS:-*}
|
||||
COOKIE_DOMAIN: ${COOKIE_DOMAIN:-}
|
||||
NEXT_PUBLIC_COOKIE_DOMAIN: ${NEXT_PUBLIC_COOKIE_DOMAIN:-}
|
||||
NEXT_PUBLIC_BATCH_CONCURRENCY: ${NEXT_PUBLIC_BATCH_CONCURRENCY:-5}
|
||||
STORAGE_TYPE: ${STORAGE_TYPE:-opendal}
|
||||
OPENDAL_SCHEME: ${OPENDAL_SCHEME:-fs}
|
||||
OPENDAL_FS_ROOT: ${OPENDAL_FS_ROOT:-storage}
|
||||
|
||||
@ -54,17 +54,17 @@
|
||||
"publish:npm": "./scripts/publish.sh"
|
||||
},
|
||||
"dependencies": {
|
||||
"axios": "^1.13.2"
|
||||
"axios": "^1.3.5"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@eslint/js": "^9.39.2",
|
||||
"@types/node": "^25.0.3",
|
||||
"@eslint/js": "^9.2.0",
|
||||
"@types/node": "^20.11.30",
|
||||
"@typescript-eslint/eslint-plugin": "^8.50.1",
|
||||
"@typescript-eslint/parser": "^8.50.1",
|
||||
"@vitest/coverage-v8": "4.0.16",
|
||||
"eslint": "^9.39.2",
|
||||
"@vitest/coverage-v8": "1.6.1",
|
||||
"eslint": "^9.2.0",
|
||||
"tsup": "^8.5.1",
|
||||
"typescript": "^5.9.3",
|
||||
"vitest": "^4.0.16"
|
||||
"typescript": "^5.4.5",
|
||||
"vitest": "^1.5.0"
|
||||
}
|
||||
}
|
||||
|
||||
899
sdks/nodejs-client/pnpm-lock.yaml
generated
899
sdks/nodejs-client/pnpm-lock.yaml
generated
File diff suppressed because it is too large
Load Diff
@ -1,2 +0,0 @@
|
||||
onlyBuiltDependencies:
|
||||
- esbuild
|
||||
@ -73,6 +73,3 @@ NEXT_PUBLIC_MAX_TREE_DEPTH=50
|
||||
|
||||
# The API key of amplitude
|
||||
NEXT_PUBLIC_AMPLITUDE_API_KEY=
|
||||
|
||||
# number of concurrency
|
||||
NEXT_PUBLIC_BATCH_CONCURRENCY=5
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
# base image
|
||||
FROM node:22-alpine3.21 AS base
|
||||
FROM node:24-alpine AS base
|
||||
LABEL maintainer="takatost@gmail.com"
|
||||
|
||||
# if you located in China, you can use aliyun mirror to speed up
|
||||
|
||||
@ -8,8 +8,8 @@ This is a [Next.js](https://nextjs.org/) project bootstrapped with [`create-next
|
||||
|
||||
Before starting the web frontend service, please make sure the following environment is ready.
|
||||
|
||||
- [Node.js](https://nodejs.org) >= v22.11.x
|
||||
- [pnpm](https://pnpm.io) v10.x
|
||||
- [Node.js](https://nodejs.org)
|
||||
- [pnpm](https://pnpm.io)
|
||||
|
||||
First, install the dependencies:
|
||||
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import type { ReactNode } from 'react'
|
||||
import * as React from 'react'
|
||||
import { AppInitializer } from '@/app/components/app-initializer'
|
||||
import AmplitudeProvider from '@/app/components/base/amplitude'
|
||||
import GA, { GaType } from '@/app/components/base/ga'
|
||||
import Zendesk from '@/app/components/base/zendesk'
|
||||
@ -8,6 +7,7 @@ import GotoAnything from '@/app/components/goto-anything'
|
||||
import Header from '@/app/components/header'
|
||||
import HeaderWrapper from '@/app/components/header/header-wrapper'
|
||||
import ReadmePanel from '@/app/components/plugins/readme-panel'
|
||||
import SwrInitializer from '@/app/components/swr-initializer'
|
||||
import { AppContextProvider } from '@/context/app-context'
|
||||
import { EventEmitterContextProvider } from '@/context/event-emitter'
|
||||
import { ModalContextProvider } from '@/context/modal-context'
|
||||
@ -20,7 +20,7 @@ const Layout = ({ children }: { children: ReactNode }) => {
|
||||
<>
|
||||
<GA gaType={GaType.admin} />
|
||||
<AmplitudeProvider />
|
||||
<AppInitializer>
|
||||
<SwrInitializer>
|
||||
<AppContextProvider>
|
||||
<EventEmitterContextProvider>
|
||||
<ProviderContextProvider>
|
||||
@ -38,7 +38,7 @@ const Layout = ({ children }: { children: ReactNode }) => {
|
||||
</EventEmitterContextProvider>
|
||||
</AppContextProvider>
|
||||
<Zendesk />
|
||||
</AppInitializer>
|
||||
</SwrInitializer>
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
@ -14,6 +14,7 @@ import { useWebAppStore } from '@/context/web-app-context'
|
||||
import { sendWebAppEMailLoginCode, webAppEmailLoginWithCode } from '@/service/common'
|
||||
import { fetchAccessToken } from '@/service/share'
|
||||
import { setWebAppAccessToken, setWebAppPassport } from '@/service/webapp-auth'
|
||||
import { encryptVerificationCode } from '@/utils/encryption'
|
||||
|
||||
export default function CheckCode() {
|
||||
const { t } = useTranslation()
|
||||
@ -64,7 +65,7 @@ export default function CheckCode() {
|
||||
return
|
||||
}
|
||||
setIsLoading(true)
|
||||
const ret = await webAppEmailLoginWithCode({ email, code, token })
|
||||
const ret = await webAppEmailLoginWithCode({ email, code: encryptVerificationCode(code), token })
|
||||
if (ret.result === 'success') {
|
||||
setWebAppAccessToken(ret.data.access_token)
|
||||
const { access_token } = await fetchAccessToken({
|
||||
|
||||
@ -14,6 +14,7 @@ import { useWebAppStore } from '@/context/web-app-context'
|
||||
import { webAppLogin } from '@/service/common'
|
||||
import { fetchAccessToken } from '@/service/share'
|
||||
import { setWebAppAccessToken, setWebAppPassport } from '@/service/webapp-auth'
|
||||
import { encryptPassword } from '@/utils/encryption'
|
||||
|
||||
type MailAndPasswordAuthProps = {
|
||||
isEmailSetup: boolean
|
||||
@ -72,7 +73,7 @@ export default function MailAndPasswordAuth({ isEmailSetup }: MailAndPasswordAut
|
||||
setIsLoading(true)
|
||||
const loginData: Record<string, any> = {
|
||||
email,
|
||||
password,
|
||||
password: encryptPassword(password),
|
||||
language: locale,
|
||||
remember_me: true,
|
||||
}
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
import type { ReactNode } from 'react'
|
||||
import * as React from 'react'
|
||||
import { AppInitializer } from '@/app/components/app-initializer'
|
||||
import AmplitudeProvider from '@/app/components/base/amplitude'
|
||||
import GA, { GaType } from '@/app/components/base/ga'
|
||||
import HeaderWrapper from '@/app/components/header/header-wrapper'
|
||||
import SwrInitor from '@/app/components/swr-initializer'
|
||||
import { AppContextProvider } from '@/context/app-context'
|
||||
import { EventEmitterContextProvider } from '@/context/event-emitter'
|
||||
import { ModalContextProvider } from '@/context/modal-context'
|
||||
@ -15,7 +15,7 @@ const Layout = ({ children }: { children: ReactNode }) => {
|
||||
<>
|
||||
<GA gaType={GaType.admin} />
|
||||
<AmplitudeProvider />
|
||||
<AppInitializer>
|
||||
<SwrInitor>
|
||||
<AppContextProvider>
|
||||
<EventEmitterContextProvider>
|
||||
<ProviderContextProvider>
|
||||
@ -30,7 +30,7 @@ const Layout = ({ children }: { children: ReactNode }) => {
|
||||
</ProviderContextProvider>
|
||||
</EventEmitterContextProvider>
|
||||
</AppContextProvider>
|
||||
</AppInitializer>
|
||||
</SwrInitor>
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
@ -47,12 +47,6 @@ const getCheckboxDefaultSelectValue = (value: InputVar['default']) => {
|
||||
const parseCheckboxSelectValue = (value: string) =>
|
||||
value === CHECKBOX_DEFAULT_TRUE_VALUE
|
||||
|
||||
const normalizeSelectDefaultValue = (inputVar: InputVar) => {
|
||||
if (inputVar.type === InputVarType.select && inputVar.default === '')
|
||||
return { ...inputVar, default: undefined }
|
||||
return inputVar
|
||||
}
|
||||
|
||||
export type IConfigModalProps = {
|
||||
isCreate?: boolean
|
||||
payload?: InputVar
|
||||
@ -73,7 +67,7 @@ const ConfigModal: FC<IConfigModalProps> = ({
|
||||
}) => {
|
||||
const { modelConfig } = useContext(ConfigContext)
|
||||
const { t } = useTranslation()
|
||||
const [tempPayload, setTempPayload] = useState<InputVar>(() => normalizeSelectDefaultValue(payload || getNewVarInWorkflow('') as any))
|
||||
const [tempPayload, setTempPayload] = useState<InputVar>(() => payload || getNewVarInWorkflow('') as any)
|
||||
const { type, label, variable, options, max_length } = tempPayload
|
||||
const modalRef = useRef<HTMLDivElement>(null)
|
||||
const appDetail = useAppStore(state => state.appDetail)
|
||||
@ -188,8 +182,6 @@ const ConfigModal: FC<IConfigModalProps> = ({
|
||||
|
||||
const newPayload = produce(tempPayload, (draft) => {
|
||||
draft.type = type
|
||||
if (type === InputVarType.select)
|
||||
draft.default = undefined
|
||||
if ([InputVarType.singleFile, InputVarType.multiFiles].includes(type)) {
|
||||
(Object.keys(DEFAULT_FILE_UPLOAD_SETTING)).forEach((key) => {
|
||||
if (key !== 'max_length')
|
||||
|
||||
@ -176,7 +176,7 @@ const DatasetConfig: FC = () => {
|
||||
}))
|
||||
}, [setDatasetConfigs, datasetConfigsRef])
|
||||
|
||||
const handleAddCondition = useCallback<HandleAddCondition>(({ id, name, type }) => {
|
||||
const handleAddCondition = useCallback<HandleAddCondition>(({ name, type }) => {
|
||||
let operator: ComparisonOperator = ComparisonOperator.is
|
||||
|
||||
if (type === MetadataFilteringVariableType.number)
|
||||
@ -184,7 +184,6 @@ const DatasetConfig: FC = () => {
|
||||
|
||||
const newCondition = {
|
||||
id: uuid4(),
|
||||
metadata_id: id, // Save metadata.id for reliable reference
|
||||
name,
|
||||
comparison_operator: operator,
|
||||
}
|
||||
|
||||
@ -1,141 +0,0 @@
|
||||
import type { DataSet } from '@/models/datasets'
|
||||
import { act, fireEvent, render, screen } from '@testing-library/react'
|
||||
import * as React from 'react'
|
||||
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
import { IndexingType } from '@/app/components/datasets/create/step-two'
|
||||
import { DatasetPermission } from '@/models/datasets'
|
||||
import { RETRIEVE_METHOD } from '@/types/app'
|
||||
import SelectDataSet from './index'
|
||||
|
||||
vi.mock('@/i18n-config/i18next-config', () => ({
|
||||
__esModule: true,
|
||||
default: {
|
||||
changeLanguage: vi.fn(),
|
||||
addResourceBundle: vi.fn(),
|
||||
use: vi.fn().mockReturnThis(),
|
||||
init: vi.fn(),
|
||||
addResource: vi.fn(),
|
||||
hasResourceBundle: vi.fn().mockReturnValue(true),
|
||||
},
|
||||
}))
|
||||
const mockUseInfiniteScroll = vi.fn()
|
||||
vi.mock('ahooks', async (importOriginal) => {
|
||||
const actual = await importOriginal()
|
||||
return {
|
||||
...(typeof actual === 'object' && actual !== null ? actual : {}),
|
||||
useInfiniteScroll: (...args: any[]) => mockUseInfiniteScroll(...args),
|
||||
}
|
||||
})
|
||||
|
||||
const mockUseInfiniteDatasets = vi.fn()
|
||||
vi.mock('@/service/knowledge/use-dataset', () => ({
|
||||
useInfiniteDatasets: (...args: any[]) => mockUseInfiniteDatasets(...args),
|
||||
}))
|
||||
|
||||
vi.mock('@/hooks/use-knowledge', () => ({
|
||||
useKnowledge: () => ({
|
||||
formatIndexingTechniqueAndMethod: (tech: string, method: string) => `${tech}:${method}`,
|
||||
}),
|
||||
}))
|
||||
|
||||
const baseProps = {
|
||||
isShow: true,
|
||||
onClose: vi.fn(),
|
||||
selectedIds: [] as string[],
|
||||
onSelect: vi.fn(),
|
||||
}
|
||||
|
||||
const makeDataset = (overrides: Partial<DataSet>): DataSet => ({
|
||||
id: 'dataset-id',
|
||||
name: 'Dataset Name',
|
||||
provider: 'internal',
|
||||
icon_info: {
|
||||
icon_type: 'emoji',
|
||||
icon: '💾',
|
||||
icon_background: '#fff',
|
||||
icon_url: '',
|
||||
},
|
||||
embedding_available: true,
|
||||
is_multimodal: false,
|
||||
description: '',
|
||||
permission: DatasetPermission.allTeamMembers,
|
||||
indexing_technique: IndexingType.ECONOMICAL,
|
||||
retrieval_model_dict: {
|
||||
search_method: RETRIEVE_METHOD.fullText,
|
||||
top_k: 5,
|
||||
reranking_enable: false,
|
||||
reranking_model: {
|
||||
reranking_model_name: '',
|
||||
reranking_provider_name: '',
|
||||
},
|
||||
score_threshold_enabled: false,
|
||||
score_threshold: 0,
|
||||
},
|
||||
...overrides,
|
||||
} as DataSet)
|
||||
|
||||
describe('SelectDataSet', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
it('renders dataset entries, allows selection, and fires onSelect', async () => {
|
||||
const datasetOne = makeDataset({
|
||||
id: 'set-1',
|
||||
name: 'Dataset One',
|
||||
is_multimodal: true,
|
||||
indexing_technique: IndexingType.ECONOMICAL,
|
||||
})
|
||||
const datasetTwo = makeDataset({
|
||||
id: 'set-2',
|
||||
name: 'Hidden Dataset',
|
||||
embedding_available: false,
|
||||
provider: 'external',
|
||||
})
|
||||
mockUseInfiniteDatasets.mockReturnValue({
|
||||
data: { pages: [{ data: [datasetOne, datasetTwo] }] },
|
||||
isLoading: false,
|
||||
isFetchingNextPage: false,
|
||||
fetchNextPage: vi.fn(),
|
||||
hasNextPage: false,
|
||||
})
|
||||
|
||||
const onSelect = vi.fn()
|
||||
await act(async () => {
|
||||
render(<SelectDataSet {...baseProps} onSelect={onSelect} selectedIds={[]} />)
|
||||
})
|
||||
|
||||
expect(screen.getByText('Dataset One')).toBeInTheDocument()
|
||||
expect(screen.getByText('Hidden Dataset')).toBeInTheDocument()
|
||||
|
||||
await act(async () => {
|
||||
fireEvent.click(screen.getByText('Dataset One'))
|
||||
})
|
||||
expect(screen.getByText('1 appDebug.feature.dataSet.selected')).toBeInTheDocument()
|
||||
|
||||
const addButton = screen.getByRole('button', { name: 'common.operation.add' })
|
||||
await act(async () => {
|
||||
fireEvent.click(addButton)
|
||||
})
|
||||
expect(onSelect).toHaveBeenCalledWith([datasetOne])
|
||||
})
|
||||
|
||||
it('shows empty state when no datasets are available and disables add', async () => {
|
||||
mockUseInfiniteDatasets.mockReturnValue({
|
||||
data: { pages: [{ data: [] }] },
|
||||
isLoading: false,
|
||||
isFetchingNextPage: false,
|
||||
fetchNextPage: vi.fn(),
|
||||
hasNextPage: false,
|
||||
})
|
||||
|
||||
await act(async () => {
|
||||
render(<SelectDataSet {...baseProps} onSelect={vi.fn()} selectedIds={[]} />)
|
||||
})
|
||||
|
||||
expect(screen.getByText('appDebug.feature.dataSet.noDataSet')).toBeInTheDocument()
|
||||
expect(screen.getByRole('link', { name: 'appDebug.feature.dataSet.toCreate' })).toHaveAttribute('href', '/datasets/create')
|
||||
expect(screen.getByRole('button', { name: 'common.operation.add' })).toBeDisabled()
|
||||
})
|
||||
})
|
||||
@ -679,7 +679,7 @@ const Configuration: FC = () => {
|
||||
const toolInCollectionList = collectionList.find(c => tool.provider_id === c.id)
|
||||
return {
|
||||
...tool,
|
||||
isDeleted: res.deleted_tools?.some((deletedTool: any) => deletedTool.provider_id === tool.provider_id && deletedTool.tool_name === tool.tool_name) ?? false,
|
||||
isDeleted: res.deleted_tools?.some((deletedTool: any) => deletedTool.id === tool.id && deletedTool.tool_name === tool.tool_name) ?? false,
|
||||
notAuthor: toolInCollectionList?.is_team_authorization === false,
|
||||
...(tool.provider_type === 'builtin'
|
||||
? {
|
||||
|
||||
@ -1,125 +0,0 @@
|
||||
import type { IPromptValuePanelProps } from './index'
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import * as React from 'react'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { useStore } from '@/app/components/app/store'
|
||||
import ConfigContext from '@/context/debug-configuration'
|
||||
import { AppModeEnum, ModelModeType, Resolution } from '@/types/app'
|
||||
import PromptValuePanel from './index'
|
||||
|
||||
vi.mock('@/app/components/app/store', () => ({
|
||||
useStore: vi.fn(),
|
||||
}))
|
||||
vi.mock('@/app/components/base/features/new-feature-panel/feature-bar', () => ({
|
||||
__esModule: true,
|
||||
default: ({ onFeatureBarClick }: { onFeatureBarClick: () => void }) => (
|
||||
<button type="button" onClick={onFeatureBarClick}>
|
||||
feature bar
|
||||
</button>
|
||||
),
|
||||
}))
|
||||
|
||||
const mockSetShowAppConfigureFeaturesModal = vi.fn()
|
||||
const mockUseStore = vi.mocked(useStore)
|
||||
const mockSetInputs = vi.fn()
|
||||
const mockOnSend = vi.fn()
|
||||
|
||||
const promptVariables = [
|
||||
{ key: 'textVar', name: 'Text Var', type: 'string', required: true },
|
||||
{ key: 'boolVar', name: 'Boolean Var', type: 'checkbox' },
|
||||
] as const
|
||||
|
||||
const baseContextValue: any = {
|
||||
modelModeType: ModelModeType.completion,
|
||||
modelConfig: {
|
||||
configs: {
|
||||
prompt_template: 'prompt template',
|
||||
prompt_variables: promptVariables,
|
||||
},
|
||||
},
|
||||
setInputs: mockSetInputs,
|
||||
mode: AppModeEnum.COMPLETION,
|
||||
isAdvancedMode: false,
|
||||
completionPromptConfig: {
|
||||
prompt: { text: 'completion' },
|
||||
conversation_histories_role: { user_prefix: 'user', assistant_prefix: 'assistant' },
|
||||
},
|
||||
chatPromptConfig: { prompt: [] },
|
||||
} as any
|
||||
|
||||
const defaultProps: IPromptValuePanelProps = {
|
||||
appType: AppModeEnum.COMPLETION,
|
||||
onSend: mockOnSend,
|
||||
inputs: { textVar: 'initial', boolVar: false },
|
||||
visionConfig: { enabled: false, number_limits: 0, detail: Resolution.low, transfer_methods: [] },
|
||||
onVisionFilesChange: vi.fn(),
|
||||
}
|
||||
|
||||
const renderPanel = (options: {
|
||||
context?: Partial<typeof baseContextValue>
|
||||
props?: Partial<IPromptValuePanelProps>
|
||||
} = {}) => {
|
||||
const contextValue = { ...baseContextValue, ...options.context }
|
||||
const props = { ...defaultProps, ...options.props }
|
||||
return render(
|
||||
<ConfigContext.Provider value={contextValue}>
|
||||
<PromptValuePanel {...props} />
|
||||
</ConfigContext.Provider>,
|
||||
)
|
||||
}
|
||||
|
||||
describe('PromptValuePanel', () => {
|
||||
beforeEach(() => {
|
||||
mockUseStore.mockImplementation(selector => selector({
|
||||
setShowAppConfigureFeaturesModal: mockSetShowAppConfigureFeaturesModal,
|
||||
appSidebarExpand: '',
|
||||
currentLogModalActiveTab: 'prompt',
|
||||
showPromptLogModal: false,
|
||||
showAgentLogModal: false,
|
||||
setShowPromptLogModal: vi.fn(),
|
||||
setShowAgentLogModal: vi.fn(),
|
||||
showMessageLogModal: false,
|
||||
showAppConfigureFeaturesModal: false,
|
||||
} as any))
|
||||
mockSetInputs.mockClear()
|
||||
mockOnSend.mockClear()
|
||||
mockSetShowAppConfigureFeaturesModal.mockClear()
|
||||
})
|
||||
|
||||
it('updates inputs, clears values, and triggers run when ready', async () => {
|
||||
renderPanel()
|
||||
|
||||
const textInput = screen.getByPlaceholderText('Text Var')
|
||||
fireEvent.change(textInput, { target: { value: 'updated' } })
|
||||
expect(mockSetInputs).toHaveBeenCalledWith(expect.objectContaining({ textVar: 'updated' }))
|
||||
|
||||
const clearButton = screen.getByRole('button', { name: 'common.operation.clear' })
|
||||
fireEvent.click(clearButton)
|
||||
|
||||
expect(mockSetInputs).toHaveBeenLastCalledWith({
|
||||
textVar: '',
|
||||
boolVar: '',
|
||||
})
|
||||
|
||||
const runButton = screen.getByRole('button', { name: 'appDebug.inputs.run' })
|
||||
expect(runButton).not.toBeDisabled()
|
||||
fireEvent.click(runButton)
|
||||
await waitFor(() => expect(mockOnSend).toHaveBeenCalledTimes(1))
|
||||
})
|
||||
|
||||
it('disables run when mode is not completion', () => {
|
||||
renderPanel({
|
||||
context: {
|
||||
mode: AppModeEnum.CHAT,
|
||||
},
|
||||
props: {
|
||||
appType: AppModeEnum.CHAT,
|
||||
},
|
||||
})
|
||||
|
||||
const runButton = screen.getByRole('button', { name: 'appDebug.inputs.run' })
|
||||
expect(runButton).toBeDisabled()
|
||||
fireEvent.click(runButton)
|
||||
expect(mockOnSend).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
@ -1,29 +0,0 @@
|
||||
import type { PromptVariable } from '@/models/debug'
|
||||
|
||||
import { describe, expect, it } from 'vitest'
|
||||
import { replaceStringWithValues } from './utils'
|
||||
|
||||
const promptVariables: PromptVariable[] = [
|
||||
{ key: 'user', name: 'User', type: 'string' },
|
||||
{ key: 'topic', name: 'Topic', type: 'string' },
|
||||
]
|
||||
|
||||
describe('replaceStringWithValues', () => {
|
||||
it('should replace placeholders when inputs have values', () => {
|
||||
const template = 'Hello {{user}} talking about {{topic}}'
|
||||
const result = replaceStringWithValues(template, promptVariables, { user: 'Alice', topic: 'cats' })
|
||||
expect(result).toBe('Hello Alice talking about cats')
|
||||
})
|
||||
|
||||
it('should use prompt variable name when value is missing', () => {
|
||||
const template = 'Hi {{user}} from {{topic}}'
|
||||
const result = replaceStringWithValues(template, promptVariables, {})
|
||||
expect(result).toBe('Hi {{User}} from {{Topic}}')
|
||||
})
|
||||
|
||||
it('should leave placeholder untouched when no variable is defined', () => {
|
||||
const template = 'Unknown {{missing}} placeholder'
|
||||
const result = replaceStringWithValues(template, promptVariables, {})
|
||||
expect(result).toBe('Unknown {{missing}} placeholder')
|
||||
})
|
||||
})
|
||||
@ -1,136 +0,0 @@
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import { AppModeEnum } from '@/types/app'
|
||||
import Apps from './index'
|
||||
|
||||
const mockUseExploreAppList = vi.fn()
|
||||
|
||||
vi.mock('ahooks', () => ({
|
||||
useDebounceFn: (fn: () => void) => ({
|
||||
run: () => setTimeout(fn, 0),
|
||||
cancel: vi.fn(),
|
||||
flush: () => fn(),
|
||||
}),
|
||||
}))
|
||||
vi.mock('@/context/app-context', () => ({
|
||||
useAppContext: () => ({ isCurrentWorkspaceEditor: true }),
|
||||
}))
|
||||
vi.mock('use-context-selector', async () => {
|
||||
const actual = await vi.importActual<typeof import('use-context-selector')>('use-context-selector')
|
||||
return {
|
||||
...actual,
|
||||
useContext: () => ({ hasEditPermission: true }),
|
||||
}
|
||||
})
|
||||
vi.mock('nuqs', () => ({
|
||||
useQueryState: () => ['Recommended', vi.fn()],
|
||||
}))
|
||||
vi.mock('@/service/use-explore', () => ({
|
||||
useExploreAppList: () => mockUseExploreAppList(),
|
||||
}))
|
||||
vi.mock('@/app/components/app/type-selector', () => ({
|
||||
__esModule: true,
|
||||
default: ({ value, onChange }: { value: AppModeEnum[], onChange: (value: AppModeEnum[]) => void }) => (
|
||||
<button data-testid="type-selector" onClick={() => onChange([...value, 'chat' as AppModeEnum])}>{value.join(',')}</button>
|
||||
),
|
||||
}))
|
||||
vi.mock('../app-card', () => ({
|
||||
__esModule: true,
|
||||
default: ({ app, onCreate }: { app: any, onCreate: () => void }) => (
|
||||
<div
|
||||
data-testid="app-card"
|
||||
data-name={app.app.name}
|
||||
onClick={onCreate}
|
||||
>
|
||||
{app.app.name}
|
||||
</div>
|
||||
),
|
||||
}))
|
||||
vi.mock('@/app/components/explore/create-app-modal', () => ({
|
||||
__esModule: true,
|
||||
default: () => <div data-testid="create-from-template-modal" />,
|
||||
}))
|
||||
vi.mock('@/app/components/base/toast', () => ({
|
||||
default: { notify: vi.fn() },
|
||||
}))
|
||||
vi.mock('@/app/components/base/amplitude', () => ({
|
||||
trackEvent: vi.fn(),
|
||||
}))
|
||||
vi.mock('@/service/apps', () => ({
|
||||
importDSL: vi.fn().mockResolvedValue({ app_id: '1' }),
|
||||
}))
|
||||
vi.mock('@/service/explore', () => ({
|
||||
fetchAppDetail: vi.fn().mockResolvedValue({
|
||||
export_data: 'dsl',
|
||||
mode: 'chat',
|
||||
}),
|
||||
}))
|
||||
vi.mock('@/app/components/workflow/plugin-dependency/hooks', () => ({
|
||||
usePluginDependencies: () => ({
|
||||
handleCheckPluginDependencies: vi.fn(),
|
||||
}),
|
||||
}))
|
||||
vi.mock('@/utils/app-redirection', () => ({
|
||||
getRedirection: vi.fn(),
|
||||
}))
|
||||
vi.mock('next/navigation', () => ({
|
||||
useRouter: () => ({ push: vi.fn() }),
|
||||
}))
|
||||
|
||||
const createAppEntry = (name: string, category: string) => ({
|
||||
app_id: name,
|
||||
category,
|
||||
app: {
|
||||
id: name,
|
||||
name,
|
||||
icon_type: 'emoji',
|
||||
icon: '🙂',
|
||||
icon_background: '#000',
|
||||
icon_url: null,
|
||||
description: 'desc',
|
||||
mode: AppModeEnum.CHAT,
|
||||
},
|
||||
})
|
||||
|
||||
describe('Apps', () => {
|
||||
const defaultData = {
|
||||
allList: [
|
||||
createAppEntry('Alpha', 'Cat A'),
|
||||
createAppEntry('Bravo', 'Cat B'),
|
||||
],
|
||||
categories: ['Cat A', 'Cat B'],
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockUseExploreAppList.mockReturnValue({
|
||||
data: defaultData,
|
||||
isLoading: false,
|
||||
})
|
||||
})
|
||||
|
||||
it('renders template cards when data is available', () => {
|
||||
render(<Apps />)
|
||||
|
||||
expect(screen.getAllByTestId('app-card')).toHaveLength(2)
|
||||
expect(screen.getByText('Alpha')).toBeInTheDocument()
|
||||
expect(screen.getByText('Bravo')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('opens create modal when a template card is clicked', () => {
|
||||
render(<Apps />)
|
||||
|
||||
fireEvent.click(screen.getAllByTestId('app-card')[0])
|
||||
expect(screen.getByTestId('create-from-template-modal')).toBeInTheDocument()
|
||||
})
|
||||
it('shows no template message when list is empty', () => {
|
||||
mockUseExploreAppList.mockReturnValueOnce({
|
||||
data: { allList: [], categories: [] },
|
||||
isLoading: false,
|
||||
})
|
||||
|
||||
render(<Apps />)
|
||||
|
||||
expect(screen.getByText('app.newApp.noTemplateFound')).toBeInTheDocument()
|
||||
expect(screen.getByText('app.newApp.noTemplateFoundTip')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
@ -8,7 +8,6 @@ import { useRouter } from 'next/navigation'
|
||||
import * as React from 'react'
|
||||
import { useMemo, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useContext } from 'use-context-selector'
|
||||
import AppTypeSelector from '@/app/components/app/type-selector'
|
||||
import { trackEvent } from '@/app/components/base/amplitude'
|
||||
import Divider from '@/app/components/base/divider'
|
||||
@ -19,7 +18,7 @@ import CreateAppModal from '@/app/components/explore/create-app-modal'
|
||||
import { usePluginDependencies } from '@/app/components/workflow/plugin-dependency/hooks'
|
||||
import { NEED_REFRESH_APP_LIST_KEY } from '@/config'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import ExploreContext from '@/context/explore-context'
|
||||
import { useTabSearchParams } from '@/hooks/use-tab-searchparams'
|
||||
import { DSLImportMode } from '@/models/app'
|
||||
import { importDSL } from '@/service/apps'
|
||||
import { fetchAppDetail } from '@/service/explore'
|
||||
@ -47,7 +46,6 @@ const Apps = ({
|
||||
const { t } = useTranslation()
|
||||
const { isCurrentWorkspaceEditor } = useAppContext()
|
||||
const { push } = useRouter()
|
||||
const { hasEditPermission } = useContext(ExploreContext)
|
||||
const allCategoriesEn = AppCategories.RECOMMENDED
|
||||
|
||||
const [keywords, setKeywords] = useState('')
|
||||
@ -63,7 +61,10 @@ const Apps = ({
|
||||
}
|
||||
|
||||
const [currentType, setCurrentType] = useState<AppModeEnum[]>([])
|
||||
const [currCategory, setCurrCategory] = useState<AppCategories | string>(allCategoriesEn)
|
||||
const [currCategory, setCurrCategory] = useTabSearchParams({
|
||||
defaultTab: allCategoriesEn,
|
||||
disableSearchParams: true,
|
||||
})
|
||||
|
||||
const {
|
||||
data,
|
||||
@ -214,7 +215,7 @@ const Apps = ({
|
||||
<AppCard
|
||||
key={app.app_id}
|
||||
app={app}
|
||||
canCreate={hasEditPermission}
|
||||
canCreate={isCurrentWorkspaceEditor}
|
||||
onCreate={() => {
|
||||
setCurrApp(app)
|
||||
setIsShowCreateModal(true)
|
||||
|
||||
@ -1,38 +0,0 @@
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import Sidebar, { AppCategories } from './sidebar'
|
||||
|
||||
vi.mock('@remixicon/react', () => ({
|
||||
RiStickyNoteAddLine: () => <span>sticky</span>,
|
||||
RiThumbUpLine: () => <span>thumb</span>,
|
||||
}))
|
||||
describe('Sidebar', () => {
|
||||
it('renders recommended and custom categories', () => {
|
||||
render(<Sidebar current={AppCategories.RECOMMENDED} categories={['Cat A', 'Cat B']} />)
|
||||
|
||||
expect(screen.getByText('app.newAppFromTemplate.sidebar.Recommended')).toBeInTheDocument()
|
||||
expect(screen.getByText('Cat A')).toBeInTheDocument()
|
||||
expect(screen.getByText('Cat B')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('notifies callbacks when items are clicked', () => {
|
||||
const onClick = vi.fn()
|
||||
const onCreate = vi.fn()
|
||||
render(
|
||||
<Sidebar
|
||||
current="Cat A"
|
||||
categories={['Cat A']}
|
||||
onClick={onClick}
|
||||
onCreateFromBlank={onCreate}
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByText('app.newAppFromTemplate.sidebar.Recommended'))
|
||||
expect(onClick).toHaveBeenCalledWith(AppCategories.RECOMMENDED)
|
||||
|
||||
fireEvent.click(screen.getByText('Cat A'))
|
||||
expect(onClick).toHaveBeenCalledWith('Cat A')
|
||||
|
||||
fireEvent.click(screen.getByText('app.newApp.startFromBlank'))
|
||||
expect(onCreate).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
@ -1,162 +0,0 @@
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import { useRouter } from 'next/navigation'
|
||||
import { afterAll, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { trackEvent } from '@/app/components/base/amplitude'
|
||||
|
||||
import { ToastContext } from '@/app/components/base/toast'
|
||||
import { NEED_REFRESH_APP_LIST_KEY } from '@/config'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import { useProviderContext } from '@/context/provider-context'
|
||||
import { createApp } from '@/service/apps'
|
||||
import { AppModeEnum } from '@/types/app'
|
||||
import { getRedirection } from '@/utils/app-redirection'
|
||||
import CreateAppModal from './index'
|
||||
|
||||
vi.mock('ahooks', () => ({
|
||||
useDebounceFn: (fn: (...args: any[]) => any) => {
|
||||
const run = (...args: any[]) => fn(...args)
|
||||
const cancel = vi.fn()
|
||||
const flush = vi.fn()
|
||||
return { run, cancel, flush }
|
||||
},
|
||||
useKeyPress: vi.fn(),
|
||||
useHover: () => false,
|
||||
}))
|
||||
vi.mock('next/navigation', () => ({
|
||||
useRouter: vi.fn(),
|
||||
}))
|
||||
vi.mock('@/app/components/base/amplitude', () => ({
|
||||
trackEvent: vi.fn(),
|
||||
}))
|
||||
vi.mock('@/service/apps', () => ({
|
||||
createApp: vi.fn(),
|
||||
}))
|
||||
vi.mock('@/utils/app-redirection', () => ({
|
||||
getRedirection: vi.fn(),
|
||||
}))
|
||||
vi.mock('@/context/provider-context', () => ({
|
||||
useProviderContext: vi.fn(),
|
||||
}))
|
||||
vi.mock('@/context/app-context', () => ({
|
||||
useAppContext: vi.fn(),
|
||||
}))
|
||||
vi.mock('@/context/i18n', () => ({
|
||||
useDocLink: () => () => '/guides',
|
||||
}))
|
||||
vi.mock('@/hooks/use-theme', () => ({
|
||||
__esModule: true,
|
||||
default: () => ({ theme: 'light' }),
|
||||
}))
|
||||
|
||||
const mockNotify = vi.fn()
|
||||
const mockUseRouter = vi.mocked(useRouter)
|
||||
const mockPush = vi.fn()
|
||||
const mockCreateApp = vi.mocked(createApp)
|
||||
const mockTrackEvent = vi.mocked(trackEvent)
|
||||
const mockGetRedirection = vi.mocked(getRedirection)
|
||||
const mockUseProviderContext = vi.mocked(useProviderContext)
|
||||
const mockUseAppContext = vi.mocked(useAppContext)
|
||||
|
||||
const defaultPlanUsage = {
|
||||
buildApps: 0,
|
||||
teamMembers: 0,
|
||||
annotatedResponse: 0,
|
||||
documentsUploadQuota: 0,
|
||||
apiRateLimit: 0,
|
||||
triggerEvents: 0,
|
||||
vectorSpace: 0,
|
||||
}
|
||||
|
||||
const renderModal = () => {
|
||||
const onClose = vi.fn()
|
||||
const onSuccess = vi.fn()
|
||||
render(
|
||||
<ToastContext.Provider value={{ notify: mockNotify, close: vi.fn() }}>
|
||||
<CreateAppModal show onClose={onClose} onSuccess={onSuccess} defaultAppMode={AppModeEnum.ADVANCED_CHAT} />
|
||||
</ToastContext.Provider>,
|
||||
)
|
||||
return { onClose, onSuccess }
|
||||
}
|
||||
|
||||
describe('CreateAppModal', () => {
|
||||
const mockSetItem = vi.fn()
|
||||
const originalLocalStorage = window.localStorage
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockUseRouter.mockReturnValue({ push: mockPush } as any)
|
||||
mockUseProviderContext.mockReturnValue({
|
||||
plan: {
|
||||
type: AppModeEnum.ADVANCED_CHAT,
|
||||
usage: defaultPlanUsage,
|
||||
total: { ...defaultPlanUsage, buildApps: 1 },
|
||||
reset: {},
|
||||
},
|
||||
enableBilling: true,
|
||||
} as any)
|
||||
mockUseAppContext.mockReturnValue({
|
||||
isCurrentWorkspaceEditor: true,
|
||||
} as any)
|
||||
mockSetItem.mockClear()
|
||||
Object.defineProperty(window, 'localStorage', {
|
||||
value: {
|
||||
setItem: mockSetItem,
|
||||
getItem: vi.fn(),
|
||||
removeItem: vi.fn(),
|
||||
clear: vi.fn(),
|
||||
key: vi.fn(),
|
||||
length: 0,
|
||||
},
|
||||
writable: true,
|
||||
})
|
||||
})
|
||||
|
||||
afterAll(() => {
|
||||
Object.defineProperty(window, 'localStorage', {
|
||||
value: originalLocalStorage,
|
||||
writable: true,
|
||||
})
|
||||
})
|
||||
|
||||
it('creates an app, notifies success, and fires callbacks', async () => {
|
||||
const mockApp = { id: 'app-1', mode: AppModeEnum.ADVANCED_CHAT }
|
||||
mockCreateApp.mockResolvedValue(mockApp as any)
|
||||
const { onClose, onSuccess } = renderModal()
|
||||
|
||||
const nameInput = screen.getByPlaceholderText('app.newApp.appNamePlaceholder')
|
||||
fireEvent.change(nameInput, { target: { value: 'My App' } })
|
||||
fireEvent.click(screen.getByRole('button', { name: 'app.newApp.Create' }))
|
||||
|
||||
await waitFor(() => expect(mockCreateApp).toHaveBeenCalledWith({
|
||||
name: 'My App',
|
||||
description: '',
|
||||
icon_type: 'emoji',
|
||||
icon: '🤖',
|
||||
icon_background: '#FFEAD5',
|
||||
mode: AppModeEnum.ADVANCED_CHAT,
|
||||
}))
|
||||
|
||||
expect(mockTrackEvent).toHaveBeenCalledWith('create_app', {
|
||||
app_mode: AppModeEnum.ADVANCED_CHAT,
|
||||
description: '',
|
||||
})
|
||||
expect(mockNotify).toHaveBeenCalledWith({ type: 'success', message: 'app.newApp.appCreated' })
|
||||
expect(onSuccess).toHaveBeenCalled()
|
||||
expect(onClose).toHaveBeenCalled()
|
||||
await waitFor(() => expect(mockSetItem).toHaveBeenCalledWith(NEED_REFRESH_APP_LIST_KEY, '1'))
|
||||
await waitFor(() => expect(mockGetRedirection).toHaveBeenCalledWith(true, mockApp, mockPush))
|
||||
})
|
||||
|
||||
it('shows error toast when creation fails', async () => {
|
||||
mockCreateApp.mockRejectedValue(new Error('boom'))
|
||||
const { onClose } = renderModal()
|
||||
|
||||
const nameInput = screen.getByPlaceholderText('app.newApp.appNamePlaceholder')
|
||||
fireEvent.change(nameInput, { target: { value: 'My App' } })
|
||||
fireEvent.click(screen.getByRole('button', { name: 'app.newApp.Create' }))
|
||||
|
||||
await waitFor(() => expect(mockCreateApp).toHaveBeenCalled())
|
||||
expect(mockNotify).toHaveBeenCalledWith({ type: 'error', message: 'boom' })
|
||||
expect(onClose).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
@ -139,14 +139,14 @@ const getFormattedChatList = (messages: ChatMessage[], conversationId: string, t
|
||||
id: item.id,
|
||||
content: item.answer,
|
||||
agent_thoughts: addFileInfos(item.agent_thoughts ? sortAgentSorts(item.agent_thoughts) : item.agent_thoughts, item.message_files),
|
||||
feedback: item.feedbacks?.find(item => item.from_source === 'user'), // user feedback
|
||||
adminFeedback: item.feedbacks?.find(item => item.from_source === 'admin'), // admin feedback
|
||||
feedback: item.feedbacks.find(item => item.from_source === 'user'), // user feedback
|
||||
adminFeedback: item.feedbacks.find(item => item.from_source === 'admin'), // admin feedback
|
||||
feedbackDisabled: false,
|
||||
isAnswer: true,
|
||||
message_files: getProcessedFilesFromResponse(answerFiles.map((item: any) => ({ ...item, related_id: item.id }))),
|
||||
log: [
|
||||
...(item.message ?? []),
|
||||
...(item.message?.[item.message.length - 1]?.role !== 'assistant'
|
||||
...item.message,
|
||||
...(item.message[item.message.length - 1]?.role !== 'assistant'
|
||||
? [
|
||||
{
|
||||
role: 'assistant',
|
||||
@ -165,7 +165,7 @@ const getFormattedChatList = (messages: ChatMessage[], conversationId: string, t
|
||||
more: {
|
||||
time: dayjs.unix(item.created_at).tz(timezone).format(format),
|
||||
tokens: item.answer_tokens + item.message_tokens,
|
||||
latency: (item.provider_response_latency ?? 0).toFixed(2),
|
||||
latency: item.provider_response_latency.toFixed(2),
|
||||
},
|
||||
citation: item.metadata?.retriever_resources,
|
||||
annotation: (() => {
|
||||
|
||||
@ -1,121 +0,0 @@
|
||||
import type { SiteInfo } from '@/models/share'
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import copy from 'copy-to-clipboard'
|
||||
import * as React from 'react'
|
||||
|
||||
import { act } from 'react'
|
||||
import { afterAll, afterEach, describe, expect, it, vi } from 'vitest'
|
||||
import Embedded from './index'
|
||||
|
||||
vi.mock('./style.module.css', () => ({
|
||||
__esModule: true,
|
||||
default: {
|
||||
option: 'option',
|
||||
active: 'active',
|
||||
iframeIcon: 'iframeIcon',
|
||||
scriptsIcon: 'scriptsIcon',
|
||||
chromePluginIcon: 'chromePluginIcon',
|
||||
pluginInstallIcon: 'pluginInstallIcon',
|
||||
},
|
||||
}))
|
||||
const mockThemeBuilder = {
|
||||
buildTheme: vi.fn(),
|
||||
theme: {
|
||||
primaryColor: '#123456',
|
||||
},
|
||||
}
|
||||
const mockUseAppContext = vi.fn(() => ({
|
||||
langGeniusVersionInfo: {
|
||||
current_env: 'PRODUCTION',
|
||||
current_version: '',
|
||||
latest_version: '',
|
||||
release_date: '',
|
||||
release_notes: '',
|
||||
version: '',
|
||||
can_auto_update: false,
|
||||
},
|
||||
}))
|
||||
|
||||
vi.mock('copy-to-clipboard', () => ({
|
||||
__esModule: true,
|
||||
default: vi.fn(),
|
||||
}))
|
||||
vi.mock('@/app/components/base/chat/embedded-chatbot/theme/theme-context', () => ({
|
||||
useThemeContext: () => mockThemeBuilder,
|
||||
}))
|
||||
vi.mock('@/context/app-context', () => ({
|
||||
useAppContext: () => mockUseAppContext(),
|
||||
}))
|
||||
const mockWindowOpen = vi.spyOn(window, 'open').mockImplementation(() => null)
|
||||
const mockedCopy = vi.mocked(copy)
|
||||
|
||||
const siteInfo: SiteInfo = {
|
||||
title: 'test site',
|
||||
chat_color_theme: '#000000',
|
||||
chat_color_theme_inverted: false,
|
||||
}
|
||||
|
||||
const baseProps = {
|
||||
isShow: true,
|
||||
siteInfo,
|
||||
onClose: vi.fn(),
|
||||
appBaseUrl: 'https://app.example.com',
|
||||
accessToken: 'token',
|
||||
className: 'custom-modal',
|
||||
}
|
||||
|
||||
const getCopyButton = () => {
|
||||
const buttons = screen.getAllByRole('button')
|
||||
const actionButton = buttons.find(button => button.className.includes('action-btn'))
|
||||
expect(actionButton).toBeDefined()
|
||||
return actionButton!
|
||||
}
|
||||
|
||||
describe('Embedded', () => {
|
||||
afterEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockWindowOpen.mockClear()
|
||||
})
|
||||
|
||||
afterAll(() => {
|
||||
mockWindowOpen.mockRestore()
|
||||
})
|
||||
|
||||
it('builds theme and copies iframe snippet', async () => {
|
||||
await act(async () => {
|
||||
render(<Embedded {...baseProps} />)
|
||||
})
|
||||
|
||||
const actionButton = getCopyButton()
|
||||
const innerDiv = actionButton.querySelector('div')
|
||||
act(() => {
|
||||
fireEvent.click(innerDiv ?? actionButton)
|
||||
})
|
||||
|
||||
expect(mockThemeBuilder.buildTheme).toHaveBeenCalledWith(siteInfo.chat_color_theme, siteInfo.chat_color_theme_inverted)
|
||||
expect(mockedCopy).toHaveBeenCalledWith(expect.stringContaining('/chatbot/token'))
|
||||
})
|
||||
|
||||
it('opens chrome plugin store link when chrome option selected', async () => {
|
||||
await act(async () => {
|
||||
render(<Embedded {...baseProps} />)
|
||||
})
|
||||
|
||||
const optionButtons = document.body.querySelectorAll('[class*="option"]')
|
||||
expect(optionButtons.length).toBeGreaterThanOrEqual(3)
|
||||
act(() => {
|
||||
fireEvent.click(optionButtons[2])
|
||||
})
|
||||
|
||||
const [chromeText] = screen.getAllByText('appOverview.overview.appInfo.embedded.chromePlugin')
|
||||
act(() => {
|
||||
fireEvent.click(chromeText)
|
||||
})
|
||||
|
||||
expect(mockWindowOpen).toHaveBeenCalledWith(
|
||||
'https://chrome.google.com/webstore/detail/dify-chatbot/ceehdapohffmjmkdcifjofadiaoeggaf',
|
||||
'_blank',
|
||||
'noopener,noreferrer',
|
||||
)
|
||||
})
|
||||
})
|
||||
@ -1,217 +0,0 @@
|
||||
import type { ReactNode } from 'react'
|
||||
import type { ModalContextState } from '@/context/modal-context'
|
||||
import type { ProviderContextState } from '@/context/provider-context'
|
||||
import type { AppDetailResponse } from '@/models/app'
|
||||
import type { AppSSO } from '@/types/app'
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import { Plan } from '@/app/components/billing/type'
|
||||
import { baseProviderContextValue } from '@/context/provider-context'
|
||||
import { AppModeEnum } from '@/types/app'
|
||||
import SettingsModal from './index'
|
||||
|
||||
vi.mock('react-i18next', async () => {
|
||||
const actual = await vi.importActual<typeof import('react-i18next')>('react-i18next')
|
||||
return {
|
||||
...actual,
|
||||
useTranslation: () => ({
|
||||
t: (key: string, options?: Record<string, unknown>) => {
|
||||
if (options?.returnObjects)
|
||||
return [`${key}-feature-1`, `${key}-feature-2`]
|
||||
if (options)
|
||||
return `${key}:${JSON.stringify(options)}`
|
||||
return key
|
||||
},
|
||||
i18n: {
|
||||
language: 'en',
|
||||
changeLanguage: vi.fn(),
|
||||
},
|
||||
}),
|
||||
Trans: ({ children }: { children?: ReactNode }) => <>{children}</>,
|
||||
}
|
||||
})
|
||||
|
||||
const mockNotify = vi.fn()
|
||||
const mockOnClose = vi.fn()
|
||||
const mockOnSave = vi.fn()
|
||||
const mockSetShowPricingModal = vi.fn()
|
||||
const mockSetShowAccountSettingModal = vi.fn()
|
||||
const mockUseProviderContext = vi.fn<() => ProviderContextState>()
|
||||
|
||||
const buildModalContext = (): ModalContextState => ({
|
||||
setShowAccountSettingModal: mockSetShowAccountSettingModal,
|
||||
setShowApiBasedExtensionModal: vi.fn(),
|
||||
setShowModerationSettingModal: vi.fn(),
|
||||
setShowExternalDataToolModal: vi.fn(),
|
||||
setShowPricingModal: mockSetShowPricingModal,
|
||||
setShowAnnotationFullModal: vi.fn(),
|
||||
setShowModelModal: vi.fn(),
|
||||
setShowExternalKnowledgeAPIModal: vi.fn(),
|
||||
setShowModelLoadBalancingModal: vi.fn(),
|
||||
setShowOpeningModal: vi.fn(),
|
||||
setShowUpdatePluginModal: vi.fn(),
|
||||
setShowEducationExpireNoticeModal: vi.fn(),
|
||||
setShowTriggerEventsLimitModal: vi.fn(),
|
||||
})
|
||||
|
||||
vi.mock('@/context/modal-context', () => ({
|
||||
useModalContext: () => buildModalContext(),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/toast', async () => {
|
||||
const actual = await vi.importActual<typeof import('@/app/components/base/toast')>('@/app/components/base/toast')
|
||||
return {
|
||||
...actual,
|
||||
useToastContext: () => ({
|
||||
notify: mockNotify,
|
||||
close: vi.fn(),
|
||||
}),
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock('@/context/i18n', async () => {
|
||||
const actual = await vi.importActual<typeof import('@/context/i18n')>('@/context/i18n')
|
||||
return {
|
||||
...actual,
|
||||
useDocLink: () => (path?: string) => `https://docs.example.com${path ?? ''}`,
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock('@/context/provider-context', async () => {
|
||||
const actual = await vi.importActual<typeof import('@/context/provider-context')>('@/context/provider-context')
|
||||
return {
|
||||
...actual,
|
||||
useProviderContext: () => mockUseProviderContext(),
|
||||
}
|
||||
})
|
||||
|
||||
const mockAppInfo = {
|
||||
site: {
|
||||
title: 'Test App',
|
||||
icon_type: 'emoji',
|
||||
icon: '😀',
|
||||
icon_background: '#ABCDEF',
|
||||
icon_url: 'https://example.com/icon.png',
|
||||
description: 'A description',
|
||||
chat_color_theme: '#123456',
|
||||
chat_color_theme_inverted: true,
|
||||
copyright: '© Dify',
|
||||
privacy_policy: '',
|
||||
custom_disclaimer: 'Disclaimer',
|
||||
default_language: 'en-US',
|
||||
show_workflow_steps: true,
|
||||
use_icon_as_answer_icon: true,
|
||||
},
|
||||
mode: AppModeEnum.ADVANCED_CHAT,
|
||||
enable_sso: false,
|
||||
} as unknown as AppDetailResponse & Partial<AppSSO>
|
||||
|
||||
const renderSettingsModal = () => render(
|
||||
<SettingsModal
|
||||
isChat
|
||||
isShow
|
||||
appInfo={mockAppInfo}
|
||||
onClose={mockOnClose}
|
||||
onSave={mockOnSave}
|
||||
/>,
|
||||
)
|
||||
|
||||
describe('SettingsModal', () => {
|
||||
beforeEach(() => {
|
||||
mockNotify.mockClear()
|
||||
mockOnClose.mockClear()
|
||||
mockOnSave.mockClear()
|
||||
mockSetShowPricingModal.mockClear()
|
||||
mockSetShowAccountSettingModal.mockClear()
|
||||
mockUseProviderContext.mockReturnValue({
|
||||
...baseProviderContextValue,
|
||||
enableBilling: true,
|
||||
plan: {
|
||||
...baseProviderContextValue.plan,
|
||||
type: Plan.sandbox,
|
||||
},
|
||||
webappCopyrightEnabled: true,
|
||||
})
|
||||
})
|
||||
|
||||
it('should render the modal and expose the expanded settings section', async () => {
|
||||
renderSettingsModal()
|
||||
expect(screen.getByText('appOverview.overview.appInfo.settings.title')).toBeInTheDocument()
|
||||
|
||||
const showMoreEntry = screen.getByText('appOverview.overview.appInfo.settings.more.entry')
|
||||
fireEvent.click(showMoreEntry)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByPlaceholderText('appOverview.overview.appInfo.settings.more.copyRightPlaceholder')).toBeInTheDocument()
|
||||
expect(screen.getByPlaceholderText('appOverview.overview.appInfo.settings.more.privacyPolicyPlaceholder')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should notify the user when the name is empty', async () => {
|
||||
renderSettingsModal()
|
||||
const nameInput = screen.getByPlaceholderText('app.appNamePlaceholder')
|
||||
fireEvent.change(nameInput, { target: { value: '' } })
|
||||
fireEvent.click(screen.getByText('common.operation.save'))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ message: 'app.newApp.nameNotEmpty' }))
|
||||
})
|
||||
expect(mockOnSave).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should validate the theme color and show an error when the hex is invalid', async () => {
|
||||
renderSettingsModal()
|
||||
const colorInput = screen.getByPlaceholderText('E.g #A020F0')
|
||||
fireEvent.change(colorInput, { target: { value: 'not-a-hex' } })
|
||||
|
||||
fireEvent.click(screen.getByText('common.operation.save'))
|
||||
await waitFor(() => {
|
||||
expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({
|
||||
message: 'appOverview.overview.appInfo.settings.invalidHexMessage',
|
||||
}))
|
||||
})
|
||||
expect(mockOnSave).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should validate the privacy policy URL when advanced settings are open', async () => {
|
||||
renderSettingsModal()
|
||||
fireEvent.click(screen.getByText('appOverview.overview.appInfo.settings.more.entry'))
|
||||
const privacyInput = screen.getByPlaceholderText('appOverview.overview.appInfo.settings.more.privacyPolicyPlaceholder')
|
||||
// eslint-disable-next-line sonarjs/no-clear-text-protocols
|
||||
fireEvent.change(privacyInput, { target: { value: 'ftp://invalid-url' } })
|
||||
|
||||
fireEvent.click(screen.getByText('common.operation.save'))
|
||||
await waitFor(() => {
|
||||
expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({
|
||||
message: 'appOverview.overview.appInfo.settings.invalidPrivacyPolicy',
|
||||
}))
|
||||
})
|
||||
expect(mockOnSave).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should save valid settings and close the modal', async () => {
|
||||
mockOnSave.mockResolvedValueOnce(undefined)
|
||||
renderSettingsModal()
|
||||
|
||||
fireEvent.click(screen.getByText('common.operation.save'))
|
||||
|
||||
await waitFor(() => expect(mockOnSave).toHaveBeenCalled())
|
||||
expect(mockOnSave).toHaveBeenCalledWith(expect.objectContaining({
|
||||
title: mockAppInfo.site.title,
|
||||
description: mockAppInfo.site.description,
|
||||
default_language: mockAppInfo.site.default_language,
|
||||
chat_color_theme: mockAppInfo.site.chat_color_theme,
|
||||
chat_color_theme_inverted: mockAppInfo.site.chat_color_theme_inverted,
|
||||
prompt_public: false,
|
||||
copyright: mockAppInfo.site.copyright,
|
||||
privacy_policy: mockAppInfo.site.privacy_policy,
|
||||
custom_disclaimer: mockAppInfo.site.custom_disclaimer,
|
||||
icon_type: 'emoji',
|
||||
icon: mockAppInfo.site.icon,
|
||||
icon_background: mockAppInfo.site.icon_background,
|
||||
show_workflow_steps: mockAppInfo.site.show_workflow_steps,
|
||||
use_icon_as_answer_icon: mockAppInfo.site.use_icon_as_answer_icon,
|
||||
enable_sso: mockAppInfo.enable_sso,
|
||||
}))
|
||||
expect(mockOnClose).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
@ -1,67 +0,0 @@
|
||||
import type { ISavedItemsProps } from './index'
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import copy from 'copy-to-clipboard'
|
||||
|
||||
import * as React from 'react'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import SavedItems from './index'
|
||||
|
||||
vi.mock('copy-to-clipboard', () => ({
|
||||
__esModule: true,
|
||||
default: vi.fn(),
|
||||
}))
|
||||
vi.mock('next/navigation', () => ({
|
||||
useParams: () => ({}),
|
||||
usePathname: () => '/',
|
||||
}))
|
||||
|
||||
const mockCopy = vi.mocked(copy)
|
||||
const toastNotifySpy = vi.spyOn(Toast, 'notify')
|
||||
|
||||
const baseProps: ISavedItemsProps = {
|
||||
list: [
|
||||
{ id: '1', answer: 'hello world' },
|
||||
],
|
||||
isShowTextToSpeech: true,
|
||||
onRemove: vi.fn(),
|
||||
onStartCreateContent: vi.fn(),
|
||||
}
|
||||
|
||||
describe('SavedItems', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
toastNotifySpy.mockClear()
|
||||
})
|
||||
|
||||
it('renders saved answers with metadata and controls', () => {
|
||||
const { container } = render(<SavedItems {...baseProps} />)
|
||||
|
||||
const markdownElement = container.querySelector('.markdown-body')
|
||||
expect(markdownElement).toBeInTheDocument()
|
||||
expect(screen.getByText('11 common.unit.char')).toBeInTheDocument()
|
||||
|
||||
const actionArea = container.querySelector('[class*="bg-components-actionbar-bg"]')
|
||||
const actionButtons = actionArea?.querySelectorAll('button') ?? []
|
||||
expect(actionButtons.length).toBeGreaterThanOrEqual(3)
|
||||
})
|
||||
|
||||
it('copies content and notifies, and triggers remove callback', () => {
|
||||
const handleRemove = vi.fn()
|
||||
const { container } = render(<SavedItems {...baseProps} onRemove={handleRemove} />)
|
||||
|
||||
const actionArea = container.querySelector('[class*="bg-components-actionbar-bg"]')
|
||||
const actionButtons = actionArea?.querySelectorAll('button') ?? []
|
||||
expect(actionButtons.length).toBeGreaterThanOrEqual(3)
|
||||
|
||||
const copyButton = actionButtons[1]
|
||||
const deleteButton = actionButtons[2]
|
||||
|
||||
fireEvent.click(copyButton)
|
||||
expect(mockCopy).toHaveBeenCalledWith('hello world')
|
||||
expect(toastNotifySpy).toHaveBeenCalledWith({ type: 'success', message: 'common.actionMsg.copySuccessfully' })
|
||||
|
||||
fireEvent.click(deleteButton)
|
||||
expect(handleRemove).toHaveBeenCalledWith('1')
|
||||
})
|
||||
})
|
||||
@ -1,22 +0,0 @@
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import NoData from './index'
|
||||
|
||||
describe('NoData', () => {
|
||||
it('renders title/description and calls callback when button clicked', () => {
|
||||
const handleStart = vi.fn()
|
||||
render(<NoData onStartCreateContent={handleStart} />)
|
||||
|
||||
const title = screen.getByText('share.generation.savedNoData.title')
|
||||
const description = screen.getByText('share.generation.savedNoData.description')
|
||||
const button = screen.getByRole('button', { name: 'share.generation.savedNoData.startCreateContent' })
|
||||
|
||||
expect(title).toBeInTheDocument()
|
||||
expect(description).toBeInTheDocument()
|
||||
expect(button).toBeInTheDocument()
|
||||
|
||||
fireEvent.click(button)
|
||||
expect(handleStart).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
})
|
||||
363
web/app/components/apps/hooks/use-apps-query-state.spec.ts
Normal file
363
web/app/components/apps/hooks/use-apps-query-state.spec.ts
Normal file
@ -0,0 +1,363 @@
|
||||
/**
|
||||
* Test suite for useAppsQueryState hook
|
||||
*
|
||||
* This hook manages app filtering state through URL search parameters, enabling:
|
||||
* - Bookmarkable filter states (users can share URLs with specific filters active)
|
||||
* - Browser history integration (back/forward buttons work with filters)
|
||||
* - Multiple filter types: tagIDs, keywords, isCreatedByMe
|
||||
*
|
||||
* The hook syncs local filter state with URL search parameters, making filter
|
||||
* navigation persistent and shareable across sessions.
|
||||
*/
|
||||
import { act, renderHook } from '@testing-library/react'
|
||||
|
||||
// Import the hook after mocks are set up
|
||||
import useAppsQueryState from './use-apps-query-state'
|
||||
|
||||
// Mock Next.js navigation hooks
|
||||
const mockPush = vi.fn()
|
||||
const mockPathname = '/apps'
|
||||
let mockSearchParams = new URLSearchParams()
|
||||
|
||||
vi.mock('next/navigation', () => ({
|
||||
usePathname: vi.fn(() => mockPathname),
|
||||
useRouter: vi.fn(() => ({
|
||||
push: mockPush,
|
||||
})),
|
||||
useSearchParams: vi.fn(() => mockSearchParams),
|
||||
}))
|
||||
|
||||
describe('useAppsQueryState', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockSearchParams = new URLSearchParams()
|
||||
})
|
||||
|
||||
describe('Basic functionality', () => {
|
||||
it('should return query object and setQuery function', () => {
|
||||
const { result } = renderHook(() => useAppsQueryState())
|
||||
|
||||
expect(result.current.query).toBeDefined()
|
||||
expect(typeof result.current.setQuery).toBe('function')
|
||||
})
|
||||
|
||||
it('should initialize with empty query when no search params exist', () => {
|
||||
const { result } = renderHook(() => useAppsQueryState())
|
||||
|
||||
expect(result.current.query.tagIDs).toBeUndefined()
|
||||
expect(result.current.query.keywords).toBeUndefined()
|
||||
expect(result.current.query.isCreatedByMe).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Parsing search params', () => {
|
||||
it('should parse tagIDs from URL', () => {
|
||||
mockSearchParams.set('tagIDs', 'tag1;tag2;tag3')
|
||||
|
||||
const { result } = renderHook(() => useAppsQueryState())
|
||||
|
||||
expect(result.current.query.tagIDs).toEqual(['tag1', 'tag2', 'tag3'])
|
||||
})
|
||||
|
||||
it('should parse single tagID from URL', () => {
|
||||
mockSearchParams.set('tagIDs', 'single-tag')
|
||||
|
||||
const { result } = renderHook(() => useAppsQueryState())
|
||||
|
||||
expect(result.current.query.tagIDs).toEqual(['single-tag'])
|
||||
})
|
||||
|
||||
it('should parse keywords from URL', () => {
|
||||
mockSearchParams.set('keywords', 'search term')
|
||||
|
||||
const { result } = renderHook(() => useAppsQueryState())
|
||||
|
||||
expect(result.current.query.keywords).toBe('search term')
|
||||
})
|
||||
|
||||
it('should parse isCreatedByMe as true from URL', () => {
|
||||
mockSearchParams.set('isCreatedByMe', 'true')
|
||||
|
||||
const { result } = renderHook(() => useAppsQueryState())
|
||||
|
||||
expect(result.current.query.isCreatedByMe).toBe(true)
|
||||
})
|
||||
|
||||
it('should parse isCreatedByMe as false for other values', () => {
|
||||
mockSearchParams.set('isCreatedByMe', 'false')
|
||||
|
||||
const { result } = renderHook(() => useAppsQueryState())
|
||||
|
||||
expect(result.current.query.isCreatedByMe).toBe(false)
|
||||
})
|
||||
|
||||
it('should parse all params together', () => {
|
||||
mockSearchParams.set('tagIDs', 'tag1;tag2')
|
||||
mockSearchParams.set('keywords', 'test')
|
||||
mockSearchParams.set('isCreatedByMe', 'true')
|
||||
|
||||
const { result } = renderHook(() => useAppsQueryState())
|
||||
|
||||
expect(result.current.query.tagIDs).toEqual(['tag1', 'tag2'])
|
||||
expect(result.current.query.keywords).toBe('test')
|
||||
expect(result.current.query.isCreatedByMe).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Updating query state', () => {
|
||||
it('should update keywords via setQuery', () => {
|
||||
const { result } = renderHook(() => useAppsQueryState())
|
||||
|
||||
act(() => {
|
||||
result.current.setQuery({ keywords: 'new search' })
|
||||
})
|
||||
|
||||
expect(result.current.query.keywords).toBe('new search')
|
||||
})
|
||||
|
||||
it('should update tagIDs via setQuery', () => {
|
||||
const { result } = renderHook(() => useAppsQueryState())
|
||||
|
||||
act(() => {
|
||||
result.current.setQuery({ tagIDs: ['tag1', 'tag2'] })
|
||||
})
|
||||
|
||||
expect(result.current.query.tagIDs).toEqual(['tag1', 'tag2'])
|
||||
})
|
||||
|
||||
it('should update isCreatedByMe via setQuery', () => {
|
||||
const { result } = renderHook(() => useAppsQueryState())
|
||||
|
||||
act(() => {
|
||||
result.current.setQuery({ isCreatedByMe: true })
|
||||
})
|
||||
|
||||
expect(result.current.query.isCreatedByMe).toBe(true)
|
||||
})
|
||||
|
||||
it('should support partial updates via callback', () => {
|
||||
const { result } = renderHook(() => useAppsQueryState())
|
||||
|
||||
act(() => {
|
||||
result.current.setQuery({ keywords: 'initial' })
|
||||
})
|
||||
|
||||
act(() => {
|
||||
result.current.setQuery(prev => ({ ...prev, isCreatedByMe: true }))
|
||||
})
|
||||
|
||||
expect(result.current.query.keywords).toBe('initial')
|
||||
expect(result.current.query.isCreatedByMe).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
describe('URL synchronization', () => {
|
||||
it('should sync keywords to URL', async () => {
|
||||
const { result } = renderHook(() => useAppsQueryState())
|
||||
|
||||
act(() => {
|
||||
result.current.setQuery({ keywords: 'search' })
|
||||
})
|
||||
|
||||
// Wait for useEffect to run
|
||||
await act(async () => {
|
||||
await new Promise(resolve => setTimeout(resolve, 0))
|
||||
})
|
||||
|
||||
expect(mockPush).toHaveBeenCalledWith(
|
||||
expect.stringContaining('keywords=search'),
|
||||
{ scroll: false },
|
||||
)
|
||||
})
|
||||
|
||||
it('should sync tagIDs to URL with semicolon separator', async () => {
|
||||
const { result } = renderHook(() => useAppsQueryState())
|
||||
|
||||
act(() => {
|
||||
result.current.setQuery({ tagIDs: ['tag1', 'tag2'] })
|
||||
})
|
||||
|
||||
await act(async () => {
|
||||
await new Promise(resolve => setTimeout(resolve, 0))
|
||||
})
|
||||
|
||||
expect(mockPush).toHaveBeenCalledWith(
|
||||
expect.stringContaining('tagIDs=tag1%3Btag2'),
|
||||
{ scroll: false },
|
||||
)
|
||||
})
|
||||
|
||||
it('should sync isCreatedByMe to URL', async () => {
|
||||
const { result } = renderHook(() => useAppsQueryState())
|
||||
|
||||
act(() => {
|
||||
result.current.setQuery({ isCreatedByMe: true })
|
||||
})
|
||||
|
||||
await act(async () => {
|
||||
await new Promise(resolve => setTimeout(resolve, 0))
|
||||
})
|
||||
|
||||
expect(mockPush).toHaveBeenCalledWith(
|
||||
expect.stringContaining('isCreatedByMe=true'),
|
||||
{ scroll: false },
|
||||
)
|
||||
})
|
||||
|
||||
it('should remove keywords from URL when empty', async () => {
|
||||
mockSearchParams.set('keywords', 'existing')
|
||||
|
||||
const { result } = renderHook(() => useAppsQueryState())
|
||||
|
||||
act(() => {
|
||||
result.current.setQuery({ keywords: '' })
|
||||
})
|
||||
|
||||
await act(async () => {
|
||||
await new Promise(resolve => setTimeout(resolve, 0))
|
||||
})
|
||||
|
||||
// Should be called without keywords param
|
||||
expect(mockPush).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should remove tagIDs from URL when empty array', async () => {
|
||||
mockSearchParams.set('tagIDs', 'tag1;tag2')
|
||||
|
||||
const { result } = renderHook(() => useAppsQueryState())
|
||||
|
||||
act(() => {
|
||||
result.current.setQuery({ tagIDs: [] })
|
||||
})
|
||||
|
||||
await act(async () => {
|
||||
await new Promise(resolve => setTimeout(resolve, 0))
|
||||
})
|
||||
|
||||
expect(mockPush).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should remove isCreatedByMe from URL when false', async () => {
|
||||
mockSearchParams.set('isCreatedByMe', 'true')
|
||||
|
||||
const { result } = renderHook(() => useAppsQueryState())
|
||||
|
||||
act(() => {
|
||||
result.current.setQuery({ isCreatedByMe: false })
|
||||
})
|
||||
|
||||
await act(async () => {
|
||||
await new Promise(resolve => setTimeout(resolve, 0))
|
||||
})
|
||||
|
||||
expect(mockPush).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Edge cases', () => {
|
||||
it('should handle empty tagIDs string in URL', () => {
|
||||
// NOTE: This test documents current behavior where ''.split(';') returns ['']
|
||||
// This could potentially cause filtering issues as it's treated as a tag with empty name
|
||||
// rather than absence of tags. Consider updating parseParams if this is problematic.
|
||||
mockSearchParams.set('tagIDs', '')
|
||||
|
||||
const { result } = renderHook(() => useAppsQueryState())
|
||||
|
||||
expect(result.current.query.tagIDs).toEqual([''])
|
||||
})
|
||||
|
||||
it('should handle empty keywords', () => {
|
||||
mockSearchParams.set('keywords', '')
|
||||
|
||||
const { result } = renderHook(() => useAppsQueryState())
|
||||
|
||||
expect(result.current.query.keywords).toBeUndefined()
|
||||
})
|
||||
|
||||
it('should handle undefined tagIDs', () => {
|
||||
const { result } = renderHook(() => useAppsQueryState())
|
||||
|
||||
act(() => {
|
||||
result.current.setQuery({ tagIDs: undefined })
|
||||
})
|
||||
|
||||
expect(result.current.query.tagIDs).toBeUndefined()
|
||||
})
|
||||
|
||||
it('should handle special characters in keywords', () => {
|
||||
// Use URLSearchParams constructor to properly simulate URL decoding behavior
|
||||
// URLSearchParams.get() decodes URL-encoded characters
|
||||
mockSearchParams = new URLSearchParams('keywords=test%20with%20spaces')
|
||||
|
||||
const { result } = renderHook(() => useAppsQueryState())
|
||||
|
||||
expect(result.current.query.keywords).toBe('test with spaces')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Memoization', () => {
|
||||
it('should return memoized object reference when query unchanged', () => {
|
||||
const { result, rerender } = renderHook(() => useAppsQueryState())
|
||||
|
||||
const firstResult = result.current
|
||||
rerender()
|
||||
const secondResult = result.current
|
||||
|
||||
expect(firstResult.query).toBe(secondResult.query)
|
||||
})
|
||||
|
||||
it('should return new object reference when query changes', () => {
|
||||
const { result } = renderHook(() => useAppsQueryState())
|
||||
|
||||
const firstQuery = result.current.query
|
||||
|
||||
act(() => {
|
||||
result.current.setQuery({ keywords: 'changed' })
|
||||
})
|
||||
|
||||
expect(result.current.query).not.toBe(firstQuery)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Integration scenarios', () => {
|
||||
it('should handle sequential updates', async () => {
|
||||
const { result } = renderHook(() => useAppsQueryState())
|
||||
|
||||
act(() => {
|
||||
result.current.setQuery({ keywords: 'first' })
|
||||
})
|
||||
|
||||
act(() => {
|
||||
result.current.setQuery(prev => ({ ...prev, tagIDs: ['tag1'] }))
|
||||
})
|
||||
|
||||
act(() => {
|
||||
result.current.setQuery(prev => ({ ...prev, isCreatedByMe: true }))
|
||||
})
|
||||
|
||||
expect(result.current.query.keywords).toBe('first')
|
||||
expect(result.current.query.tagIDs).toEqual(['tag1'])
|
||||
expect(result.current.query.isCreatedByMe).toBe(true)
|
||||
})
|
||||
|
||||
it('should clear all filters', () => {
|
||||
mockSearchParams.set('tagIDs', 'tag1;tag2')
|
||||
mockSearchParams.set('keywords', 'search')
|
||||
mockSearchParams.set('isCreatedByMe', 'true')
|
||||
|
||||
const { result } = renderHook(() => useAppsQueryState())
|
||||
|
||||
act(() => {
|
||||
result.current.setQuery({
|
||||
tagIDs: undefined,
|
||||
keywords: undefined,
|
||||
isCreatedByMe: false,
|
||||
})
|
||||
})
|
||||
|
||||
expect(result.current.query.tagIDs).toBeUndefined()
|
||||
expect(result.current.query.keywords).toBeUndefined()
|
||||
expect(result.current.query.isCreatedByMe).toBe(false)
|
||||
})
|
||||
})
|
||||
})
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user