merge error

This commit is contained in:
jyong
2024-09-13 09:49:24 +08:00
parent 9ca0e56a8a
commit 89e81873c4
5 changed files with 132 additions and 76 deletions

View File

@ -23,7 +23,6 @@ class ApiTemplateSetting(BaseModel):
method: str
url: str
request_method: str
authorization: Authorization
api_token: str
headers: Optional[dict] = None
params: Optional[dict] = None
callback_setting: Optional[ProcessStatusSetting] = None

View File

@ -117,6 +117,16 @@ class ExternalDatasetService:
return True
return False
@staticmethod
def get_external_knowledge_binding_with_dataset_id(tenant_id: str, dataset_id: str) -> ExternalKnowledgeBindings:
external_knowledge_binding = ExternalKnowledgeBindings.query.filter_by(
dataset_id=dataset_id,
tenant_id=tenant_id
).first()
if not external_knowledge_binding:
raise ValueError('external knowledge binding not found')
return external_knowledge_binding
@staticmethod
def document_create_args_validate(tenant_id: str, api_template_id: str, process_parameter: dict):
api_template = ExternalApiTemplates.query.filter_by(
@ -196,8 +206,6 @@ class ExternalDatasetService:
@staticmethod
def process_external_api(settings: ApiTemplateSetting,
headers: Union[None, dict[str, Any]],
parameter: Union[None, dict[str, Any]],
files: Union[None, dict[str, Any]]) -> httpx.Response:
"""
do http request depending on api bundle
@ -205,14 +213,12 @@ class ExternalDatasetService:
kwargs = {
'url': settings.url,
'headers': headers,
'headers': settings.headers,
'follow_redirects': True,
}
if settings.request_method in ('get', 'head', 'post', 'put', 'delete', 'patch'):
response = getattr(ssrf_proxy, settings.request_method)(data=parameter, files=files, **kwargs)
else:
raise ValueError(f'Invalid http method {settings.request_method}')
response = getattr(ssrf_proxy, settings.request_method)(data=settings.params, files=files, **kwargs)
return response
@staticmethod
@ -246,7 +252,7 @@ class ExternalDatasetService:
return ApiTemplateSetting.parse_obj(settings)
@staticmethod
def create_external_dataset(tenant_id, user_id, args):
def create_external_dataset(tenant_id: str, user_id: str, args: dict) -> Dataset:
# check if dataset name already exists
if Dataset.query.filter_by(name=args.get('name'), tenant_id=tenant_id).first():
raise DatasetNameDuplicateError(f"Dataset with name {args.get('name')} already exists.")
@ -254,6 +260,7 @@ class ExternalDatasetService:
id=args.get('api_template_id'),
tenant_id=tenant_id
).first()
if api_template is None:
raise ValueError('api template not found')
@ -281,4 +288,37 @@ class ExternalDatasetService:
return dataset
@staticmethod
def fetch_external_knowledge_retrival(tenant_id: str,
dataset_id: str,
query: str,
external_retrival_parameters: dict):
external_knowledge_binding = ExternalKnowledgeBindings.query.filter_by(
dataset_id=dataset_id,
tenant_id=tenant_id
).first()
if not external_knowledge_binding:
raise ValueError('external knowledge binding not found')
external_api_template = ExternalApiTemplates.query.filter_by(
id=external_knowledge_binding.external_api_template_id
).first()
if not external_api_template:
raise ValueError('external api template not found')
settings = json.loads(external_api_template.settings)
headers = {}
if settings.get('api_token'):
headers['Authorization'] = f"Bearer {settings.get('api_token')}"
external_retrival_parameters['query'] = query
api_template_setting = {
'url': f"{settings.get('endpoint')}/dify/external-knowledge/retrival-documents",
'request_method': 'post',
'headers': settings.get('headers'),
'params': external_retrival_parameters
}
response = ExternalDatasetService.process_external_api(
ApiTemplateSetting(**api_template_setting), None
)

View File

@ -19,7 +19,8 @@ default_retrieval_model = {
class HitTestingService:
@classmethod
def retrieve(cls, dataset: Dataset, query: str, account: Account, retrieval_model: dict, limit: int = 10) -> dict:
def retrieve(cls, dataset: Dataset, query: str, account: Account,
retrieval_model: dict, external_retrieval_model: dict, limit: int = 10) -> dict:
if dataset.available_document_count == 0 or dataset.available_segment_count == 0:
return {
"query": {
@ -50,6 +51,8 @@ class HitTestingService:
if retrieval_model.get("reranking_mode")
else "reranking_model",
weights=retrieval_model.get("weights", None),
provider=dataset.provider,
external_retrieval_model=external_retrieval_model,
)
end = time.perf_counter()