mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 08:58:09 +08:00
Merge branch 'main' into feat/r2
# Conflicts: # api/core/repositories/sqlalchemy_workflow_node_execution_repository.py # api/core/workflow/entities/node_entities.py # api/core/workflow/enums.py
This commit is contained in:
@ -14,7 +14,7 @@ from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models.account import Tenant
|
||||
from models.model import App, Conversation, Message
|
||||
from models.workflow import WorkflowNodeExecution, WorkflowRun
|
||||
from models.workflow import WorkflowNodeExecutionModel, WorkflowRun
|
||||
from services.billing_service import BillingService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -108,10 +108,11 @@ class ClearFreePlanTenantExpiredLogs:
|
||||
while True:
|
||||
with Session(db.engine).no_autoflush as session:
|
||||
workflow_node_executions = (
|
||||
session.query(WorkflowNodeExecution)
|
||||
session.query(WorkflowNodeExecutionModel)
|
||||
.filter(
|
||||
WorkflowNodeExecution.tenant_id == tenant_id,
|
||||
WorkflowNodeExecution.created_at < datetime.datetime.now() - datetime.timedelta(days=days),
|
||||
WorkflowNodeExecutionModel.tenant_id == tenant_id,
|
||||
WorkflowNodeExecutionModel.created_at
|
||||
< datetime.datetime.now() - datetime.timedelta(days=days),
|
||||
)
|
||||
.limit(batch)
|
||||
.all()
|
||||
@ -135,8 +136,8 @@ class ClearFreePlanTenantExpiredLogs:
|
||||
]
|
||||
|
||||
# delete workflow node executions
|
||||
session.query(WorkflowNodeExecution).filter(
|
||||
WorkflowNodeExecution.id.in_(workflow_node_execution_ids),
|
||||
session.query(WorkflowNodeExecutionModel).filter(
|
||||
WorkflowNodeExecutionModel.id.in_(workflow_node_execution_ids),
|
||||
).delete(synchronize_session=False)
|
||||
session.commit()
|
||||
|
||||
|
||||
@ -2,8 +2,11 @@ import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from core.app.app_config.entities import ModelConfig
|
||||
from core.model_runtime.entities import LLMMode
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.models.document import Document
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
@ -34,7 +37,29 @@ class HitTestingService:
|
||||
# get retrieval model , if the model is not setting , using default
|
||||
if not retrieval_model:
|
||||
retrieval_model = dataset.retrieval_model or default_retrieval_model
|
||||
document_ids_filter = None
|
||||
metadata_filtering_conditions = retrieval_model.get("metadata_filtering_conditions", {})
|
||||
if metadata_filtering_conditions:
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
|
||||
from core.app.app_config.entities import MetadataFilteringCondition
|
||||
|
||||
metadata_filtering_conditions = MetadataFilteringCondition(**metadata_filtering_conditions)
|
||||
|
||||
metadata_filter_document_ids, metadata_condition = dataset_retrieval.get_metadata_filter_condition(
|
||||
dataset_ids=[dataset.id],
|
||||
query=query,
|
||||
metadata_filtering_mode="manual",
|
||||
metadata_filtering_conditions=metadata_filtering_conditions,
|
||||
inputs={},
|
||||
tenant_id="",
|
||||
user_id="",
|
||||
metadata_model_config=ModelConfig(provider="", name="", mode=LLMMode.CHAT, completion_params={}),
|
||||
)
|
||||
if metadata_filter_document_ids:
|
||||
document_ids_filter = metadata_filter_document_ids.get(dataset.id, [])
|
||||
if metadata_condition and not document_ids_filter:
|
||||
return cls.compact_retrieve_response(query, [])
|
||||
all_documents = RetrievalService.retrieve(
|
||||
retrieval_method=retrieval_model.get("search_method", "semantic_search"),
|
||||
dataset_id=dataset.id,
|
||||
@ -48,6 +73,7 @@ class HitTestingService:
|
||||
else None,
|
||||
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
|
||||
weights=retrieval_model.get("weights", None),
|
||||
document_ids_filter=document_ids_filter,
|
||||
)
|
||||
|
||||
end = time.perf_counter()
|
||||
@ -99,7 +125,7 @@ class HitTestingService:
|
||||
return dict(cls.compact_external_retrieve_response(dataset, query, all_documents))
|
||||
|
||||
@classmethod
|
||||
def compact_retrieve_response(cls, query: str, documents: list[Document]):
|
||||
def compact_retrieve_response(cls, query: str, documents: list[Document]) -> dict[Any, Any]:
|
||||
records = RetrievalService.format_retrieval_documents(documents)
|
||||
|
||||
return {
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.ops.entities.config_entity import BaseTracingConfig
|
||||
from core.ops.ops_trace_manager import OpsTraceManager, provider_config_map
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, TraceAppConfig
|
||||
@ -92,13 +93,12 @@ class OpsService:
|
||||
except KeyError:
|
||||
return {"error": f"Invalid tracing provider: {tracing_provider}"}
|
||||
|
||||
config_class, other_keys = (
|
||||
provider_config_map[tracing_provider]["config_class"],
|
||||
provider_config_map[tracing_provider]["other_keys"],
|
||||
)
|
||||
# FIXME: ignore type error
|
||||
default_config_instance = config_class(**tracing_config) # type: ignore
|
||||
for key in other_keys: # type: ignore
|
||||
provider_config: dict[str, Any] = provider_config_map[tracing_provider]
|
||||
config_class: type[BaseTracingConfig] = provider_config["config_class"]
|
||||
other_keys: list[str] = provider_config["other_keys"]
|
||||
|
||||
default_config_instance: BaseTracingConfig = config_class(**tracing_config)
|
||||
for key in other_keys:
|
||||
if key in tracing_config and tracing_config[key] == "":
|
||||
tracing_config[key] = getattr(default_config_instance, key, None)
|
||||
|
||||
|
||||
@ -44,6 +44,17 @@ class TagService:
|
||||
results = [tag_binding.target_id for tag_binding in tag_bindings]
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def get_tag_by_tag_name(tag_type: str, current_tenant_id: str, tag_name: str) -> list:
|
||||
tags = (
|
||||
db.session.query(Tag)
|
||||
.filter(Tag.name == tag_name, Tag.tenant_id == current_tenant_id, Tag.type == tag_type)
|
||||
.all()
|
||||
)
|
||||
if not tags:
|
||||
return []
|
||||
return tags
|
||||
|
||||
@staticmethod
|
||||
def get_tags_by_target_id(tag_type: str, current_tenant_id: str, target_id: str) -> list:
|
||||
tags = (
|
||||
@ -62,6 +73,8 @@ class TagService:
|
||||
|
||||
@staticmethod
|
||||
def save_tags(args: dict) -> Tag:
|
||||
if TagService.get_tag_by_tag_name(args["type"], current_user.current_tenant_id, args["name"]):
|
||||
raise ValueError("Tag name already exists")
|
||||
tag = Tag(
|
||||
id=str(uuid.uuid4()),
|
||||
name=args["name"],
|
||||
@ -75,6 +88,8 @@ class TagService:
|
||||
|
||||
@staticmethod
|
||||
def update_tags(args: dict, tag_id: str) -> Tag:
|
||||
if TagService.get_tag_by_tag_name(args["type"], current_user.current_tenant_id, args["name"]):
|
||||
raise ValueError("Tag name already exists")
|
||||
tag = db.session.query(Tag).filter(Tag.id == tag_id).first()
|
||||
if not tag:
|
||||
raise NotFound("Tag not found")
|
||||
|
||||
@ -173,26 +173,27 @@ class WebsiteService:
|
||||
return crawl_status_data
|
||||
|
||||
@classmethod
|
||||
def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str) -> dict[Any, Any] | None:
|
||||
def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str) -> dict[str, Any] | None:
|
||||
credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider)
|
||||
# decrypt api_key
|
||||
api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key"))
|
||||
# FIXME data is redefine too many times here, use Any to ease the type checking, fix it later
|
||||
data: Any
|
||||
|
||||
if provider == "firecrawl":
|
||||
crawl_data: list[dict[str, Any]] | None = None
|
||||
file_key = "website_files/" + job_id + ".txt"
|
||||
if storage.exists(file_key):
|
||||
d = storage.load_once(file_key)
|
||||
if d:
|
||||
data = json.loads(d.decode("utf-8"))
|
||||
stored_data = storage.load_once(file_key)
|
||||
if stored_data:
|
||||
crawl_data = json.loads(stored_data.decode("utf-8"))
|
||||
else:
|
||||
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None))
|
||||
result = firecrawl_app.check_crawl_status(job_id)
|
||||
if result.get("status") != "completed":
|
||||
raise ValueError("Crawl job is not completed")
|
||||
data = result.get("data")
|
||||
if data:
|
||||
for item in data:
|
||||
crawl_data = result.get("data")
|
||||
|
||||
if crawl_data:
|
||||
for item in crawl_data:
|
||||
if item.get("source_url") == url:
|
||||
return dict(item)
|
||||
return None
|
||||
@ -211,23 +212,24 @@ class WebsiteService:
|
||||
raise ValueError("Failed to crawl")
|
||||
return dict(response.json().get("data", {}))
|
||||
else:
|
||||
api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key"))
|
||||
response = requests.post(
|
||||
# Get crawl status first
|
||||
status_response = requests.post(
|
||||
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
|
||||
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
|
||||
json={"taskId": job_id},
|
||||
)
|
||||
data = response.json().get("data", {})
|
||||
if data.get("status") != "completed":
|
||||
status_data = status_response.json().get("data", {})
|
||||
if status_data.get("status") != "completed":
|
||||
raise ValueError("Crawl job is not completed")
|
||||
|
||||
response = requests.post(
|
||||
# Get processed data
|
||||
data_response = requests.post(
|
||||
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
|
||||
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
|
||||
json={"taskId": job_id, "urls": list(data.get("processed", {}).keys())},
|
||||
json={"taskId": job_id, "urls": list(status_data.get("processed", {}).keys())},
|
||||
)
|
||||
data = response.json().get("data", {})
|
||||
for item in data.get("processed", {}).values():
|
||||
processed_data = data_response.json().get("data", {})
|
||||
for item in processed_data.get("processed", {}).values():
|
||||
if item.get("data", {}).get("url") == url:
|
||||
return dict(item.get("data", {}))
|
||||
return None
|
||||
|
||||
@ -4,9 +4,9 @@ from datetime import datetime
|
||||
from sqlalchemy import and_, func, or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus
|
||||
from models import App, EndUser, WorkflowAppLog, WorkflowRun
|
||||
from models.enums import CreatorUserRole
|
||||
from models.workflow import WorkflowRunStatus
|
||||
|
||||
|
||||
class WorkflowAppService:
|
||||
@ -16,7 +16,7 @@ class WorkflowAppService:
|
||||
session: Session,
|
||||
app_model: App,
|
||||
keyword: str | None = None,
|
||||
status: WorkflowRunStatus | None = None,
|
||||
status: WorkflowExecutionStatus | None = None,
|
||||
created_at_before: datetime | None = None,
|
||||
created_at_after: datetime | None = None,
|
||||
page: int = 1,
|
||||
|
||||
@ -4,14 +4,14 @@ from typing import Optional
|
||||
|
||||
import contexts
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.workflow.repository.workflow_node_execution_repository import OrderConfig
|
||||
from core.workflow.repositories.workflow_node_execution_repository import OrderConfig
|
||||
from extensions.ext_database import db
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from models import (
|
||||
Account,
|
||||
App,
|
||||
EndUser,
|
||||
WorkflowNodeExecution,
|
||||
WorkflowNodeExecutionModel,
|
||||
WorkflowRun,
|
||||
WorkflowRunTriggeredFrom,
|
||||
)
|
||||
@ -125,7 +125,7 @@ class WorkflowRunService:
|
||||
app_model: App,
|
||||
run_id: str,
|
||||
user: Account | EndUser,
|
||||
) -> Sequence[WorkflowNodeExecution]:
|
||||
) -> Sequence[WorkflowNodeExecutionModel]:
|
||||
"""
|
||||
Get workflow run node execution list
|
||||
"""
|
||||
|
||||
@ -13,7 +13,7 @@ from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.variables import Variable
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.node_execution_entities import NodeExecution, NodeExecutionStatus
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus
|
||||
from core.workflow.errors import WorkflowNodeRunFailedError
|
||||
from core.workflow.graph_engine.entities.event import InNodeEvent
|
||||
from core.workflow.nodes import NodeType
|
||||
@ -30,8 +30,7 @@ from models.model import App, AppMode
|
||||
from models.tools import WorkflowToolProvider
|
||||
from models.workflow import (
|
||||
Workflow,
|
||||
WorkflowNodeExecution,
|
||||
WorkflowNodeExecutionStatus,
|
||||
WorkflowNodeExecutionModel,
|
||||
WorkflowNodeExecutionTriggeredFrom,
|
||||
WorkflowType,
|
||||
)
|
||||
@ -254,7 +253,7 @@ class WorkflowService:
|
||||
|
||||
def run_draft_workflow_node(
|
||||
self, app_model: App, node_id: str, user_inputs: dict, account: Account
|
||||
) -> WorkflowNodeExecution:
|
||||
) -> WorkflowNodeExecutionModel:
|
||||
"""
|
||||
Run draft workflow node
|
||||
"""
|
||||
@ -296,7 +295,7 @@ class WorkflowService:
|
||||
|
||||
def run_free_workflow_node(
|
||||
self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any]
|
||||
) -> NodeExecution:
|
||||
) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Run draft workflow node
|
||||
"""
|
||||
@ -322,7 +321,7 @@ class WorkflowService:
|
||||
invoke_node_fn: Callable[[], tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]],
|
||||
start_at: float,
|
||||
node_id: str,
|
||||
) -> NodeExecution:
|
||||
) -> WorkflowNodeExecution:
|
||||
try:
|
||||
node_instance, generator = invoke_node_fn()
|
||||
|
||||
@ -374,7 +373,7 @@ class WorkflowService:
|
||||
error = e.error
|
||||
|
||||
# Create a NodeExecution domain model
|
||||
node_execution = NodeExecution(
|
||||
node_execution = WorkflowNodeExecution(
|
||||
id=str(uuid4()),
|
||||
workflow_id="", # This is a single-step execution, so no workflow ID
|
||||
index=1,
|
||||
@ -403,13 +402,13 @@ class WorkflowService:
|
||||
|
||||
# Map status from WorkflowNodeExecutionStatus to NodeExecutionStatus
|
||||
if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
node_execution.status = NodeExecutionStatus.SUCCEEDED
|
||||
node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
elif node_run_result.status == WorkflowNodeExecutionStatus.EXCEPTION:
|
||||
node_execution.status = NodeExecutionStatus.EXCEPTION
|
||||
node_execution.status = WorkflowNodeExecutionStatus.EXCEPTION
|
||||
node_execution.error = node_run_result.error
|
||||
else:
|
||||
# Set failed status and error
|
||||
node_execution.status = NodeExecutionStatus.FAILED
|
||||
node_execution.status = WorkflowNodeExecutionStatus.FAILED
|
||||
node_execution.error = error
|
||||
|
||||
return node_execution
|
||||
|
||||
Reference in New Issue
Block a user