mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 09:28:04 +08:00
orm filter -> where (#22801)
Signed-off-by: -LAN- <laipz8200@outlook.com> Co-authored-by: -LAN- <laipz8200@outlook.com> Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
@ -99,7 +99,7 @@ class BaseAgentRunner(AppRunner):
|
||||
# get how many agent thoughts have been created
|
||||
self.agent_thought_count = (
|
||||
db.session.query(MessageAgentThought)
|
||||
.filter(
|
||||
.where(
|
||||
MessageAgentThought.message_id == self.message.id,
|
||||
)
|
||||
.count()
|
||||
@ -336,7 +336,7 @@ class BaseAgentRunner(AppRunner):
|
||||
Save agent thought
|
||||
"""
|
||||
updated_agent_thought = (
|
||||
db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first()
|
||||
db.session.query(MessageAgentThought).where(MessageAgentThought.id == agent_thought.id).first()
|
||||
)
|
||||
if not updated_agent_thought:
|
||||
raise ValueError("agent thought not found")
|
||||
@ -496,7 +496,7 @@ class BaseAgentRunner(AppRunner):
|
||||
return result
|
||||
|
||||
def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage:
|
||||
files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
|
||||
files = db.session.query(MessageFile).where(MessageFile.message_id == message.id).all()
|
||||
if not files:
|
||||
return UserPromptMessage(content=message.query)
|
||||
if message.app_model_config:
|
||||
|
||||
@ -72,7 +72,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
app_config = self.application_generate_entity.app_config
|
||||
app_config = cast(AdvancedChatAppConfig, app_config)
|
||||
|
||||
app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
|
||||
app_record = db.session.query(App).where(App.id == app_config.app_id).first()
|
||||
if not app_record:
|
||||
raise ValueError("App not found")
|
||||
|
||||
|
||||
@ -45,7 +45,7 @@ class AgentChatAppRunner(AppRunner):
|
||||
app_config = application_generate_entity.app_config
|
||||
app_config = cast(AgentChatAppConfig, app_config)
|
||||
|
||||
app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
|
||||
app_record = db.session.query(App).where(App.id == app_config.app_id).first()
|
||||
if not app_record:
|
||||
raise ValueError("App not found")
|
||||
|
||||
@ -183,10 +183,10 @@ class AgentChatAppRunner(AppRunner):
|
||||
if {ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL}.intersection(model_schema.features or []):
|
||||
agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING
|
||||
|
||||
conversation_result = db.session.query(Conversation).filter(Conversation.id == conversation.id).first()
|
||||
conversation_result = db.session.query(Conversation).where(Conversation.id == conversation.id).first()
|
||||
if conversation_result is None:
|
||||
raise ValueError("Conversation not found")
|
||||
message_result = db.session.query(Message).filter(Message.id == message.id).first()
|
||||
message_result = db.session.query(Message).where(Message.id == message.id).first()
|
||||
if message_result is None:
|
||||
raise ValueError("Message not found")
|
||||
db.session.close()
|
||||
|
||||
@ -43,7 +43,7 @@ class ChatAppRunner(AppRunner):
|
||||
app_config = application_generate_entity.app_config
|
||||
app_config = cast(ChatAppConfig, app_config)
|
||||
|
||||
app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
|
||||
app_record = db.session.query(App).where(App.id == app_config.app_id).first()
|
||||
if not app_record:
|
||||
raise ValueError("App not found")
|
||||
|
||||
|
||||
@ -248,7 +248,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
"""
|
||||
message = (
|
||||
db.session.query(Message)
|
||||
.filter(
|
||||
.where(
|
||||
Message.id == message_id,
|
||||
Message.app_id == app_model.id,
|
||||
Message.from_source == ("api" if isinstance(user, EndUser) else "console"),
|
||||
|
||||
@ -36,7 +36,7 @@ class CompletionAppRunner(AppRunner):
|
||||
app_config = application_generate_entity.app_config
|
||||
app_config = cast(CompletionAppConfig, app_config)
|
||||
|
||||
app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
|
||||
app_record = db.session.query(App).where(App.id == app_config.app_id).first()
|
||||
if not app_record:
|
||||
raise ValueError("App not found")
|
||||
|
||||
|
||||
@ -85,7 +85,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
if conversation:
|
||||
app_model_config = (
|
||||
db.session.query(AppModelConfig)
|
||||
.filter(AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id)
|
||||
.where(AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id)
|
||||
.first()
|
||||
)
|
||||
|
||||
@ -259,7 +259,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
:param conversation_id: conversation id
|
||||
:return: conversation
|
||||
"""
|
||||
conversation = db.session.query(Conversation).filter(Conversation.id == conversation_id).first()
|
||||
conversation = db.session.query(Conversation).where(Conversation.id == conversation_id).first()
|
||||
|
||||
if not conversation:
|
||||
raise ConversationNotExistsError("Conversation not exists")
|
||||
@ -272,7 +272,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
:param message_id: message id
|
||||
:return: message
|
||||
"""
|
||||
message = db.session.query(Message).filter(Message.id == message_id).first()
|
||||
message = db.session.query(Message).where(Message.id == message_id).first()
|
||||
|
||||
if message is None:
|
||||
raise MessageNotExistsError("Message not exists")
|
||||
|
||||
@ -26,7 +26,7 @@ class AnnotationReplyFeature:
|
||||
:return:
|
||||
"""
|
||||
annotation_setting = (
|
||||
db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_record.id).first()
|
||||
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_record.id).first()
|
||||
)
|
||||
|
||||
if not annotation_setting:
|
||||
|
||||
@ -471,7 +471,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
:return:
|
||||
"""
|
||||
agent_thought: Optional[MessageAgentThought] = (
|
||||
db.session.query(MessageAgentThought).filter(MessageAgentThought.id == event.agent_thought_id).first()
|
||||
db.session.query(MessageAgentThought).where(MessageAgentThought.id == event.agent_thought_id).first()
|
||||
)
|
||||
|
||||
if agent_thought:
|
||||
|
||||
@ -81,7 +81,7 @@ class MessageCycleManager:
|
||||
def _generate_conversation_name_worker(self, flask_app: Flask, conversation_id: str, query: str):
|
||||
with flask_app.app_context():
|
||||
# get conversation and message
|
||||
conversation = db.session.query(Conversation).filter(Conversation.id == conversation_id).first()
|
||||
conversation = db.session.query(Conversation).where(Conversation.id == conversation_id).first()
|
||||
|
||||
if not conversation:
|
||||
return
|
||||
@ -140,7 +140,7 @@ class MessageCycleManager:
|
||||
:param event: event
|
||||
:return:
|
||||
"""
|
||||
message_file = db.session.query(MessageFile).filter(MessageFile.id == event.message_file_id).first()
|
||||
message_file = db.session.query(MessageFile).where(MessageFile.id == event.message_file_id).first()
|
||||
|
||||
if message_file and message_file.url is not None:
|
||||
# get tool file id
|
||||
|
||||
@ -49,7 +49,7 @@ class DatasetIndexToolCallbackHandler:
|
||||
for document in documents:
|
||||
if document.metadata is not None:
|
||||
document_id = document.metadata["document_id"]
|
||||
dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first()
|
||||
dataset_document = db.session.query(DatasetDocument).where(DatasetDocument.id == document_id).first()
|
||||
if not dataset_document:
|
||||
_logger.warning(
|
||||
"Expected DatasetDocument record to exist, but none was found, document_id=%s",
|
||||
@ -59,7 +59,7 @@ class DatasetIndexToolCallbackHandler:
|
||||
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
||||
child_chunk = (
|
||||
db.session.query(ChildChunk)
|
||||
.filter(
|
||||
.where(
|
||||
ChildChunk.index_node_id == document.metadata["doc_id"],
|
||||
ChildChunk.dataset_id == dataset_document.dataset_id,
|
||||
ChildChunk.document_id == dataset_document.id,
|
||||
@ -69,18 +69,18 @@ class DatasetIndexToolCallbackHandler:
|
||||
if child_chunk:
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(DocumentSegment.id == child_chunk.segment_id)
|
||||
.where(DocumentSegment.id == child_chunk.segment_id)
|
||||
.update(
|
||||
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False
|
||||
)
|
||||
)
|
||||
else:
|
||||
query = db.session.query(DocumentSegment).filter(
|
||||
query = db.session.query(DocumentSegment).where(
|
||||
DocumentSegment.index_node_id == document.metadata["doc_id"]
|
||||
)
|
||||
|
||||
if "dataset_id" in document.metadata:
|
||||
query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"])
|
||||
query = query.where(DocumentSegment.dataset_id == document.metadata["dataset_id"])
|
||||
|
||||
# add hit count to document segment
|
||||
query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False)
|
||||
|
||||
@ -191,7 +191,7 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
provider_record = (
|
||||
db.session.query(Provider)
|
||||
.filter(
|
||||
.where(
|
||||
Provider.tenant_id == self.tenant_id,
|
||||
Provider.provider_type == ProviderType.CUSTOM.value,
|
||||
Provider.provider_name.in_(provider_names),
|
||||
@ -351,7 +351,7 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
provider_model_record = (
|
||||
db.session.query(ProviderModel)
|
||||
.filter(
|
||||
.where(
|
||||
ProviderModel.tenant_id == self.tenant_id,
|
||||
ProviderModel.provider_name.in_(provider_names),
|
||||
ProviderModel.model_name == model,
|
||||
@ -481,7 +481,7 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
return (
|
||||
db.session.query(ProviderModelSetting)
|
||||
.filter(
|
||||
.where(
|
||||
ProviderModelSetting.tenant_id == self.tenant_id,
|
||||
ProviderModelSetting.provider_name.in_(provider_names),
|
||||
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
||||
@ -560,7 +560,7 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
return (
|
||||
db.session.query(LoadBalancingModelConfig)
|
||||
.filter(
|
||||
.where(
|
||||
LoadBalancingModelConfig.tenant_id == self.tenant_id,
|
||||
LoadBalancingModelConfig.provider_name.in_(provider_names),
|
||||
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
|
||||
@ -583,7 +583,7 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
load_balancing_config_count = (
|
||||
db.session.query(LoadBalancingModelConfig)
|
||||
.filter(
|
||||
.where(
|
||||
LoadBalancingModelConfig.tenant_id == self.tenant_id,
|
||||
LoadBalancingModelConfig.provider_name.in_(provider_names),
|
||||
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
|
||||
@ -627,7 +627,7 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
model_setting = (
|
||||
db.session.query(ProviderModelSetting)
|
||||
.filter(
|
||||
.where(
|
||||
ProviderModelSetting.tenant_id == self.tenant_id,
|
||||
ProviderModelSetting.provider_name.in_(provider_names),
|
||||
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
||||
@ -693,7 +693,7 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
preferred_model_provider = (
|
||||
db.session.query(TenantPreferredModelProvider)
|
||||
.filter(
|
||||
.where(
|
||||
TenantPreferredModelProvider.tenant_id == self.tenant_id,
|
||||
TenantPreferredModelProvider.provider_name.in_(provider_names),
|
||||
)
|
||||
|
||||
@ -32,7 +32,7 @@ class ApiExternalDataTool(ExternalDataTool):
|
||||
# get api_based_extension
|
||||
api_based_extension = (
|
||||
db.session.query(APIBasedExtension)
|
||||
.filter(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
|
||||
.where(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
@ -56,7 +56,7 @@ class ApiExternalDataTool(ExternalDataTool):
|
||||
# get api_based_extension
|
||||
api_based_extension = (
|
||||
db.session.query(APIBasedExtension)
|
||||
.filter(APIBasedExtension.tenant_id == self.tenant_id, APIBasedExtension.id == api_based_extension_id)
|
||||
.where(APIBasedExtension.tenant_id == self.tenant_id, APIBasedExtension.id == api_based_extension_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
|
||||
@ -15,7 +15,7 @@ def encrypt_token(tenant_id: str, token: str):
|
||||
from models.account import Tenant
|
||||
from models.engine import db
|
||||
|
||||
if not (tenant := db.session.query(Tenant).filter(Tenant.id == tenant_id).first()):
|
||||
if not (tenant := db.session.query(Tenant).where(Tenant.id == tenant_id).first()):
|
||||
raise ValueError(f"Tenant with id {tenant_id} not found")
|
||||
encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key)
|
||||
return base64.b64encode(encrypted_token).decode()
|
||||
|
||||
@ -59,7 +59,7 @@ class IndexingRunner:
|
||||
# get the process rule
|
||||
processing_rule = (
|
||||
db.session.query(DatasetProcessRule)
|
||||
.filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
|
||||
.where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
|
||||
.first()
|
||||
)
|
||||
if not processing_rule:
|
||||
@ -119,12 +119,12 @@ class IndexingRunner:
|
||||
db.session.delete(document_segment)
|
||||
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
||||
# delete child chunks
|
||||
db.session.query(ChildChunk).filter(ChildChunk.segment_id == document_segment.id).delete()
|
||||
db.session.query(ChildChunk).where(ChildChunk.segment_id == document_segment.id).delete()
|
||||
db.session.commit()
|
||||
# get the process rule
|
||||
processing_rule = (
|
||||
db.session.query(DatasetProcessRule)
|
||||
.filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
|
||||
.where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
|
||||
.first()
|
||||
)
|
||||
if not processing_rule:
|
||||
@ -212,7 +212,7 @@ class IndexingRunner:
|
||||
# get the process rule
|
||||
processing_rule = (
|
||||
db.session.query(DatasetProcessRule)
|
||||
.filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
|
||||
.where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
@ -316,7 +316,7 @@ class IndexingRunner:
|
||||
# delete image files and related db records
|
||||
image_upload_file_ids = get_image_upload_file_ids(document.page_content)
|
||||
for upload_file_id in image_upload_file_ids:
|
||||
image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first()
|
||||
image_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
|
||||
if image_file is None:
|
||||
continue
|
||||
try:
|
||||
@ -346,7 +346,7 @@ class IndexingRunner:
|
||||
raise ValueError("no upload file found")
|
||||
|
||||
file_detail = (
|
||||
db.session.query(UploadFile).filter(UploadFile.id == data_source_info["upload_file_id"]).one_or_none()
|
||||
db.session.query(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"]).one_or_none()
|
||||
)
|
||||
|
||||
if file_detail:
|
||||
@ -599,7 +599,7 @@ class IndexingRunner:
|
||||
keyword.create(documents)
|
||||
if dataset.indexing_technique != "high_quality":
|
||||
document_ids = [document.metadata["doc_id"] for document in documents]
|
||||
db.session.query(DocumentSegment).filter(
|
||||
db.session.query(DocumentSegment).where(
|
||||
DocumentSegment.document_id == document_id,
|
||||
DocumentSegment.dataset_id == dataset_id,
|
||||
DocumentSegment.index_node_id.in_(document_ids),
|
||||
@ -630,7 +630,7 @@ class IndexingRunner:
|
||||
index_processor.load(dataset, chunk_documents, with_keywords=False)
|
||||
|
||||
document_ids = [document.metadata["doc_id"] for document in chunk_documents]
|
||||
db.session.query(DocumentSegment).filter(
|
||||
db.session.query(DocumentSegment).where(
|
||||
DocumentSegment.document_id == dataset_document.id,
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.index_node_id.in_(document_ids),
|
||||
|
||||
@ -28,7 +28,7 @@ class MCPServerStreamableHTTPRequestHandler:
|
||||
):
|
||||
self.app = app
|
||||
self.request = request
|
||||
mcp_server = db.session.query(AppMCPServer).filter(AppMCPServer.app_id == self.app.id).first()
|
||||
mcp_server = db.session.query(AppMCPServer).where(AppMCPServer.app_id == self.app.id).first()
|
||||
if not mcp_server:
|
||||
raise ValueError("MCP server not found")
|
||||
self.mcp_server: AppMCPServer = mcp_server
|
||||
@ -192,7 +192,7 @@ class MCPServerStreamableHTTPRequestHandler:
|
||||
def retrieve_end_user(self):
|
||||
return (
|
||||
db.session.query(EndUser)
|
||||
.filter(EndUser.external_user_id == self.mcp_server.id, EndUser.type == "mcp")
|
||||
.where(EndUser.external_user_id == self.mcp_server.id, EndUser.type == "mcp")
|
||||
.first()
|
||||
)
|
||||
|
||||
|
||||
@ -67,7 +67,7 @@ class TokenBufferMemory:
|
||||
|
||||
prompt_messages: list[PromptMessage] = []
|
||||
for message in messages:
|
||||
files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
|
||||
files = db.session.query(MessageFile).where(MessageFile.message_id == message.id).all()
|
||||
if files:
|
||||
file_extra_config = None
|
||||
if self.conversation.mode in {AppMode.AGENT_CHAT, AppMode.COMPLETION, AppMode.CHAT}:
|
||||
|
||||
@ -89,7 +89,7 @@ class ApiModeration(Moderation):
|
||||
def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> Optional[APIBasedExtension]:
|
||||
extension = (
|
||||
db.session.query(APIBasedExtension)
|
||||
.filter(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
|
||||
.where(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
|
||||
@ -120,7 +120,7 @@ class AliyunDataTrace(BaseTraceInstance):
|
||||
user_id = message_data.from_account_id
|
||||
if message_data.from_end_user_id:
|
||||
end_user_data: Optional[EndUser] = (
|
||||
db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first()
|
||||
db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first()
|
||||
)
|
||||
if end_user_data is not None:
|
||||
user_id = end_user_data.session_id
|
||||
@ -244,14 +244,14 @@ class AliyunDataTrace(BaseTraceInstance):
|
||||
if not app_id:
|
||||
raise ValueError("No app_id found in trace_info metadata")
|
||||
|
||||
app = session.query(App).filter(App.id == app_id).first()
|
||||
app = session.query(App).where(App.id == app_id).first()
|
||||
if not app:
|
||||
raise ValueError(f"App with id {app_id} not found")
|
||||
|
||||
if not app.created_by:
|
||||
raise ValueError(f"App with id {app_id} has no creator (created_by is None)")
|
||||
|
||||
service_account = session.query(Account).filter(Account.id == app.created_by).first()
|
||||
service_account = session.query(Account).where(Account.id == app.created_by).first()
|
||||
if not service_account:
|
||||
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
|
||||
current_tenant = (
|
||||
|
||||
@ -297,7 +297,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
# Add end user data if available
|
||||
if trace_info.message_data.from_end_user_id:
|
||||
end_user_data: Optional[EndUser] = (
|
||||
db.session.query(EndUser).filter(EndUser.id == trace_info.message_data.from_end_user_id).first()
|
||||
db.session.query(EndUser).where(EndUser.id == trace_info.message_data.from_end_user_id).first()
|
||||
)
|
||||
if end_user_data is not None:
|
||||
message_metadata["end_user_id"] = end_user_data.session_id
|
||||
@ -703,7 +703,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
WorkflowNodeExecutionModel.process_data,
|
||||
WorkflowNodeExecutionModel.execution_metadata,
|
||||
)
|
||||
.filter(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id)
|
||||
.where(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id)
|
||||
.all()
|
||||
)
|
||||
return workflow_nodes
|
||||
|
||||
@ -44,14 +44,14 @@ class BaseTraceInstance(ABC):
|
||||
"""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# Get the app to find its creator
|
||||
app = session.query(App).filter(App.id == app_id).first()
|
||||
app = session.query(App).where(App.id == app_id).first()
|
||||
if not app:
|
||||
raise ValueError(f"App with id {app_id} not found")
|
||||
|
||||
if not app.created_by:
|
||||
raise ValueError(f"App with id {app_id} has no creator (created_by is None)")
|
||||
|
||||
service_account = session.query(Account).filter(Account.id == app.created_by).first()
|
||||
service_account = session.query(Account).where(Account.id == app.created_by).first()
|
||||
if not service_account:
|
||||
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
|
||||
|
||||
|
||||
@ -244,7 +244,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
||||
user_id = message_data.from_account_id
|
||||
if message_data.from_end_user_id:
|
||||
end_user_data: Optional[EndUser] = (
|
||||
db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first()
|
||||
db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first()
|
||||
)
|
||||
if end_user_data is not None:
|
||||
user_id = end_user_data.session_id
|
||||
|
||||
@ -262,7 +262,7 @@ class LangSmithDataTrace(BaseTraceInstance):
|
||||
|
||||
if message_data.from_end_user_id:
|
||||
end_user_data: Optional[EndUser] = (
|
||||
db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first()
|
||||
db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first()
|
||||
)
|
||||
if end_user_data is not None:
|
||||
end_user_id = end_user_data.session_id
|
||||
|
||||
@ -284,7 +284,7 @@ class OpikDataTrace(BaseTraceInstance):
|
||||
|
||||
if message_data.from_end_user_id:
|
||||
end_user_data: Optional[EndUser] = (
|
||||
db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first()
|
||||
db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first()
|
||||
)
|
||||
if end_user_data is not None:
|
||||
end_user_id = end_user_data.session_id
|
||||
|
||||
@ -218,7 +218,7 @@ class OpsTraceManager:
|
||||
"""
|
||||
trace_config_data: Optional[TraceAppConfig] = (
|
||||
db.session.query(TraceAppConfig)
|
||||
.filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
|
||||
.where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
|
||||
.first()
|
||||
)
|
||||
|
||||
@ -226,7 +226,7 @@ class OpsTraceManager:
|
||||
return None
|
||||
|
||||
# decrypt_token
|
||||
app = db.session.query(App).filter(App.id == app_id).first()
|
||||
app = db.session.query(App).where(App.id == app_id).first()
|
||||
if not app:
|
||||
raise ValueError("App not found")
|
||||
|
||||
@ -253,7 +253,7 @@ class OpsTraceManager:
|
||||
if app_id is None:
|
||||
return None
|
||||
|
||||
app: Optional[App] = db.session.query(App).filter(App.id == app_id).first()
|
||||
app: Optional[App] = db.session.query(App).where(App.id == app_id).first()
|
||||
|
||||
if app is None:
|
||||
return None
|
||||
@ -293,18 +293,18 @@ class OpsTraceManager:
|
||||
@classmethod
|
||||
def get_app_config_through_message_id(cls, message_id: str):
|
||||
app_model_config = None
|
||||
message_data = db.session.query(Message).filter(Message.id == message_id).first()
|
||||
message_data = db.session.query(Message).where(Message.id == message_id).first()
|
||||
if not message_data:
|
||||
return None
|
||||
conversation_id = message_data.conversation_id
|
||||
conversation_data = db.session.query(Conversation).filter(Conversation.id == conversation_id).first()
|
||||
conversation_data = db.session.query(Conversation).where(Conversation.id == conversation_id).first()
|
||||
if not conversation_data:
|
||||
return None
|
||||
|
||||
if conversation_data.app_model_config_id:
|
||||
app_model_config = (
|
||||
db.session.query(AppModelConfig)
|
||||
.filter(AppModelConfig.id == conversation_data.app_model_config_id)
|
||||
.where(AppModelConfig.id == conversation_data.app_model_config_id)
|
||||
.first()
|
||||
)
|
||||
elif conversation_data.app_model_config_id is None and conversation_data.override_model_configs:
|
||||
@ -331,7 +331,7 @@ class OpsTraceManager:
|
||||
if tracing_provider is not None:
|
||||
raise ValueError(f"Invalid tracing provider: {tracing_provider}")
|
||||
|
||||
app_config: Optional[App] = db.session.query(App).filter(App.id == app_id).first()
|
||||
app_config: Optional[App] = db.session.query(App).where(App.id == app_id).first()
|
||||
if not app_config:
|
||||
raise ValueError("App not found")
|
||||
app_config.tracing = json.dumps(
|
||||
@ -349,7 +349,7 @@ class OpsTraceManager:
|
||||
:param app_id: app id
|
||||
:return:
|
||||
"""
|
||||
app: Optional[App] = db.session.query(App).filter(App.id == app_id).first()
|
||||
app: Optional[App] = db.session.query(App).where(App.id == app_id).first()
|
||||
if not app:
|
||||
raise ValueError("App not found")
|
||||
if not app.tracing:
|
||||
|
||||
@ -3,6 +3,8 @@ from datetime import datetime
|
||||
from typing import Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.model import Message
|
||||
|
||||
@ -20,7 +22,7 @@ def filter_none_values(data: dict):
|
||||
|
||||
|
||||
def get_message_data(message_id: str):
|
||||
return db.session.query(Message).filter(Message.id == message_id).first()
|
||||
return db.session.scalar(select(Message).where(Message.id == message_id))
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
||||
@ -235,7 +235,7 @@ class WeaveDataTrace(BaseTraceInstance):
|
||||
|
||||
if message_data.from_end_user_id:
|
||||
end_user_data: Optional[EndUser] = (
|
||||
db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first()
|
||||
db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first()
|
||||
)
|
||||
if end_user_data is not None:
|
||||
end_user_id = end_user_data.session_id
|
||||
|
||||
@ -193,9 +193,9 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||
get the user by user id
|
||||
"""
|
||||
|
||||
user = db.session.query(EndUser).filter(EndUser.id == user_id).first()
|
||||
user = db.session.query(EndUser).where(EndUser.id == user_id).first()
|
||||
if not user:
|
||||
user = db.session.query(Account).filter(Account.id == user_id).first()
|
||||
user = db.session.query(Account).where(Account.id == user_id).first()
|
||||
|
||||
if not user:
|
||||
raise ValueError("user not found")
|
||||
@ -208,7 +208,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||
get app
|
||||
"""
|
||||
try:
|
||||
app = db.session.query(App).filter(App.id == app_id).filter(App.tenant_id == tenant_id).first()
|
||||
app = db.session.query(App).where(App.id == app_id).where(App.tenant_id == tenant_id).first()
|
||||
except Exception:
|
||||
raise ValueError("app not found")
|
||||
|
||||
|
||||
@ -275,7 +275,7 @@ class ProviderManager:
|
||||
# Get the corresponding TenantDefaultModel record
|
||||
default_model = (
|
||||
db.session.query(TenantDefaultModel)
|
||||
.filter(
|
||||
.where(
|
||||
TenantDefaultModel.tenant_id == tenant_id,
|
||||
TenantDefaultModel.model_type == model_type.to_origin_model_type(),
|
||||
)
|
||||
@ -367,7 +367,7 @@ class ProviderManager:
|
||||
# Get the list of available models from get_configurations and check if it is LLM
|
||||
default_model = (
|
||||
db.session.query(TenantDefaultModel)
|
||||
.filter(
|
||||
.where(
|
||||
TenantDefaultModel.tenant_id == tenant_id,
|
||||
TenantDefaultModel.model_type == model_type.to_origin_model_type(),
|
||||
)
|
||||
@ -541,7 +541,7 @@ class ProviderManager:
|
||||
db.session.rollback()
|
||||
existed_provider_record = (
|
||||
db.session.query(Provider)
|
||||
.filter(
|
||||
.where(
|
||||
Provider.tenant_id == tenant_id,
|
||||
Provider.provider_name == ModelProviderID(provider_name).provider_name,
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
|
||||
@ -93,11 +93,11 @@ class Jieba(BaseKeyword):
|
||||
|
||||
documents = []
|
||||
for chunk_index in sorted_chunk_indices:
|
||||
segment_query = db.session.query(DocumentSegment).filter(
|
||||
segment_query = db.session.query(DocumentSegment).where(
|
||||
DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id == chunk_index
|
||||
)
|
||||
if document_ids_filter:
|
||||
segment_query = segment_query.filter(DocumentSegment.document_id.in_(document_ids_filter))
|
||||
segment_query = segment_query.where(DocumentSegment.document_id.in_(document_ids_filter))
|
||||
segment = segment_query.first()
|
||||
|
||||
if segment:
|
||||
@ -214,7 +214,7 @@ class Jieba(BaseKeyword):
|
||||
def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: list[str]):
|
||||
document_segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id == node_id)
|
||||
.where(DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id == node_id)
|
||||
.first()
|
||||
)
|
||||
if document_segment:
|
||||
|
||||
@ -127,7 +127,7 @@ class RetrievalService:
|
||||
external_retrieval_model: Optional[dict] = None,
|
||||
metadata_filtering_conditions: Optional[dict] = None,
|
||||
):
|
||||
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
return []
|
||||
metadata_condition = (
|
||||
@ -145,7 +145,7 @@ class RetrievalService:
|
||||
@classmethod
|
||||
def _get_dataset(cls, dataset_id: str) -> Optional[Dataset]:
|
||||
with Session(db.engine) as session:
|
||||
return session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
return session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
|
||||
@classmethod
|
||||
def keyword_search(
|
||||
@ -294,7 +294,7 @@ class RetrievalService:
|
||||
dataset_documents = {
|
||||
doc.id: doc
|
||||
for doc in db.session.query(DatasetDocument)
|
||||
.filter(DatasetDocument.id.in_(document_ids))
|
||||
.where(DatasetDocument.id.in_(document_ids))
|
||||
.options(load_only(DatasetDocument.id, DatasetDocument.doc_form, DatasetDocument.dataset_id))
|
||||
.all()
|
||||
}
|
||||
@ -318,7 +318,7 @@ class RetrievalService:
|
||||
child_index_node_id = document.metadata.get("doc_id")
|
||||
|
||||
child_chunk = (
|
||||
db.session.query(ChildChunk).filter(ChildChunk.index_node_id == child_index_node_id).first()
|
||||
db.session.query(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id).first()
|
||||
)
|
||||
|
||||
if not child_chunk:
|
||||
@ -326,7 +326,7 @@ class RetrievalService:
|
||||
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(
|
||||
.where(
|
||||
DocumentSegment.dataset_id == dataset_document.dataset_id,
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.status == "completed",
|
||||
@ -381,7 +381,7 @@ class RetrievalService:
|
||||
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(
|
||||
.where(
|
||||
DocumentSegment.dataset_id == dataset_document.dataset_id,
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.status == "completed",
|
||||
|
||||
@ -443,7 +443,7 @@ class QdrantVectorFactory(AbstractVectorFactory):
|
||||
if dataset.collection_binding_id:
|
||||
dataset_collection_binding = (
|
||||
db.session.query(DatasetCollectionBinding)
|
||||
.filter(DatasetCollectionBinding.id == dataset.collection_binding_id)
|
||||
.where(DatasetCollectionBinding.id == dataset.collection_binding_id)
|
||||
.one_or_none()
|
||||
)
|
||||
if dataset_collection_binding:
|
||||
|
||||
@ -418,13 +418,13 @@ class TidbOnQdrantVector(BaseVector):
|
||||
class TidbOnQdrantVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TidbOnQdrantVector:
|
||||
tidb_auth_binding = (
|
||||
db.session.query(TidbAuthBinding).filter(TidbAuthBinding.tenant_id == dataset.tenant_id).one_or_none()
|
||||
db.session.query(TidbAuthBinding).where(TidbAuthBinding.tenant_id == dataset.tenant_id).one_or_none()
|
||||
)
|
||||
if not tidb_auth_binding:
|
||||
with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900):
|
||||
tidb_auth_binding = (
|
||||
db.session.query(TidbAuthBinding)
|
||||
.filter(TidbAuthBinding.tenant_id == dataset.tenant_id)
|
||||
.where(TidbAuthBinding.tenant_id == dataset.tenant_id)
|
||||
.one_or_none()
|
||||
)
|
||||
if tidb_auth_binding:
|
||||
@ -433,7 +433,7 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
|
||||
else:
|
||||
idle_tidb_auth_binding = (
|
||||
db.session.query(TidbAuthBinding)
|
||||
.filter(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE")
|
||||
.where(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE")
|
||||
.limit(1)
|
||||
.one_or_none()
|
||||
)
|
||||
|
||||
@ -47,7 +47,7 @@ class Vector:
|
||||
if dify_config.VECTOR_STORE_WHITELIST_ENABLE:
|
||||
whitelist = (
|
||||
db.session.query(Whitelist)
|
||||
.filter(Whitelist.tenant_id == self._dataset.tenant_id, Whitelist.category == "vector_db")
|
||||
.where(Whitelist.tenant_id == self._dataset.tenant_id, Whitelist.category == "vector_db")
|
||||
.one_or_none()
|
||||
)
|
||||
if whitelist:
|
||||
|
||||
@ -42,7 +42,7 @@ class DatasetDocumentStore:
|
||||
@property
|
||||
def docs(self) -> dict[str, Document]:
|
||||
document_segments = (
|
||||
db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == self._dataset.id).all()
|
||||
db.session.query(DocumentSegment).where(DocumentSegment.dataset_id == self._dataset.id).all()
|
||||
)
|
||||
|
||||
output = {}
|
||||
@ -63,7 +63,7 @@ class DatasetDocumentStore:
|
||||
def add_documents(self, docs: Sequence[Document], allow_update: bool = True, save_child: bool = False) -> None:
|
||||
max_position = (
|
||||
db.session.query(func.max(DocumentSegment.position))
|
||||
.filter(DocumentSegment.document_id == self._document_id)
|
||||
.where(DocumentSegment.document_id == self._document_id)
|
||||
.scalar()
|
||||
)
|
||||
|
||||
@ -147,7 +147,7 @@ class DatasetDocumentStore:
|
||||
segment_document.tokens = tokens
|
||||
if save_child and doc.children:
|
||||
# delete the existing child chunks
|
||||
db.session.query(ChildChunk).filter(
|
||||
db.session.query(ChildChunk).where(
|
||||
ChildChunk.tenant_id == self._dataset.tenant_id,
|
||||
ChildChunk.dataset_id == self._dataset.id,
|
||||
ChildChunk.document_id == self._document_id,
|
||||
@ -230,7 +230,7 @@ class DatasetDocumentStore:
|
||||
def get_document_segment(self, doc_id: str) -> Optional[DocumentSegment]:
|
||||
document_segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(DocumentSegment.dataset_id == self._dataset.id, DocumentSegment.index_node_id == doc_id)
|
||||
.where(DocumentSegment.dataset_id == self._dataset.id, DocumentSegment.index_node_id == doc_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
|
||||
@ -366,7 +366,7 @@ class NotionExtractor(BaseExtractor):
|
||||
def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str:
|
||||
data_source_binding = (
|
||||
db.session.query(DataSourceOauthBinding)
|
||||
.filter(
|
||||
.where(
|
||||
db.and_(
|
||||
DataSourceOauthBinding.tenant_id == tenant_id,
|
||||
DataSourceOauthBinding.provider == "notion",
|
||||
|
||||
@ -118,7 +118,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
child_node_ids = (
|
||||
db.session.query(ChildChunk.index_node_id)
|
||||
.join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id)
|
||||
.filter(
|
||||
.where(
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.index_node_id.in_(node_ids),
|
||||
ChildChunk.dataset_id == dataset.id,
|
||||
@ -128,7 +128,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
child_node_ids = [child_node_id[0] for child_node_id in child_node_ids]
|
||||
vector.delete_by_ids(child_node_ids)
|
||||
if delete_child_chunks:
|
||||
db.session.query(ChildChunk).filter(
|
||||
db.session.query(ChildChunk).where(
|
||||
ChildChunk.dataset_id == dataset.id, ChildChunk.index_node_id.in_(child_node_ids)
|
||||
).delete()
|
||||
db.session.commit()
|
||||
@ -136,7 +136,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
vector.delete()
|
||||
|
||||
if delete_child_chunks:
|
||||
db.session.query(ChildChunk).filter(ChildChunk.dataset_id == dataset.id).delete()
|
||||
db.session.query(ChildChunk).where(ChildChunk.dataset_id == dataset.id).delete()
|
||||
db.session.commit()
|
||||
|
||||
def retrieve(
|
||||
|
||||
@ -135,7 +135,7 @@ class DatasetRetrieval:
|
||||
available_datasets = []
|
||||
for dataset_id in dataset_ids:
|
||||
# get dataset from dataset id
|
||||
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
|
||||
# pass if dataset is not available
|
||||
if not dataset:
|
||||
@ -242,7 +242,7 @@ class DatasetRetrieval:
|
||||
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
|
||||
document = (
|
||||
db.session.query(DatasetDocument)
|
||||
.filter(
|
||||
.where(
|
||||
DatasetDocument.id == segment.document_id,
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
@ -327,7 +327,7 @@ class DatasetRetrieval:
|
||||
|
||||
if dataset_id:
|
||||
# get retrieval model config
|
||||
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if dataset:
|
||||
results = []
|
||||
if dataset.provider == "external":
|
||||
@ -516,14 +516,14 @@ class DatasetRetrieval:
|
||||
if document.metadata is not None:
|
||||
dataset_document = (
|
||||
db.session.query(DatasetDocument)
|
||||
.filter(DatasetDocument.id == document.metadata["document_id"])
|
||||
.where(DatasetDocument.id == document.metadata["document_id"])
|
||||
.first()
|
||||
)
|
||||
if dataset_document:
|
||||
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
||||
child_chunk = (
|
||||
db.session.query(ChildChunk)
|
||||
.filter(
|
||||
.where(
|
||||
ChildChunk.index_node_id == document.metadata["doc_id"],
|
||||
ChildChunk.dataset_id == dataset_document.dataset_id,
|
||||
ChildChunk.document_id == dataset_document.id,
|
||||
@ -533,7 +533,7 @@ class DatasetRetrieval:
|
||||
if child_chunk:
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(DocumentSegment.id == child_chunk.segment_id)
|
||||
.where(DocumentSegment.id == child_chunk.segment_id)
|
||||
.update(
|
||||
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
|
||||
synchronize_session=False,
|
||||
@ -541,13 +541,13 @@ class DatasetRetrieval:
|
||||
)
|
||||
db.session.commit()
|
||||
else:
|
||||
query = db.session.query(DocumentSegment).filter(
|
||||
query = db.session.query(DocumentSegment).where(
|
||||
DocumentSegment.index_node_id == document.metadata["doc_id"]
|
||||
)
|
||||
|
||||
# if 'dataset_id' in document.metadata:
|
||||
if "dataset_id" in document.metadata:
|
||||
query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"])
|
||||
query = query.where(DocumentSegment.dataset_id == document.metadata["dataset_id"])
|
||||
|
||||
# add hit count to document segment
|
||||
query.update(
|
||||
@ -600,7 +600,7 @@ class DatasetRetrieval:
|
||||
):
|
||||
with flask_app.app_context():
|
||||
with Session(db.engine) as session:
|
||||
dataset = session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
|
||||
if not dataset:
|
||||
return []
|
||||
@ -685,7 +685,7 @@ class DatasetRetrieval:
|
||||
available_datasets = []
|
||||
for dataset_id in dataset_ids:
|
||||
# get dataset from dataset id
|
||||
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
|
||||
# pass if dataset is not available
|
||||
if not dataset:
|
||||
@ -862,7 +862,7 @@ class DatasetRetrieval:
|
||||
metadata_filtering_conditions: Optional[MetadataFilteringCondition],
|
||||
inputs: dict,
|
||||
) -> tuple[Optional[dict[str, list[str]]], Optional[MetadataCondition]]:
|
||||
document_query = db.session.query(DatasetDocument).filter(
|
||||
document_query = db.session.query(DatasetDocument).where(
|
||||
DatasetDocument.dataset_id.in_(dataset_ids),
|
||||
DatasetDocument.indexing_status == "completed",
|
||||
DatasetDocument.enabled == True,
|
||||
@ -930,9 +930,9 @@ class DatasetRetrieval:
|
||||
raise ValueError("Invalid metadata filtering mode")
|
||||
if filters:
|
||||
if metadata_filtering_conditions and metadata_filtering_conditions.logical_operator == "and": # type: ignore
|
||||
document_query = document_query.filter(and_(*filters))
|
||||
document_query = document_query.where(and_(*filters))
|
||||
else:
|
||||
document_query = document_query.filter(or_(*filters))
|
||||
document_query = document_query.where(or_(*filters))
|
||||
documents = document_query.all()
|
||||
# group by dataset_id
|
||||
metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore
|
||||
@ -958,7 +958,7 @@ class DatasetRetrieval:
|
||||
self, dataset_ids: list, query: str, tenant_id: str, user_id: str, metadata_model_config: ModelConfig
|
||||
) -> Optional[list[dict[str, Any]]]:
|
||||
# get all metadata field
|
||||
metadata_fields = db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id.in_(dataset_ids)).all()
|
||||
metadata_fields = db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids)).all()
|
||||
all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields]
|
||||
# get metadata model config
|
||||
if metadata_model_config is None:
|
||||
|
||||
@ -178,7 +178,7 @@ class ApiToolProviderController(ToolProviderController):
|
||||
# get tenant api providers
|
||||
db_providers: list[ApiToolProvider] = (
|
||||
db.session.query(ApiToolProvider)
|
||||
.filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.entity.identity.name)
|
||||
.where(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.entity.identity.name)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
@ -160,7 +160,7 @@ class ToolFileManager:
|
||||
with Session(self._engine, expire_on_commit=False) as session:
|
||||
tool_file: ToolFile | None = (
|
||||
session.query(ToolFile)
|
||||
.filter(
|
||||
.where(
|
||||
ToolFile.id == id,
|
||||
)
|
||||
.first()
|
||||
@ -184,7 +184,7 @@ class ToolFileManager:
|
||||
with Session(self._engine, expire_on_commit=False) as session:
|
||||
message_file: MessageFile | None = (
|
||||
session.query(MessageFile)
|
||||
.filter(
|
||||
.where(
|
||||
MessageFile.id == id,
|
||||
)
|
||||
.first()
|
||||
@ -204,7 +204,7 @@ class ToolFileManager:
|
||||
|
||||
tool_file: ToolFile | None = (
|
||||
session.query(ToolFile)
|
||||
.filter(
|
||||
.where(
|
||||
ToolFile.id == tool_file_id,
|
||||
)
|
||||
.first()
|
||||
@ -228,7 +228,7 @@ class ToolFileManager:
|
||||
with Session(self._engine, expire_on_commit=False) as session:
|
||||
tool_file: ToolFile | None = (
|
||||
session.query(ToolFile)
|
||||
.filter(
|
||||
.where(
|
||||
ToolFile.id == tool_file_id,
|
||||
)
|
||||
.first()
|
||||
|
||||
@ -29,7 +29,7 @@ class ToolLabelManager:
|
||||
raise ValueError("Unsupported tool type")
|
||||
|
||||
# delete old labels
|
||||
db.session.query(ToolLabelBinding).filter(ToolLabelBinding.tool_id == provider_id).delete()
|
||||
db.session.query(ToolLabelBinding).where(ToolLabelBinding.tool_id == provider_id).delete()
|
||||
|
||||
# insert new labels
|
||||
for label in labels:
|
||||
@ -57,7 +57,7 @@ class ToolLabelManager:
|
||||
|
||||
labels = (
|
||||
db.session.query(ToolLabelBinding.label_name)
|
||||
.filter(
|
||||
.where(
|
||||
ToolLabelBinding.tool_id == provider_id,
|
||||
ToolLabelBinding.tool_type == controller.provider_type.value,
|
||||
)
|
||||
@ -90,7 +90,7 @@ class ToolLabelManager:
|
||||
provider_ids.append(controller.provider_id)
|
||||
|
||||
labels: list[ToolLabelBinding] = (
|
||||
db.session.query(ToolLabelBinding).filter(ToolLabelBinding.tool_id.in_(provider_ids)).all()
|
||||
db.session.query(ToolLabelBinding).where(ToolLabelBinding.tool_id.in_(provider_ids)).all()
|
||||
)
|
||||
|
||||
tool_labels: dict[str, list[str]] = {label.tool_id: [] for label in labels}
|
||||
|
||||
@ -198,7 +198,7 @@ class ToolManager:
|
||||
try:
|
||||
builtin_provider = (
|
||||
db.session.query(BuiltinToolProvider)
|
||||
.filter(
|
||||
.where(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.id == credential_id,
|
||||
)
|
||||
@ -216,7 +216,7 @@ class ToolManager:
|
||||
# use the default provider
|
||||
builtin_provider = (
|
||||
db.session.query(BuiltinToolProvider)
|
||||
.filter(
|
||||
.where(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
(BuiltinToolProvider.provider == str(provider_id_entity))
|
||||
| (BuiltinToolProvider.provider == provider_id_entity.provider_name),
|
||||
@ -229,7 +229,7 @@ class ToolManager:
|
||||
else:
|
||||
builtin_provider = (
|
||||
db.session.query(BuiltinToolProvider)
|
||||
.filter(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id))
|
||||
.where(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id))
|
||||
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
|
||||
.first()
|
||||
)
|
||||
@ -316,7 +316,7 @@ class ToolManager:
|
||||
elif provider_type == ToolProviderType.WORKFLOW:
|
||||
workflow_provider = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
|
||||
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
@ -616,7 +616,7 @@ class ToolManager:
|
||||
ORDER BY tenant_id, provider, is_default DESC, created_at DESC
|
||||
"""
|
||||
ids = [row.id for row in db.session.execute(db.text(sql), {"tenant_id": tenant_id}).all()]
|
||||
return db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.id.in_(ids)).all()
|
||||
return db.session.query(BuiltinToolProvider).where(BuiltinToolProvider.id.in_(ids)).all()
|
||||
|
||||
@classmethod
|
||||
def list_providers_from_api(
|
||||
@ -664,7 +664,7 @@ class ToolManager:
|
||||
# get db api providers
|
||||
if "api" in filters:
|
||||
db_api_providers: list[ApiToolProvider] = (
|
||||
db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all()
|
||||
db.session.query(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id).all()
|
||||
)
|
||||
|
||||
api_provider_controllers: list[dict[str, Any]] = [
|
||||
@ -687,7 +687,7 @@ class ToolManager:
|
||||
if "workflow" in filters:
|
||||
# get workflow providers
|
||||
workflow_providers: list[WorkflowToolProvider] = (
|
||||
db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all()
|
||||
db.session.query(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id).all()
|
||||
)
|
||||
|
||||
workflow_provider_controllers: list[WorkflowToolProviderController] = []
|
||||
@ -731,7 +731,7 @@ class ToolManager:
|
||||
"""
|
||||
provider: ApiToolProvider | None = (
|
||||
db.session.query(ApiToolProvider)
|
||||
.filter(
|
||||
.where(
|
||||
ApiToolProvider.id == provider_id,
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
)
|
||||
@ -768,7 +768,7 @@ class ToolManager:
|
||||
"""
|
||||
provider: MCPToolProvider | None = (
|
||||
db.session.query(MCPToolProvider)
|
||||
.filter(
|
||||
.where(
|
||||
MCPToolProvider.server_identifier == provider_id,
|
||||
MCPToolProvider.tenant_id == tenant_id,
|
||||
)
|
||||
@ -793,7 +793,7 @@ class ToolManager:
|
||||
provider_name = provider
|
||||
provider_obj: ApiToolProvider | None = (
|
||||
db.session.query(ApiToolProvider)
|
||||
.filter(
|
||||
.where(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.name == provider,
|
||||
)
|
||||
@ -885,7 +885,7 @@ class ToolManager:
|
||||
try:
|
||||
workflow_provider: WorkflowToolProvider | None = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
|
||||
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
@ -902,7 +902,7 @@ class ToolManager:
|
||||
try:
|
||||
api_provider: ApiToolProvider | None = (
|
||||
db.session.query(ApiToolProvider)
|
||||
.filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id)
|
||||
.where(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
@ -919,7 +919,7 @@ class ToolManager:
|
||||
try:
|
||||
mcp_provider: MCPToolProvider | None = (
|
||||
db.session.query(MCPToolProvider)
|
||||
.filter(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == provider_id)
|
||||
.where(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == provider_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
|
||||
@ -87,7 +87,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
||||
index_node_ids = [document.metadata["doc_id"] for document in all_documents if document.metadata]
|
||||
segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(
|
||||
.where(
|
||||
DocumentSegment.dataset_id.in_(self.dataset_ids),
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.status == "completed",
|
||||
@ -114,7 +114,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
||||
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
|
||||
document = (
|
||||
db.session.query(Document)
|
||||
.filter(
|
||||
.where(
|
||||
Document.id == segment.document_id,
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
@ -163,7 +163,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
||||
):
|
||||
with flask_app.app_context():
|
||||
dataset = (
|
||||
db.session.query(Dataset).filter(Dataset.tenant_id == self.tenant_id, Dataset.id == dataset_id).first()
|
||||
db.session.query(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == dataset_id).first()
|
||||
)
|
||||
|
||||
if not dataset:
|
||||
|
||||
@ -57,7 +57,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
|
||||
def _run(self, query: str) -> str:
|
||||
dataset = (
|
||||
db.session.query(Dataset).filter(Dataset.tenant_id == self.tenant_id, Dataset.id == self.dataset_id).first()
|
||||
db.session.query(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == self.dataset_id).first()
|
||||
)
|
||||
|
||||
if not dataset:
|
||||
@ -190,7 +190,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
|
||||
document = (
|
||||
db.session.query(DatasetDocument) # type: ignore
|
||||
.filter(
|
||||
.where(
|
||||
DatasetDocument.id == segment.document_id,
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
|
||||
@ -84,7 +84,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
"""
|
||||
workflow: Workflow | None = (
|
||||
db.session.query(Workflow)
|
||||
.filter(Workflow.app_id == db_provider.app_id, Workflow.version == db_provider.version)
|
||||
.where(Workflow.app_id == db_provider.app_id, Workflow.version == db_provider.version)
|
||||
.first()
|
||||
)
|
||||
|
||||
@ -190,7 +190,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
|
||||
db_providers: WorkflowToolProvider | None = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
.filter(
|
||||
.where(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
WorkflowToolProvider.app_id == self.provider_id,
|
||||
)
|
||||
|
||||
@ -142,12 +142,12 @@ class WorkflowTool(Tool):
|
||||
if not version:
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
.filter(Workflow.app_id == app_id, Workflow.version != "draft")
|
||||
.where(Workflow.app_id == app_id, Workflow.version != "draft")
|
||||
.order_by(Workflow.created_at.desc())
|
||||
.first()
|
||||
)
|
||||
else:
|
||||
workflow = db.session.query(Workflow).filter(Workflow.app_id == app_id, Workflow.version == version).first()
|
||||
workflow = db.session.query(Workflow).where(Workflow.app_id == app_id, Workflow.version == version).first()
|
||||
|
||||
if not workflow:
|
||||
raise ValueError("workflow not found or not published")
|
||||
@ -158,7 +158,7 @@ class WorkflowTool(Tool):
|
||||
"""
|
||||
get the app by app id
|
||||
"""
|
||||
app = db.session.query(App).filter(App.id == app_id).first()
|
||||
app = db.session.query(App).where(App.id == app_id).first()
|
||||
if not app:
|
||||
raise ValueError("app not found")
|
||||
|
||||
|
||||
@ -228,7 +228,7 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
# Subquery: Count the number of available documents for each dataset
|
||||
subquery = (
|
||||
db.session.query(Document.dataset_id, func.count(Document.id).label("available_document_count"))
|
||||
.filter(
|
||||
.where(
|
||||
Document.indexing_status == "completed",
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
@ -242,8 +242,8 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
results = (
|
||||
db.session.query(Dataset)
|
||||
.outerjoin(subquery, Dataset.id == subquery.c.dataset_id)
|
||||
.filter(Dataset.tenant_id == self.tenant_id, Dataset.id.in_(dataset_ids))
|
||||
.filter((subquery.c.available_document_count > 0) | (Dataset.provider == "external"))
|
||||
.where(Dataset.tenant_id == self.tenant_id, Dataset.id.in_(dataset_ids))
|
||||
.where((subquery.c.available_document_count > 0) | (Dataset.provider == "external"))
|
||||
.all()
|
||||
)
|
||||
|
||||
@ -370,7 +370,7 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() # type: ignore
|
||||
document = (
|
||||
db.session.query(Document)
|
||||
.filter(
|
||||
.where(
|
||||
Document.id == segment.document_id,
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
@ -415,7 +415,7 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
def _get_metadata_filter_condition(
|
||||
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
|
||||
) -> tuple[Optional[dict[str, list[str]]], Optional[MetadataCondition]]:
|
||||
document_query = db.session.query(Document).filter(
|
||||
document_query = db.session.query(Document).where(
|
||||
Document.dataset_id.in_(dataset_ids),
|
||||
Document.indexing_status == "completed",
|
||||
Document.enabled == True,
|
||||
@ -493,9 +493,9 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
node_data.metadata_filtering_conditions
|
||||
and node_data.metadata_filtering_conditions.logical_operator == "and"
|
||||
): # type: ignore
|
||||
document_query = document_query.filter(and_(*filters))
|
||||
document_query = document_query.where(and_(*filters))
|
||||
else:
|
||||
document_query = document_query.filter(or_(*filters))
|
||||
document_query = document_query.where(or_(*filters))
|
||||
documents = document_query.all()
|
||||
# group by dataset_id
|
||||
metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore
|
||||
@ -507,7 +507,7 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
|
||||
) -> list[dict[str, Any]]:
|
||||
# get all metadata field
|
||||
metadata_fields = db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id.in_(dataset_ids)).all()
|
||||
metadata_fields = db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids)).all()
|
||||
all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields]
|
||||
if node_data.metadata_model_config is None:
|
||||
raise ValueError("metadata_model_config is required")
|
||||
|
||||
Reference in New Issue
Block a user