Merge branch 'main' into feat/support-agent-sandbox

This commit is contained in:
Novice
2026-03-24 17:12:13 +08:00
136 changed files with 4018 additions and 706 deletions

View File

@ -241,7 +241,7 @@ class AppService:
class ArgsDict(TypedDict):
name: str
description: str
icon_type: str
icon_type: IconType | str | None
icon: str
icon_background: str
use_icon_as_answer_icon: bool
@ -257,7 +257,13 @@ class AppService:
assert current_user is not None
app.name = args["name"]
app.description = args["description"]
app.icon_type = IconType(args["icon_type"]) if args["icon_type"] else None
icon_type = args.get("icon_type")
if icon_type is None:
resolved_icon_type = app.icon_type
else:
resolved_icon_type = IconType(icon_type)
app.icon_type = resolved_icon_type
app.icon = args["icon"]
app.icon_background = args["icon_background"]
app.use_icon_as_answer_icon = args.get("use_icon_as_answer_icon", False)

View File

@ -1,8 +1,16 @@
from abc import ABC, abstractmethod
from typing import Any
from typing_extensions import TypedDict
class AuthCredentials(TypedDict):
auth_type: str
config: dict[str, Any]
class ApiKeyAuthBase(ABC):
def __init__(self, credentials: dict):
def __init__(self, credentials: AuthCredentials):
self.credentials = credentials
@abstractmethod

View File

@ -1,9 +1,9 @@
from services.auth.api_key_auth_base import ApiKeyAuthBase
from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials
from services.auth.auth_type import AuthType
class ApiKeyAuthFactory:
def __init__(self, provider: str, credentials: dict):
def __init__(self, provider: str, credentials: AuthCredentials):
auth_factory = self.get_apikey_auth_factory(provider)
self.auth = auth_factory(credentials)

View File

@ -2,11 +2,11 @@ import json
import httpx
from services.auth.api_key_auth_base import ApiKeyAuthBase
from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials
class FirecrawlAuth(ApiKeyAuthBase):
def __init__(self, credentials: dict):
def __init__(self, credentials: AuthCredentials):
super().__init__(credentials)
auth_type = credentials.get("auth_type")
if auth_type != "bearer":

View File

@ -2,11 +2,11 @@ import json
import httpx
from services.auth.api_key_auth_base import ApiKeyAuthBase
from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials
class JinaAuth(ApiKeyAuthBase):
def __init__(self, credentials: dict):
def __init__(self, credentials: AuthCredentials):
super().__init__(credentials)
auth_type = credentials.get("auth_type")
if auth_type != "bearer":

View File

@ -2,11 +2,11 @@ import json
import httpx
from services.auth.api_key_auth_base import ApiKeyAuthBase
from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials
class JinaAuth(ApiKeyAuthBase):
def __init__(self, credentials: dict):
def __init__(self, credentials: AuthCredentials):
super().__init__(credentials)
auth_type = credentials.get("auth_type")
if auth_type != "bearer":

View File

@ -3,11 +3,11 @@ from urllib.parse import urljoin
import httpx
from services.auth.api_key_auth_base import ApiKeyAuthBase
from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials
class WatercrawlAuth(ApiKeyAuthBase):
def __init__(self, credentials: dict):
def __init__(self, credentials: AuthCredentials):
super().__init__(credentials)
auth_type = credentials.get("auth_type")
if auth_type != "x-api-key":

View File

@ -1440,7 +1440,7 @@ class DocumentService:
.filter(
Document.id.in_(document_id_list),
Document.dataset_id == dataset_id,
Document.doc_form != "qa_model", # Skip qa_model documents
Document.doc_form != IndexStructureType.QA_INDEX, # Skip qa_model documents
)
.update({Document.need_summary: need_summary}, synchronize_session=False)
)
@ -2040,7 +2040,7 @@ class DocumentService:
document.dataset_process_rule_id = dataset_process_rule.id
document.updated_at = naive_utc_now()
document.created_from = created_from
document.doc_form = knowledge_config.doc_form
document.doc_form = IndexStructureType(knowledge_config.doc_form)
document.doc_language = knowledge_config.doc_language
document.data_source_info = json.dumps(data_source_info)
document.batch = batch
@ -2640,7 +2640,7 @@ class DocumentService:
document.splitting_completed_at = None
document.updated_at = naive_utc_now()
document.created_from = created_from
document.doc_form = document_data.doc_form
document.doc_form = IndexStructureType(document_data.doc_form)
db.session.add(document)
db.session.commit()
# update document segment
@ -3101,7 +3101,7 @@ class DocumentService:
class SegmentService:
@classmethod
def segment_create_args_validate(cls, args: dict, document: Document):
if document.doc_form == "qa_model":
if document.doc_form == IndexStructureType.QA_INDEX:
if "answer" not in args or not args["answer"]:
raise ValueError("Answer is required")
if not args["answer"].strip():
@ -3158,7 +3158,7 @@ class SegmentService:
completed_at=naive_utc_now(),
created_by=current_user.id,
)
if document.doc_form == "qa_model":
if document.doc_form == IndexStructureType.QA_INDEX:
segment_document.word_count += len(args["answer"])
segment_document.answer = args["answer"]
@ -3232,7 +3232,7 @@ class SegmentService:
tokens = 0
if dataset.indexing_technique == "high_quality" and embedding_model:
# calc embedding use tokens
if document.doc_form == "qa_model":
if document.doc_form == IndexStructureType.QA_INDEX:
tokens = embedding_model.get_text_embedding_num_tokens(
texts=[content + segment_item["answer"]]
)[0]
@ -3255,7 +3255,7 @@ class SegmentService:
completed_at=naive_utc_now(),
created_by=current_user.id,
)
if document.doc_form == "qa_model":
if document.doc_form == IndexStructureType.QA_INDEX:
segment_document.answer = segment_item["answer"]
segment_document.word_count += len(segment_item["answer"])
increment_word_count += segment_document.word_count
@ -3322,7 +3322,7 @@ class SegmentService:
content = args.content or segment.content
if segment.content == content:
segment.word_count = len(content)
if document.doc_form == "qa_model":
if document.doc_form == IndexStructureType.QA_INDEX:
segment.answer = args.answer
segment.word_count += len(args.answer) if args.answer else 0
word_count_change = segment.word_count - word_count_change
@ -3419,7 +3419,7 @@ class SegmentService:
)
# calc embedding use tokens
if document.doc_form == "qa_model":
if document.doc_form == IndexStructureType.QA_INDEX:
segment.answer = args.answer
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content + segment.answer])[0] # type: ignore
else:
@ -3436,7 +3436,7 @@ class SegmentService:
segment.enabled = True
segment.disabled_at = None
segment.disabled_by = None
if document.doc_form == "qa_model":
if document.doc_form == IndexStructureType.QA_INDEX:
segment.answer = args.answer
segment.word_count += len(args.answer) if args.answer else 0
word_count_change = segment.word_count - word_count_change

View File

@ -9,6 +9,7 @@ from flask_login import current_user
from constants import DOCUMENT_EXTENSIONS
from core.plugin.impl.plugin import PluginInstaller
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
from factories import variable_factory
@ -79,9 +80,9 @@ class RagPipelineTransformService:
pipeline = self._create_pipeline(pipeline_yaml)
# save chunk structure to dataset
if doc_form == "hierarchical_model":
if doc_form == IndexStructureType.PARENT_CHILD_INDEX:
dataset.chunk_structure = "hierarchical_model"
elif doc_form == "text_model":
elif doc_form == IndexStructureType.PARAGRAPH_INDEX:
dataset.chunk_structure = "text_model"
else:
raise ValueError("Unsupported doc form")
@ -101,7 +102,7 @@ class RagPipelineTransformService:
def _get_transform_yaml(self, doc_form: str, datasource_type: str, indexing_technique: str | None):
pipeline_yaml = {}
if doc_form == "text_model":
if doc_form == IndexStructureType.PARAGRAPH_INDEX:
match datasource_type:
case DataSourceType.UPLOAD_FILE:
if indexing_technique == "high_quality":
@ -132,7 +133,7 @@ class RagPipelineTransformService:
pipeline_yaml = yaml.safe_load(f)
case _:
raise ValueError("Unsupported datasource type")
elif doc_form == "hierarchical_model":
elif doc_form == IndexStructureType.PARENT_CHILD_INDEX:
match datasource_type:
case DataSourceType.UPLOAD_FILE:
# get graph from transform.file-parentchild.yml