Compare commits

..

33 Commits

Author SHA1 Message Date
1f5cc071f8 chore(version): bump to 0.9.1 (#8945) 2024-09-30 23:22:21 +08:00
625e4c4c72 fix multiple retrieval in knowledge node (#8942) 2024-09-30 23:07:04 +08:00
7850a28ec8 Revert "chore(version): bump to 0.9.1" (#8944) 2024-09-30 22:53:32 +08:00
730d3a6d7c chore(version): bump to 0.9.1 (#8938) 2024-09-30 22:13:38 +08:00
d6a44e9990 fix: request params for internal dataset (#8940) 2024-09-30 22:10:27 +08:00
3069b5cf57 original dataset update remove unuseful parameters (#8939) 2024-09-30 22:01:32 +08:00
7873e455bb fix: Fix the error when importing web pages using jina (#8937) 2024-09-30 21:27:11 +08:00
a651b73db0 original dataset update issue (#8935) 2024-09-30 21:17:12 +08:00
d2ce4960f1 chore(versioning): bump version to 0.9.0 (#8911) 2024-09-30 18:33:20 +08:00
1af4ca344e Feat: add debounce for search in logs (#8924) 2024-09-30 17:18:47 +08:00
fa837b2dfd fix: fix the issue with the system model configuration update (#8923) 2024-09-30 17:14:13 +08:00
824a71388a chore: translate i18n files (#8917)
Co-authored-by: JohnJyong <76649700+JohnJyong@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2024-09-30 16:35:00 +08:00
4585cffce1 fix: Compatible with special characters in pg full-text search. (#8921)
Co-authored-by: Aurelius Huang <cm.huang@aftership.com>
2024-09-30 16:32:23 +08:00
13046709a9 fix: line in iteration node is not straight (#8918) 2024-09-30 16:04:51 +08:00
9d221a5e19 external knowledge api (#8913)
Co-authored-by: Yi <yxiaoisme@gmail.com>
2024-09-30 15:38:43 +08:00
77aef9ff1d refactor: optimize the calculation of rerank threshold and the logic for forbidden characters in model_uid (#8879) 2024-09-30 12:55:01 +08:00
503561f464 fix: fix the data validation consistency issue in keyword content review (#8908) 2024-09-30 12:52:18 +08:00
ada9d408ac refactor(api/variables): VariableError as a ValueError. (#8554) 2024-09-30 12:48:58 +08:00
3af65b2f45 feat(api): add version comparison logic (#8902) 2024-09-30 11:12:26 +08:00
369e1e6f58 feat(website-crawl): add jina reader as additional alternative for website crawling (#8761) 2024-09-30 09:57:19 +08:00
fb49413a41 feat: add voyage ai as a new model provider (#8747) 2024-09-29 16:55:59 +08:00
42dfde6546 docs: add english versions for the files customizable_model_scale_out and predefined_model_scale_out (#8871) 2024-09-29 16:16:56 +08:00
c531b4a911 fix: #8843 event: tts_message_end always return in api streaming resp… (#8846) 2024-09-29 16:13:20 +08:00
e4ed916baa Add Jamba and Llama3.2 model support (#8878) 2024-09-29 16:12:56 +08:00
4ec977eaba fix(workflow): update tagging logic in GitHub Actions (#8882) 2024-09-29 16:12:42 +08:00
74f58f29f9 chore: bump ruff to 0.6.8 for fixing violation in SIM910 (#8869) 2024-09-29 00:29:59 +08:00
f97607370a refactor: update Callback to an abstract class (#8868) 2024-09-28 21:41:02 +08:00
850492dafa feat: deprecate gte-Qwen2-7B-instruct embedding model (#8866) 2024-09-28 21:40:27 +08:00
61c89a9168 feat: add internlm2.5-20b and qwen2.5-coder-7b model (#8862) 2024-09-28 16:31:02 +08:00
49af18fbd6 fix: customize model credentials were invalid despite the provider credentials being active (#8864) 2024-09-28 15:54:26 +08:00
6cd22f3bca fix: update qwen2.5-coder-7b model name (#8861) 2024-09-28 15:01:27 +08:00
a2e2f8a8c9 fix(workflow/nodes/knowledge-retrieval/use-config): Preserve rerankin… (#8842) 2024-09-28 10:54:50 +08:00
27e33fb15c chore: fix wrong VectorType match case (#8857) 2024-09-28 10:54:04 +08:00
211 changed files with 4870 additions and 1158 deletions

View File

@ -125,7 +125,7 @@ jobs:
with:
images: ${{ env[matrix.image_name_env] }}
tags: |
type=raw,value=latest,enable=${{ startsWith(github.ref, 'refs/tags/') }}
type=raw,value=latest,enable=${{ startsWith(github.ref, 'refs/tags/') && !contains(github.ref, '-') }}
type=ref,event=branch
type=sha,enable=true,priority=100,prefix=,suffix=,format=long
type=raw,value=${{ github.ref_name }},enable=${{ startsWith(github.ref, 'refs/tags/') }}

View File

@ -5,7 +5,6 @@ from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt, computed
from pydantic_settings import BaseSettings
from configs.middleware.cache.redis_config import RedisConfig
from configs.middleware.external.bedrock_config import BedrockConfig
from configs.middleware.storage.aliyun_oss_storage_config import AliyunOSSStorageConfig
from configs.middleware.storage.amazon_s3_storage_config import S3StorageConfig
from configs.middleware.storage.azure_blob_storage_config import AzureBlobStorageConfig
@ -223,6 +222,5 @@ class MiddlewareConfig(
TiDBVectorConfig,
WeaviateConfig,
ElasticsearchConfig,
BedrockConfig,
):
pass

View File

@ -1,20 +0,0 @@
from typing import Optional
from pydantic import Field
from pydantic_settings import BaseSettings
class BedrockConfig(BaseSettings):
"""
bedrock configs
"""
AWS_SECRET_ACCESS_KEY: Optional[str] = Field(
description="AWS secret access key",
default=None,
)
AWS_ACCESS_KEY_ID: Optional[str] = Field(
description="AWS secret access id",
default=None,
)

View File

@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
CURRENT_VERSION: str = Field(
description="Dify version",
default="0.8.3",
default="0.9.1",
)
COMMIT_SHA: str = Field(

View File

@ -45,7 +45,6 @@ from .datasets import (
external,
file,
hit_testing,
test_external,
website,
)

View File

@ -613,10 +613,10 @@ class DatasetRetrievalSettingApi(Resource):
case (
VectorType.MILVUS
| VectorType.RELYT
| VectorType.PGVECTOR
| VectorType.TIDB_VECTOR
| VectorType.CHROMA
| VectorType.TENCENT
| VectorType.PGVECTO_RS
):
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
case (
@ -627,6 +627,7 @@ class DatasetRetrievalSettingApi(Resource):
| VectorType.MYSCALE
| VectorType.ORACLE
| VectorType.ELASTICSEARCH
| VectorType.PGVECTOR
):
return {
"retrieval_method": [

View File

@ -5,7 +5,6 @@ from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services
from controllers.console import api
from controllers.console.app.error import ProviderNotInitializeError
from controllers.console.datasets.error import DatasetNameDuplicateError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
@ -158,48 +157,6 @@ class ExternalApiUseCheckApi(Resource):
return {"is_using": external_knowledge_api_is_using, "count": count}, 200
class ExternalDatasetInitApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("external_knowledge_api_id", type=str, required=True, nullable=True, location="json")
# parser.add_argument('name', nullable=False, required=True,
# help='name is required. Name must be between 1 to 100 characters.',
# type=_validate_name)
# parser.add_argument('description', type=str, required=True, nullable=True, location='json')
parser.add_argument("data_source", type=dict, required=True, nullable=True, location="json")
parser.add_argument("process_parameter", type=dict, required=True, nullable=True, location="json")
args = parser.parse_args()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor:
raise Forbidden()
# validate args
ExternalDatasetService.document_create_args_validate(
current_user.current_tenant_id, args["external_knowledge_api_id"], args["process_parameter"]
)
try:
dataset, documents, batch = ExternalDatasetService.init_external_dataset(
tenant_id=current_user.current_tenant_id,
user_id=current_user.id,
args=args,
)
except Exception as ex:
raise ProviderNotInitializeError(ex.description)
response = {"dataset": dataset, "documents": documents, "batch": batch}
return response
class ExternalDatasetCreateApi(Resource):
@setup_required
@login_required

View File

@ -1,33 +0,0 @@
from flask_restful import Resource, reqparse
from controllers.console import api
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from libs.login import login_required
from services.external_knowledge_service import ExternalDatasetService
class TestExternalApi(Resource):
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("retrieval_setting", nullable=False, required=True, type=dict, location="json")
parser.add_argument(
"query",
nullable=False,
required=True,
type=str,
)
parser.add_argument(
"knowledge_id",
nullable=False,
required=True,
type=str,
)
args = parser.parse_args()
result = ExternalDatasetService.test_external_knowledge_retrieval(
args["retrieval_setting"], args["query"], args["knowledge_id"]
)
return result, 200
api.add_resource(TestExternalApi, "/retrieval")

View File

@ -14,7 +14,9 @@ class WebsiteCrawlApi(Resource):
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("provider", type=str, choices=["firecrawl"], required=True, nullable=True, location="json")
parser.add_argument(
"provider", type=str, choices=["firecrawl", "jinareader"], required=True, nullable=True, location="json"
)
parser.add_argument("url", type=str, required=True, nullable=True, location="json")
parser.add_argument("options", type=dict, required=True, nullable=True, location="json")
args = parser.parse_args()
@ -33,7 +35,7 @@ class WebsiteCrawlStatusApi(Resource):
@account_initialization_required
def get(self, job_id: str):
parser = reqparse.RequestParser()
parser.add_argument("provider", type=str, choices=["firecrawl"], required=True, location="args")
parser.add_argument("provider", type=str, choices=["firecrawl", "jinareader"], required=True, location="args")
args = parser.parse_args()
# get crawl status
try:

View File

@ -38,11 +38,52 @@ class VersionApi(Resource):
return result
content = json.loads(response.content)
result["version"] = content["version"]
result["release_date"] = content["releaseDate"]
result["release_notes"] = content["releaseNotes"]
result["can_auto_update"] = content["canAutoUpdate"]
if _has_new_version(latest_version=content["version"], current_version=f"{args.get('current_version')}"):
result["version"] = content["version"]
result["release_date"] = content["releaseDate"]
result["release_notes"] = content["releaseNotes"]
result["can_auto_update"] = content["canAutoUpdate"]
return result
def _has_new_version(*, latest_version: str, current_version: str) -> bool:
def parse_version(version: str) -> tuple:
# Split version into parts and pre-release suffix if any
parts = version.split("-")
version_parts = parts[0].split(".")
pre_release = parts[1] if len(parts) > 1 else None
# Validate version format
if len(version_parts) != 3:
raise ValueError(f"Invalid version format: {version}")
try:
# Convert version parts to integers
major, minor, patch = map(int, version_parts)
return (major, minor, patch, pre_release)
except ValueError:
raise ValueError(f"Invalid version format: {version}")
latest = parse_version(latest_version)
current = parse_version(current_version)
# Compare major, minor, and patch versions
for latest_part, current_part in zip(latest[:3], current[:3]):
if latest_part > current_part:
return True
elif latest_part < current_part:
return False
# If versions are equal, check pre-release suffixes
if latest[3] is None and current[3] is not None:
return True
elif latest[3] is not None and current[3] is None:
return False
elif latest[3] is not None and current[3] is not None:
# Simple string comparison for pre-release versions
return latest[3] > current[3]
return False
api.add_resource(VersionApi, "/version")

View File

@ -72,8 +72,9 @@ class DefaultModelApi(Resource):
provider=model_setting["provider"],
model=model_setting["model"],
)
except Exception:
logging.warning(f"{model_setting['model_type']} save error")
except Exception as ex:
logging.exception(f"{model_setting['model_type']} save error: {ex}")
raise ex
return {"result": "success"}

View File

@ -231,7 +231,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
except Exception as e:
logger.error(e)
break
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
if tts_publisher:
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
def _process_stream_response(
self,

View File

@ -212,7 +212,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
except Exception as e:
logger.error(e)
break
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
if tts_publisher:
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
def _process_stream_response(
self,

View File

@ -1,2 +1,2 @@
class VariableError(Exception):
class VariableError(ValueError):
pass

View File

@ -248,7 +248,8 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
else:
start_listener_time = time.time()
yield MessageAudioStreamResponse(audio=audio.audio, task_id=task_id)
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
if publisher:
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
def _process_stream_response(
self, publisher: AppGeneratorTTSPublisher, trace_manager: Optional[TraceQueueManager] = None

View File

@ -119,7 +119,7 @@ class ProviderConfiguration(BaseModel):
credentials = model_configuration.credentials
break
if self.custom_configuration.provider:
if not credentials and self.custom_configuration.provider:
credentials = self.custom_configuration.provider.credentials
return credentials

View File

@ -1,3 +1,4 @@
from abc import ABC, abstractmethod
from typing import Optional
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
@ -13,7 +14,7 @@ _TEXT_COLOR_MAPPING = {
}
class Callback:
class Callback(ABC):
"""
Base class for callbacks.
Only for LLM.
@ -21,6 +22,7 @@ class Callback:
raise_error: bool = False
@abstractmethod
def on_before_invoke(
self,
llm_instance: AIModel,
@ -48,6 +50,7 @@ class Callback:
"""
raise NotImplementedError()
@abstractmethod
def on_new_chunk(
self,
llm_instance: AIModel,
@ -77,6 +80,7 @@ class Callback:
"""
raise NotImplementedError()
@abstractmethod
def on_after_invoke(
self,
llm_instance: AIModel,
@ -106,6 +110,7 @@ class Callback:
"""
raise NotImplementedError()
@abstractmethod
def on_invoke_error(
self,
llm_instance: AIModel,

View File

@ -0,0 +1,310 @@
## Custom Integration of Pre-defined Models
### Introduction
After completing the vendors integration, the next step is to connect the vendor's models. To illustrate the entire connection process, we will use Xinference as an example to demonstrate a complete vendor integration.
It is important to note that for custom models, each model connection requires a complete vendor credential.
Unlike pre-defined models, a custom vendor integration always includes the following two parameters, which do not need to be defined in the vendor YAML file.
![](images/index/image-3.png)
As mentioned earlier, vendors do not need to implement validate_provider_credential. The runtime will automatically call the corresponding model layer's validate_credentials to validate the credentials based on the model type and name selected by the user.
### Writing the Vendor YAML
First, we need to identify the types of models supported by the vendor we are integrating.
Currently supported model types are as follows:
- `llm` Text Generation Models
- `text_embedding` Text Embedding Models
- `rerank` Rerank Models
- `speech2text` Speech-to-Text
- `tts` Text-to-Speech
- `moderation` Moderation
Xinference supports LLM, Text Embedding, and Rerank. So we will start by writing xinference.yaml.
```yaml
provider: xinference #Define the vendor identifier
label: # Vendor display name, supports both en_US (English) and zh_Hans (Simplified Chinese). If zh_Hans is not set, it will use en_US by default.
en_US: Xorbits Inference
icon_small: # Small icon, refer to other vendors' icons stored in the _assets directory within the vendor implementation directory; follows the same language policy as the label
en_US: icon_s_en.svg
icon_large: # Large icon
en_US: icon_l_en.svg
help: # Help information
title:
en_US: How to deploy Xinference
zh_Hans: 如何部署 Xinference
url:
en_US: https://github.com/xorbitsai/inference
supported_model_types: # Supported model types. Xinference supports LLM, Text Embedding, and Rerank
- llm
- text-embedding
- rerank
configurate_methods: # Since Xinference is a locally deployed vendor with no predefined models, users need to deploy whatever models they need according to Xinference documentation. Thus, it only supports custom models.
- customizable-model
provider_credential_schema:
credential_form_schemas:
```
Then, we need to determine what credentials are required to define a model in Xinference.
- Since it supports three different types of models, we need to specify the model_type to denote the model type. Here is how we can define it:
```yaml
provider_credential_schema:
credential_form_schemas:
- variable: model_type
type: select
label:
en_US: Model type
zh_Hans: 模型类型
required: true
options:
- value: text-generation
label:
en_US: Language Model
zh_Hans: 语言模型
- value: embeddings
label:
en_US: Text Embedding
- value: reranking
label:
en_US: Rerank
```
- Next, each model has its own model_name, so we need to define that here:
```yaml
- variable: model_name
type: text-input
label:
en_US: Model name
zh_Hans: 模型名称
required: true
placeholder:
zh_Hans: 填写模型名称
en_US: Input model name
```
- Specify the Xinference local deployment address:
```yaml
- variable: server_url
label:
zh_Hans: 服务器URL
en_US: Server url
type: text-input
required: true
placeholder:
zh_Hans: 在此输入Xinference的服务器地址如 https://example.com/xxx
en_US: Enter the url of your Xinference, for example https://example.com/xxx
```
- Each model has a unique model_uid, so we also need to define that here:
```yaml
- variable: model_uid
label:
zh_Hans: 模型UID
en_US: Model uid
type: text-input
required: true
placeholder:
zh_Hans: 在此输入您的Model UID
en_US: Enter the model uid
```
Now, we have completed the basic definition of the vendor.
### Writing the Model Code
Next, let's take the `llm` type as an example and write `xinference.llm.llm.py`.
In `llm.py`, create a Xinference LLM class, we name it `XinferenceAILargeLanguageModel` (this can be arbitrary), inheriting from the `__base.large_language_model.LargeLanguageModel` base class, and implement the following methods:
- LLM Invocation
Implement the core method for LLM invocation, supporting both stream and synchronous responses.
```python
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool usage
:param stop: stop words
:param stream: is the response a stream
:param user: unique user id
:return: full response or stream response chunk generator result
"""
```
When implementing, ensure to use two functions to return data separately for synchronous and stream responses. This is important because Python treats functions containing the `yield` keyword as generator functions, mandating them to return `Generator` types. Heres an example (note that the example uses simplified parameters; in real implementation, use the parameter list as defined above):
```python
def _invoke(self, stream: bool, **kwargs) \
-> Union[LLMResult, Generator]:
if stream:
return self._handle_stream_response(**kwargs)
return self._handle_sync_response(**kwargs)
def _handle_stream_response(self, **kwargs) -> Generator:
for chunk in response:
yield chunk
def _handle_sync_response(self, **kwargs) -> LLMResult:
return LLMResult(**response)
```
- Pre-compute Input Tokens
If the model does not provide an interface for pre-computing tokens, you can return 0 directly.
```python
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],tools: Optional[list[PromptMessageTool]] = None) -> int:
"""
Get number of tokens for given prompt messages
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param tools: tools for tool usage
:return: token count
"""
```
Sometimes, you might not want to return 0 directly. In such cases, you can use `self._get_num_tokens_by_gpt2(text: str)` to get pre-computed tokens. This method is provided by the `AIModel` base class, and it uses GPT2's Tokenizer for calculation. However, it should be noted that this is only a substitute and may not be fully accurate.
- Model Credentials Validation
Similar to vendor credentials validation, this method validates individual model credentials.
```python
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return: None
"""
```
- Model Parameter Schema
Unlike custom types, since the YAML file does not define which parameters a model supports, we need to dynamically generate the model parameter schema.
For instance, Xinference supports `max_tokens`, `temperature`, and `top_p` parameters.
However, some vendors may support different parameters for different models. For example, the `OpenLLM` vendor supports `top_k`, but not all models provided by this vendor support `top_k`. Let's say model A supports `top_k` but model B does not. In such cases, we need to dynamically generate the model parameter schema, as illustrated below:
```python
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
"""
used to define customizable model schema
"""
rules = [
ParameterRule(
name='temperature', type=ParameterType.FLOAT,
use_template='temperature',
label=I18nObject(
zh_Hans='温度', en_US='Temperature'
)
),
ParameterRule(
name='top_p', type=ParameterType.FLOAT,
use_template='top_p',
label=I18nObject(
zh_Hans='Top P', en_US='Top P'
)
),
ParameterRule(
name='max_tokens', type=ParameterType.INT,
use_template='max_tokens',
min=1,
default=512,
label=I18nObject(
zh_Hans='最大生成长度', en_US='Max Tokens'
)
)
]
# if model is A, add top_k to rules
if model == 'A':
rules.append(
ParameterRule(
name='top_k', type=ParameterType.INT,
use_template='top_k',
min=1,
default=50,
label=I18nObject(
zh_Hans='Top K', en_US='Top K'
)
)
)
"""
some NOT IMPORTANT code here
"""
entity = AIModelEntity(
model=model,
label=I18nObject(
en_US=model
),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=model_type,
model_properties={
ModelPropertyKey.MODE: ModelType.LLM,
},
parameter_rules=rules
)
return entity
```
- Exception Error Mapping
When a model invocation error occurs, it should be mapped to the runtime's specified `InvokeError` type, enabling Dify to handle different errors appropriately.
Runtime Errors:
- `InvokeConnectionError` Connection error during invocation
- `InvokeServerUnavailableError` Service provider unavailable
- `InvokeRateLimitError` Rate limit reached
- `InvokeAuthorizationError` Authorization failure
- `InvokeBadRequestError` Invalid request parameters
```python
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
```
For interface method details, see: [Interfaces](./interfaces.md). For specific implementations, refer to: [llm.py](https://github.com/langgenius/dify-runtime/blob/main/lib/model_providers/anthropic/llm/llm.py).

Binary file not shown.

After

Width:  |  Height:  |  Size: 230 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 205 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 44 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 262 KiB

View File

@ -0,0 +1,173 @@
## Predefined Model Integration
After completing the vendor integration, the next step is to integrate the models from the vendor.
First, we need to determine the type of model to be integrated and create the corresponding model type `module` under the respective vendor's directory.
Currently supported model types are:
- `llm` Text Generation Model
- `text_embedding` Text Embedding Model
- `rerank` Rerank Model
- `speech2text` Speech-to-Text
- `tts` Text-to-Speech
- `moderation` Moderation
Continuing with `Anthropic` as an example, `Anthropic` only supports LLM, so create a `module` named `llm` under `model_providers.anthropic`.
For predefined models, we first need to create a YAML file named after the model under the `llm` `module`, such as `claude-2.1.yaml`.
### Prepare Model YAML
```yaml
model: claude-2.1 # Model identifier
# Display name of the model, which can be set to en_US English or zh_Hans Chinese. If zh_Hans is not set, it will default to en_US.
# This can also be omitted, in which case the model identifier will be used as the label
label:
en_US: claude-2.1
model_type: llm # Model type, claude-2.1 is an LLM
features: # Supported features, agent-thought supports Agent reasoning, vision supports image understanding
- agent-thought
model_properties: # Model properties
mode: chat # LLM mode, complete for text completion models, chat for conversation models
context_size: 200000 # Maximum context size
parameter_rules: # Parameter rules for the model call; only LLM requires this
- name: temperature # Parameter variable name
# Five default configuration templates are provided: temperature/top_p/max_tokens/presence_penalty/frequency_penalty
# The template variable name can be set directly in use_template, which will use the default configuration in entities.defaults.PARAMETER_RULE_TEMPLATE
# Additional configuration parameters will override the default configuration if set
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label: # Display name of the parameter
zh_Hans: 取样数量
en_US: Top k
type: int # Parameter type, supports float/int/string/boolean
help: # Help information, describing the parameter's function
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false # Whether the parameter is mandatory; can be omitted
- name: max_tokens_to_sample
use_template: max_tokens
default: 4096 # Default value of the parameter
min: 1 # Minimum value of the parameter, applicable to float/int only
max: 4096 # Maximum value of the parameter, applicable to float/int only
pricing: # Pricing information
input: '8.00' # Input unit price, i.e., prompt price
output: '24.00' # Output unit price, i.e., response content price
unit: '0.000001' # Price unit, meaning the above prices are per 100K
currency: USD # Price currency
```
It is recommended to prepare all model configurations before starting the implementation of the model code.
You can also refer to the YAML configuration information under the corresponding model type directories of other vendors in the `model_providers` directory. For the complete YAML rules, refer to: [Schema](schema.md#aimodelentity).
### Implement the Model Call Code
Next, create a Python file named `llm.py` under the `llm` `module` to write the implementation code.
Create an Anthropic LLM class named `AnthropicLargeLanguageModel` (or any other name), inheriting from the `__base.large_language_model.LargeLanguageModel` base class, and implement the following methods:
- LLM Call
Implement the core method for calling the LLM, supporting both streaming and synchronous responses.
```python
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:return: full response or stream response chunk generator result
"""
```
Ensure to use two functions for returning data, one for synchronous returns and the other for streaming returns, because Python identifies functions containing the `yield` keyword as generator functions, fixing the return type to `Generator`. Thus, synchronous and streaming returns need to be implemented separately, as shown below (note that the example uses simplified parameters, for actual implementation follow the above parameter list):
```python
def _invoke(self, stream: bool, **kwargs) \
-> Union[LLMResult, Generator]:
if stream:
return self._handle_stream_response(**kwargs)
return self._handle_sync_response(**kwargs)
def _handle_stream_response(self, **kwargs) -> Generator:
for chunk in response:
yield chunk
def _handle_sync_response(self, **kwargs) -> LLMResult:
return LLMResult(**response)
```
- Pre-compute Input Tokens
If the model does not provide an interface to precompute tokens, return 0 directly.
```python
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
"""
Get number of tokens for given prompt messages
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param tools: tools for tool calling
:return:
"""
```
- Validate Model Credentials
Similar to vendor credential validation, but specific to a single model.
```python
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
```
- Map Invoke Errors
When a model call fails, map it to a specific `InvokeError` type as required by Runtime, allowing Dify to handle different errors accordingly.
Runtime Errors:
- `InvokeConnectionError` Connection error
- `InvokeServerUnavailableError` Service provider unavailable
- `InvokeRateLimitError` Rate limit reached
- `InvokeAuthorizationError` Authorization failed
- `InvokeBadRequestError` Parameter error
```python
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
```
For interface method explanations, see: [Interfaces](./interfaces.md). For detailed implementation, refer to: [llm.py](https://github.com/langgenius/dify-runtime/blob/main/lib/model_providers/anthropic/llm/llm.py).

View File

@ -58,7 +58,7 @@ provider_credential_schema: # Provider credential rules, as Anthropic only supp
en_US: Enter your API URL
```
You can also refer to the YAML configuration information under other provider directories in `model_providers`. The complete YAML rules are available at: [Schema](schema.md#Provider).
You can also refer to the YAML configuration information under other provider directories in `model_providers`. The complete YAML rules are available at: [Schema](schema.md#provider).
### Implementing Provider Code

View File

@ -117,7 +117,7 @@ model_credential_schema:
en_US: Enter your API Base
```
也可以参考 `model_providers` 目录下其他供应商目录下的 YAML 配置信息,完整的 YAML 规则见:[Schema](schema.md#Provider)。
也可以参考 `model_providers` 目录下其他供应商目录下的 YAML 配置信息,完整的 YAML 规则见:[Schema](schema.md#provider)。
#### 实现供应商代码

View File

@ -40,3 +40,4 @@
- fireworks
- mixedbread
- nomic
- voyage

View File

@ -6,6 +6,8 @@
- anthropic.claude-v2:1
- anthropic.claude-3-sonnet-v1:0
- anthropic.claude-3-haiku-v1:0
- ai21.jamba-1-5-large-v1:0
- ai21.jamba-1-5-mini-v1:0
- cohere.command-light-text-v14
- cohere.command-text-v14
- cohere.command-r-plus-v1.0
@ -15,6 +17,10 @@
- meta.llama3-1-405b-instruct-v1:0
- meta.llama3-8b-instruct-v1:0
- meta.llama3-70b-instruct-v1:0
- us.meta.llama3-2-1b-instruct-v1:0
- us.meta.llama3-2-3b-instruct-v1:0
- us.meta.llama3-2-11b-instruct-v1:0
- us.meta.llama3-2-90b-instruct-v1:0
- meta.llama2-13b-chat-v1
- meta.llama2-70b-chat-v1
- mistral.mistral-large-2407-v1:0

View File

@ -0,0 +1,26 @@
model: ai21.jamba-1-5-large-v1:0
label:
en_US: Jamba 1.5 Large
model_type: llm
model_properties:
mode: completion
context_size: 256000
parameter_rules:
- name: temperature
use_template: temperature
default: 1
min: 0.0
max: 2.0
- name: top_p
use_template: top_p
- name: max_gen_len
use_template: max_tokens
required: true
default: 4096
min: 1
max: 4096
pricing:
input: '0.002'
output: '0.008'
unit: '0.001'
currency: USD

View File

@ -0,0 +1,26 @@
model: ai21.jamba-1-5-mini-v1:0
label:
en_US: Jamba 1.5 Mini
model_type: llm
model_properties:
mode: completion
context_size: 256000
parameter_rules:
- name: temperature
use_template: temperature
default: 1
min: 0.0
max: 2.0
- name: top_p
use_template: top_p
- name: max_gen_len
use_template: max_tokens
required: true
default: 4096
min: 1
max: 4096
pricing:
input: '0.0002'
output: '0.0004'
unit: '0.001'
currency: USD

View File

@ -63,6 +63,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
{"prefix": "us.anthropic.claude-3", "support_system_prompts": True, "support_tool_use": True},
{"prefix": "eu.anthropic.claude-3", "support_system_prompts": True, "support_tool_use": True},
{"prefix": "anthropic.claude-3", "support_system_prompts": True, "support_tool_use": True},
{"prefix": "us.meta.llama3-2", "support_system_prompts": True, "support_tool_use": True},
{"prefix": "meta.llama", "support_system_prompts": True, "support_tool_use": False},
{"prefix": "mistral.mistral-7b-instruct", "support_system_prompts": False, "support_tool_use": False},
{"prefix": "mistral.mixtral-8x7b-instruct", "support_system_prompts": False, "support_tool_use": False},
@ -70,6 +71,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
{"prefix": "mistral.mistral-small", "support_system_prompts": True, "support_tool_use": True},
{"prefix": "cohere.command-r", "support_system_prompts": True, "support_tool_use": True},
{"prefix": "amazon.titan", "support_system_prompts": False, "support_tool_use": False},
{"prefix": "ai21.jamba-1-5", "support_system_prompts": True, "support_tool_use": False},
]
@staticmethod

View File

@ -0,0 +1,29 @@
model: us.meta.llama3-2-11b-instruct-v1:0
label:
en_US: US Meta Llama 3.2 11B Instruct
model_type: llm
features:
- vision
- tool-call
model_properties:
mode: completion
context_size: 128000
parameter_rules:
- name: temperature
use_template: temperature
default: 0.5
min: 0.0
max: 1
- name: top_p
use_template: top_p
- name: max_gen_len
use_template: max_tokens
required: true
default: 512
min: 1
max: 2048
pricing:
input: '0.00035'
output: '0.00035'
unit: '0.001'
currency: USD

View File

@ -0,0 +1,26 @@
model: us.meta.llama3-2-1b-instruct-v1:0
label:
en_US: US Meta Llama 3.2 1B Instruct
model_type: llm
model_properties:
mode: completion
context_size: 128000
parameter_rules:
- name: temperature
use_template: temperature
default: 0.5
min: 0.0
max: 1
- name: top_p
use_template: top_p
- name: max_gen_len
use_template: max_tokens
required: true
default: 512
min: 1
max: 2048
pricing:
input: '0.0001'
output: '0.0001'
unit: '0.001'
currency: USD

View File

@ -0,0 +1,26 @@
model: us.meta.llama3-2-3b-instruct-v1:0
label:
en_US: US Meta Llama 3.2 3B Instruct
model_type: llm
model_properties:
mode: completion
context_size: 128000
parameter_rules:
- name: temperature
use_template: temperature
default: 0.5
min: 0.0
max: 1
- name: top_p
use_template: top_p
- name: max_gen_len
use_template: max_tokens
required: true
default: 512
min: 1
max: 2048
pricing:
input: '0.00015'
output: '0.00015'
unit: '0.001'
currency: USD

View File

@ -0,0 +1,31 @@
model: us.meta.llama3-2-90b-instruct-v1:0
label:
en_US: US Meta Llama 3.2 90B Instruct
model_type: llm
features:
- tool-call
model_properties:
mode: completion
context_size: 128000
parameter_rules:
- name: temperature
use_template: temperature
default: 0.5
min: 0.0
max: 1
- name: top_p
use_template: top_p
default: 0.9
min: 0
max: 1
- name: max_gen_len
use_template: max_tokens
required: true
default: 512
min: 1
max: 2048
pricing:
input: '0.002'
output: '0.002'
unit: '0.001'
currency: USD

View File

@ -1,24 +1,23 @@
- Qwen2.5-72B-Instruct
- Qwen2.5-7B-Instruct
- Qwen2-72B-Instruct
- Qwen2-72B-Instruct-AWQ-int4
- Qwen2-72B-Instruct-GPTQ-Int4
- Qwen2-7B-Instruct
- Qwen2-7B
- Qwen1.5-110B-Chat-GPTQ-Int4
- Qwen1.5-72B-Chat-GPTQ-Int4
- Qwen1.5-7B
- Qwen-14B-Chat-Int4
- Yi-Coder-1.5B-Chat
- Yi-Coder-9B-Chat
- Qwen2-72B-Instruct-AWQ-int4
- Yi-1_5-9B-Chat-16K
- Qwen2-7B-Instruct
- Reflection-Llama-3.1-70B
- Qwen2-72B-Instruct
- Meta-Llama-3.1-8B-Instruct
- Meta-Llama-3.1-405B-Instruct-AWQ-INT4
- Meta-Llama-3-70B-Instruct-GPTQ-Int4
- chatglm3-6b
- Meta-Llama-3-8B-Instruct
- Llama3-Chinese_v2
- deepseek-v2-lite-chat
- Qwen2-72B-Instruct-GPTQ-Int4
- Qwen2-7B
- Qwen-14B-Chat-Int4
- Qwen1.5-72B-Chat-GPTQ-Int4
- Qwen1.5-7B
- Qwen1.5-110B-Chat-GPTQ-Int4
- deepseek-v2-chat
- chatglm3-6b

View File

@ -0,0 +1,4 @@
- gte-Qwen2-7B-instruct
- BAAI/bge-large-en-v1.5
- BAAI/bge-large-zh-v1.5
- BAAI/bge-m3

View File

@ -2,3 +2,4 @@ model: gte-Qwen2-7B-instruct
model_type: text-embedding
model_properties:
context_size: 2048
deprecated: true

View File

@ -1,18 +1,17 @@
- Qwen/Qwen2.5-72B-Instruct
- Qwen/Qwen2.5-Math-72B-Instruct
- Qwen/Qwen2.5-32B-Instruct
- Qwen/Qwen2.5-14B-Instruct
- Qwen/Qwen2.5-7B-Instruct
- Qwen/Qwen2.5-Coder-7B-Instruct
- deepseek-ai/DeepSeek-V2.5
- Qwen/Qwen2.5-Math-72B-Instruct
- Qwen/Qwen2-72B-Instruct
- Qwen/Qwen2-57B-A14B-Instruct
- Qwen/Qwen2-7B-Instruct
- Qwen/Qwen2-1.5B-Instruct
- deepseek-ai/DeepSeek-V2.5
- deepseek-ai/DeepSeek-V2-Chat
- deepseek-ai/DeepSeek-Coder-V2-Instruct
- THUDM/glm-4-9b-chat
- THUDM/chatglm3-6b
- 01-ai/Yi-1.5-34B-Chat-16K
- 01-ai/Yi-1.5-9B-Chat-16K
- 01-ai/Yi-1.5-6B-Chat
@ -26,13 +25,4 @@
- google/gemma-2-27b-it
- google/gemma-2-9b-it
- mistralai/Mistral-7B-Instruct-v0.2
- Pro/Qwen/Qwen2-7B-Instruct
- Pro/Qwen/Qwen2-1.5B-Instruct
- Pro/THUDM/glm-4-9b-chat
- Pro/THUDM/chatglm3-6b
- Pro/01-ai/Yi-1.5-9B-Chat-16K
- Pro/01-ai/Yi-1.5-6B-Chat
- Pro/internlm/internlm2_5-7b-chat
- Pro/meta-llama/Meta-Llama-3.1-8B-Instruct
- Pro/meta-llama/Meta-Llama-3-8B-Instruct
- Pro/google/gemma-2-9b-it
- mistralai/Mixtral-8x7B-Instruct-v0.1

View File

@ -0,0 +1,30 @@
model: internlm/internlm2_5-20b-chat
label:
en_US: internlm/internlm2_5-20b-chat
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 32768
parameter_rules:
- name: temperature
use_template: temperature
- name: max_tokens
use_template: max_tokens
type: int
default: 512
min: 1
max: 4096
help:
zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。
en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.
- name: top_p
use_template: top_p
- name: frequency_penalty
use_template: frequency_penalty
pricing:
input: '1'
output: '1'
unit: '0.000001'
currency: RMB

View File

@ -0,0 +1,74 @@
model: Qwen/Qwen2.5-Coder-7B-Instruct
label:
en_US: Qwen/Qwen2.5-Coder-7B-Instruct
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 131072
parameter_rules:
- name: temperature
use_template: temperature
type: float
default: 0.3
min: 0.0
max: 2.0
help:
zh_Hans: 用于控制随机性和多样性的程度。具体来说temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值使得更多的低概率词被选择生成结果更加多样化而较低的temperature值则会增强概率分布的峰值使得高概率词更容易被选择生成结果更加确定。
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
- name: max_tokens
use_template: max_tokens
type: int
default: 8192
min: 1
max: 8192
help:
zh_Hans: 用于指定模型在生成内容时token的最大数量它定义了生成的上限但不保证每次都会生成到这个数量。
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
- name: top_p
use_template: top_p
type: float
default: 0.8
min: 0.1
max: 0.9
help:
zh_Hans: 生成过程中核采样方法概率阈值例如取值为0.8时仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
- name: top_k
type: int
min: 0
max: 99
label:
zh_Hans: 取样数量
en_US: Top k
help:
zh_Hans: 生成时采样候选集的大小。例如取值为50时仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大生成的随机性越高取值越小生成的确定性越高。
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
- name: seed
required: false
type: int
default: 1234
label:
zh_Hans: 随机种子
en_US: Random seed
help:
zh_Hans: 生成时使用的随机数种子用户控制模型生成内容的随机性。支持无符号64位整数默认值为 1234。在使用seed时模型将尽可能生成相同或相似的结果但目前不保证每次生成的结果完全相同。
en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time.
- name: repetition_penalty
required: false
type: float
default: 1.1
label:
zh_Hans: 重复惩罚
en_US: Repetition penalty
help:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
- name: response_format
use_template: response_format
pricing:
input: '0'
output: '0'
unit: '0.000001'
currency: RMB

View File

@ -0,0 +1,74 @@
model: Qwen/Qwen2.5-Math-72B-Instruct
label:
en_US: Qwen/Qwen2.5-Math-72B-Instruct
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 4096
parameter_rules:
- name: temperature
use_template: temperature
type: float
default: 0.3
min: 0.0
max: 2.0
help:
zh_Hans: 用于控制随机性和多样性的程度。具体来说temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值使得更多的低概率词被选择生成结果更加多样化而较低的temperature值则会增强概率分布的峰值使得高概率词更容易被选择生成结果更加确定。
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
- name: max_tokens
use_template: max_tokens
type: int
default: 2000
min: 1
max: 2000
help:
zh_Hans: 用于指定模型在生成内容时token的最大数量它定义了生成的上限但不保证每次都会生成到这个数量。
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
- name: top_p
use_template: top_p
type: float
default: 0.8
min: 0.1
max: 0.9
help:
zh_Hans: 生成过程中核采样方法概率阈值例如取值为0.8时仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
- name: top_k
type: int
min: 0
max: 99
label:
zh_Hans: 取样数量
en_US: Top k
help:
zh_Hans: 生成时采样候选集的大小。例如取值为50时仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大生成的随机性越高取值越小生成的确定性越高。
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
- name: seed
required: false
type: int
default: 1234
label:
zh_Hans: 随机种子
en_US: Random seed
help:
zh_Hans: 生成时使用的随机数种子用户控制模型生成内容的随机性。支持无符号64位整数默认值为 1234。在使用seed时模型将尽可能生成相同或相似的结果但目前不保证每次生成的结果完全相同。
en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time.
- name: repetition_penalty
required: false
type: float
default: 1.1
label:
zh_Hans: 重复惩罚
en_US: Repetition penalty
help:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
- name: response_format
use_template: response_format
pricing:
input: '4.13'
output: '4.13'
unit: '0.000001'
currency: RMB

View File

@ -1,7 +1,7 @@
# for more details, please refer to https://help.aliyun.com/zh/model-studio/getting-started/models
model: qwen2.5-7b-instruct
model: qwen2.5-coder-7b-instruct
label:
en_US: qwen2.5-7b-instruct
en_US: qwen2.5-coder-7b-instruct
model_type: llm
features:
- agent-thought

View File

@ -0,0 +1,21 @@
<svg version="1.0" xmlns="http://www.w3.org/2000/svg" width="100.000000pt" height="19.000000pt" viewBox="0 0 300.000000 57.000000" preserveAspectRatio="xMidYMid meet"><g transform="translate(0.000000,57.000000) scale(0.100000,-0.100000)" fill="#000000" stroke="none"><path d="M2505 368 c-38 -84 -86 -188 -106 -230 l-38 -78 27 0 c24 0 30 7 55
75 l28 75 100 0 100 0 25 -55 c13 -31 24 -64 24 -75 0 -17 7 -20 44 -20 l43 0
-37 73 c-20 39 -68 143 -106 229 -38 87 -74 158 -80 158 -5 0 -41 -69 -79
-152z m110 -30 c22 -51 41 -95 42 -98 2 -3 -36 -6 -83 -7 -76 -1 -85 0 -81 15
12 40 72 182 77 182 3 0 24 -41 45 -92z"/><path d="M63 493 c19 -61 197 -438 209 -440 10 -2 147 282 216 449 2 4 -10 8
-27 8 -23 0 -31 -5 -31 -17 0 -16 -142 -365 -146 -360 -8 11 -144 329 -149
350 -6 23 -12 27 -42 27 -29 0 -34 -3 -30 -17z"/><path d="M2855 285 l0 -225 30 0 30 0 0 225 0 225 -30 0 -30 0 0 -225z"/><path d="M588 380 c-55 -30 -82 -74 -86 -145 -3 -50 0 -66 20 -95 39 -58 82
-80 153 -80 68 0 110 21 149 73 32 43 30 150 -3 196 -47 66 -158 90 -233 51z
m133 -16 c59 -30 89 -156 54 -224 -45 -87 -162 -78 -201 16 -18 44 -18 128 1
164 28 55 90 73 146 44z"/><path d="M935 303 l76 -98 -7 -72 -6 -73 33 0 34 0 -3 78 -4 77 71 93 c65 85
68 92 46 92 -15 0 -29 -9 -36 -22 -18 -33 -90 -128 -98 -128 -6 1 -67 85 -88
122 -8 15 -24 23 -53 25 l-41 4 76 -98z"/><path d="M1257 230 c-82 -169 -83 -170 -57 -170 17 0 27 6 27 15 0 8 7 31 17
52 l17 38 79 0 78 1 16 -34 c9 -18 16 -42 16 -52 0 -17 7 -20 41 -20 22 0 39
3 37 8 -2 4 -39 80 -83 170 -43 89 -84 162 -92 162 -7 0 -50 -76 -96 -170z
m90 -38 c-33 -2 -61 -1 -63 1 -2 2 10 34 26 71 l31 68 33 -68 33 -69 -60 -3z"/><path d="M1665 386 c-37 -16 -84 -63 -97 -96 -13 -35 -12 -104 2 -132 49 -94
182 -134 280 -83 24 12 29 22 32 64 3 49 3 49 -30 53 l-33 4 3 -45 c4 -61 -5
-71 -60 -71 -93 0 -142 57 -142 164 0 44 5 60 25 85 47 55 136 65 184 20 30
-28 35 -20 11 19 -19 31 -22 32 -82 32 -35 -1 -76 -7 -93 -14z"/><path d="M1955 230 l0 -170 91 0 c76 0 93 3 98 16 4 9 5 18 4 20 -2 1 -31 -1
-66 -5 -34 -4 -64 -5 -67 -3 -3 3 -5 36 -5 73 l0 68 55 -6 c49 -5 55 -4 55 13
0 17 -6 19 -55 16 l-55 -4 0 61 0 61 64 0 c48 0 65 4 70 15 4 13 -10 15 -92
15 l-97 0 0 -170z"/></g></svg>

After

Width:  |  Height:  |  Size: 2.2 KiB

View File

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<svg width="64px" height="64px" viewBox="0 0 64 64" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
<title>voyage</title>
<g id="voyage" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
<rect id="矩形" fill="#333333" x="0" y="0" width="64" height="64" rx="12"></rect>
<path d="M12.1128004,51.4376727 C13.8950799,45.8316747 30.5922254,11.1847688 31.7178757,11.0009656 C32.6559176,10.8171624 45.5070913,36.9172188 51.9795803,52.2647871 C52.1671887,52.6323936 51.0415384,53 49.4468672,53 C47.2893709,53 46.5389374,52.540492 46.5389374,51.4376727 C46.5389374,49.967247 33.2187427,17.8935861 32.8435259,18.3530942 C32.0930924,19.3640118 19.3357228,48.5887229 18.8667019,50.5186566 C18.3038768,52.6323936 17.7410516,53 14.926926,53 C12.2066045,53 11.7375836,52.7242952 12.1128004,51.4376727 Z" id="路径" fill="#FFFFFF" transform="translate(32, 32) scale(1, -1) translate(-32, -32)"></path>
</g>
</svg>

After

Width:  |  Height:  |  Size: 1.0 KiB

View File

@ -0,0 +1,4 @@
model: rerank-1
model_type: rerank
model_properties:
context_size: 8000

View File

@ -0,0 +1,4 @@
model: rerank-lite-1
model_type: rerank
model_properties:
context_size: 4000

View File

@ -0,0 +1,123 @@
from typing import Optional
import httpx
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
class VoyageRerankModel(RerankModel):
"""
Model class for Voyage rerank model.
"""
def _invoke(
self,
model: str,
credentials: dict,
query: str,
docs: list[str],
score_threshold: Optional[float] = None,
top_n: Optional[int] = None,
user: Optional[str] = None,
) -> RerankResult:
"""
Invoke rerank model
:param model: model name
:param credentials: model credentials
:param query: search query
:param docs: docs for reranking
:param score_threshold: score threshold
:param top_n: top n documents to return
:param user: unique user id
:return: rerank result
"""
if len(docs) == 0:
return RerankResult(model=model, docs=[])
base_url = credentials.get("base_url", "https://api.voyageai.com/v1")
base_url = base_url.removesuffix("/")
try:
response = httpx.post(
base_url + "/rerank",
json={"model": model, "query": query, "documents": docs, "top_k": top_n, "return_documents": True},
headers={"Authorization": f"Bearer {credentials.get('api_key')}", "Content-Type": "application/json"},
)
response.raise_for_status()
results = response.json()
rerank_documents = []
for result in results["data"]:
rerank_document = RerankDocument(
index=result["index"],
text=result["document"],
score=result["relevance_score"],
)
if score_threshold is None or result["relevance_score"] >= score_threshold:
rerank_documents.append(rerank_document)
return RerankResult(model=model, docs=rerank_documents)
except httpx.HTTPStatusError as e:
raise InvokeServerUnavailableError(str(e))
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
try:
self._invoke(
model=model,
credentials=credentials,
query="What is the capital of the United States?",
docs=[
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
"Census, Carson City had a population of 55,274.",
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
"are a political division controlled by the United States. Its capital is Saipan.",
],
score_threshold=0.8,
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
"""
return {
InvokeConnectionError: [httpx.ConnectError],
InvokeServerUnavailableError: [httpx.RemoteProtocolError],
InvokeRateLimitError: [],
InvokeAuthorizationError: [httpx.HTTPStatusError],
InvokeBadRequestError: [httpx.RequestError],
}
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
"""
generate custom model entities from credentials
"""
entity = AIModelEntity(
model=model,
label=I18nObject(en_US=model),
model_type=ModelType.RERANK,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", "8000"))},
)
return entity

View File

@ -0,0 +1,172 @@
import time
from json import JSONDecodeError, dumps
from typing import Optional
import requests
from core.embedding.embedding_constant import EmbeddingInputType
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
class VoyageTextEmbeddingModel(TextEmbeddingModel):
"""
Model class for Voyage text embedding model.
"""
api_base: str = "https://api.voyageai.com/v1"
def _invoke(
self,
model: str,
credentials: dict,
texts: list[str],
user: Optional[str] = None,
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
) -> TextEmbeddingResult:
"""
Invoke text embedding model
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:param user: unique user id
:param input_type: input type
:return: embeddings result
"""
api_key = credentials["api_key"]
if not api_key:
raise CredentialsValidateFailedError("api_key is required")
base_url = credentials.get("base_url", self.api_base)
base_url = base_url.removesuffix("/")
url = base_url + "/embeddings"
headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"}
voyage_input_type = "null"
if input_type is not None:
voyage_input_type = input_type.value
data = {"model": model, "input": texts, "input_type": voyage_input_type}
try:
response = requests.post(url, headers=headers, data=dumps(data))
except Exception as e:
raise InvokeConnectionError(str(e))
if response.status_code != 200:
try:
resp = response.json()
msg = resp["detail"]
if response.status_code == 401:
raise InvokeAuthorizationError(msg)
elif response.status_code == 429:
raise InvokeRateLimitError(msg)
elif response.status_code == 500:
raise InvokeServerUnavailableError(msg)
else:
raise InvokeBadRequestError(msg)
except JSONDecodeError as e:
raise InvokeServerUnavailableError(
f"Failed to convert response to json: {e} with text: {response.text}"
)
try:
resp = response.json()
embeddings = resp["data"]
usage = resp["usage"]
except Exception as e:
raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}")
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage["total_tokens"])
result = TextEmbeddingResult(
model=model, embeddings=[[float(data) for data in x["embedding"]] for x in embeddings], usage=usage
)
return result
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
"""
Get number of tokens for given prompt messages
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:return:
"""
return sum(self._get_num_tokens_by_gpt2(text) for text in texts)
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
try:
self._invoke(model=model, credentials=credentials, texts=["ping"])
except Exception as e:
raise CredentialsValidateFailedError(f"Credentials validation failed: {e}")
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
return {
InvokeConnectionError: [InvokeConnectionError],
InvokeServerUnavailableError: [InvokeServerUnavailableError],
InvokeRateLimitError: [InvokeRateLimitError],
InvokeAuthorizationError: [InvokeAuthorizationError],
InvokeBadRequestError: [KeyError, InvokeBadRequestError],
}
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
"""
Calculate response usage
:param model: model name
:param credentials: model credentials
:param tokens: input tokens
:return: usage
"""
# get input price info
input_price_info = self.get_price(
model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
)
# transform usage
usage = EmbeddingUsage(
tokens=tokens,
total_tokens=tokens,
unit_price=input_price_info.unit_price,
price_unit=input_price_info.unit,
total_price=input_price_info.total_amount,
currency=input_price_info.currency,
latency=time.perf_counter() - self.started_at,
)
return usage
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
"""
generate custom model entities from credentials
"""
entity = AIModelEntity(
model=model,
label=I18nObject(en_US=model),
model_type=ModelType.TEXT_EMBEDDING,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))},
)
return entity

View File

@ -0,0 +1,8 @@
model: voyage-3-lite
model_type: text-embedding
model_properties:
context_size: 32000
pricing:
input: '0.00002'
unit: '0.001'
currency: USD

View File

@ -0,0 +1,8 @@
model: voyage-3
model_type: text-embedding
model_properties:
context_size: 32000
pricing:
input: '0.00006'
unit: '0.001'
currency: USD

View File

@ -0,0 +1,28 @@
import logging
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
logger = logging.getLogger(__name__)
class VoyageProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
"""
Validate provider credentials
if validate failed, raise exception
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
"""
try:
model_instance = self.get_model_instance(ModelType.TEXT_EMBEDDING)
# Use `voyage-3` model for validate,
# no matter what model you pass in, text completion model or chat model
model_instance.validate_credentials(model="voyage-3", credentials=credentials)
except CredentialsValidateFailedError as ex:
raise ex
except Exception as ex:
logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
raise ex

View File

@ -0,0 +1,31 @@
provider: voyage
label:
en_US: Voyage
description:
en_US: Embedding and Rerank Model Supported
icon_small:
en_US: icon_s_en.svg
icon_large:
en_US: icon_l_en.svg
background: "#EFFDFD"
help:
title:
en_US: Get your API key from Voyage AI
zh_Hans: 从 Voyage 获取 API Key
url:
en_US: https://dash.voyageai.com/
supported_model_types:
- text-embedding
- rerank
configurate_methods:
- predefined-model
provider_credential_schema:
credential_form_schemas:
- variable: api_key
label:
en_US: API Key
type: secret-input
required: true
placeholder:
zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key

View File

@ -59,6 +59,7 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
from core.model_runtime.model_providers.xinference.xinference_helper import (
XinferenceHelper,
XinferenceModelExtraParameter,
validate_model_uid,
)
from core.model_runtime.utils import helper
@ -114,7 +115,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
}
"""
try:
if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]:
if not validate_model_uid(credentials):
raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
extra_param = XinferenceHelper.get_xinference_extra_parameter(

View File

@ -15,6 +15,7 @@ from core.model_runtime.errors.invoke import (
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
from core.model_runtime.model_providers.xinference.xinference_helper import validate_model_uid
class XinferenceRerankModel(RerankModel):
@ -77,10 +78,7 @@ class XinferenceRerankModel(RerankModel):
)
# score threshold check
if score_threshold is not None:
if result["relevance_score"] >= score_threshold:
rerank_documents.append(rerank_document)
else:
if score_threshold is None or result["relevance_score"] >= score_threshold:
rerank_documents.append(rerank_document)
return RerankResult(model=model, docs=rerank_documents)
@ -94,7 +92,7 @@ class XinferenceRerankModel(RerankModel):
:return:
"""
try:
if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]:
if not validate_model_uid(credentials):
raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
credentials["server_url"] = credentials["server_url"].removesuffix("/")

View File

@ -14,6 +14,7 @@ from core.model_runtime.errors.invoke import (
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
from core.model_runtime.model_providers.xinference.xinference_helper import validate_model_uid
class XinferenceSpeech2TextModel(Speech2TextModel):
@ -42,7 +43,7 @@ class XinferenceSpeech2TextModel(Speech2TextModel):
:return:
"""
try:
if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]:
if not validate_model_uid(credentials):
raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
credentials["server_url"] = credentials["server_url"].removesuffix("/")

View File

@ -17,7 +17,7 @@ from core.model_runtime.errors.invoke import (
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.model_runtime.model_providers.xinference.xinference_helper import XinferenceHelper
from core.model_runtime.model_providers.xinference.xinference_helper import XinferenceHelper, validate_model_uid
class XinferenceTextEmbeddingModel(TextEmbeddingModel):
@ -110,7 +110,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
:return:
"""
try:
if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]:
if not validate_model_uid(credentials):
raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
server_url = credentials["server_url"]

View File

@ -15,7 +15,7 @@ from core.model_runtime.errors.invoke import (
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.tts_model import TTSModel
from core.model_runtime.model_providers.xinference.xinference_helper import XinferenceHelper
from core.model_runtime.model_providers.xinference.xinference_helper import XinferenceHelper, validate_model_uid
class XinferenceText2SpeechModel(TTSModel):
@ -70,7 +70,7 @@ class XinferenceText2SpeechModel(TTSModel):
:return:
"""
try:
if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]:
if not validate_model_uid(credentials):
raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
credentials["server_url"] = credentials["server_url"].removesuffix("/")

View File

@ -132,3 +132,16 @@ class XinferenceHelper:
context_length=context_length,
model_family=model_family,
)
def validate_model_uid(credentials: dict) -> bool:
"""
Validate the model_uid within the credentials dictionary to ensure it does not
contain forbidden characters ("/", "?", "#").
param credentials: model credentials
:return: True if the model_uid does not contain forbidden characters ("/", "?", "#"), else False.
"""
forbidden_characters = ["/", "?", "#"]
model_uid = credentials.get("model_uid", "")
return not any(char in forbidden_characters for char in model_uid)

View File

@ -48,7 +48,7 @@ from ._utils import (
)
if TYPE_CHECKING:
from pydantic_core.core_schema import LiteralSchema, ModelField, ModelFieldsSchema
from pydantic_core.core_schema import ModelField
__all__ = ["BaseModel", "GenericModel"]
_BaseModelT = TypeVar("_BaseModelT", bound="BaseModel")

View File

@ -248,7 +248,7 @@ def required_args(*variants: Sequence[str]) -> Callable[[CallableT], CallableT]:
@functools.wraps(func)
def wrapper(*args: object, **kwargs: object) -> object:
given_params: set[str] = set()
for i, _ in enumerate(args):
for i in range(len(args)):
try:
given_params.add(positional[i])
except IndexError:

View File

@ -18,8 +18,12 @@ class KeywordsModeration(Moderation):
if not config.get("keywords"):
raise ValueError("keywords is required")
if len(config.get("keywords")) > 1000:
raise ValueError("keywords length must be less than 1000")
if len(config.get("keywords")) > 10000:
raise ValueError("keywords length must be less than 10000")
keywords_row_len = config["keywords"].split("\n")
if len(keywords_row_len) > 100:
raise ValueError("the number of rows for the keywords must be less than 100")
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
flagged = False

View File

@ -45,7 +45,7 @@ class Jieba(BaseKeyword):
keyword_table_handler = JiebaKeywordTableHandler()
keyword_table = self._get_dataset_keyword_table()
keywords_list = kwargs.get("keywords_list", None)
keywords_list = kwargs.get("keywords_list")
for i in range(len(texts)):
text = texts[i]
if keywords_list:

View File

@ -166,7 +166,7 @@ class PGVector(BaseVector):
with self._get_cursor() as cur:
cur.execute(
f"""SELECT meta, text, ts_rank(to_tsvector(coalesce(text, '')), to_tsquery(%s)) AS score
f"""SELECT meta, text, ts_rank(to_tsvector(coalesce(text, '')), plainto_tsquery(%s)) AS score
FROM {self.table_name}
WHERE to_tsvector(text) @@ plainto_tsquery(%s)
ORDER BY score DESC

View File

@ -56,7 +56,7 @@ class TencentVector(BaseVector):
return self._client.create_database(database_name=self._client_config.database)
def get_type(self) -> str:
return "tencent"
return VectorType.TENCENT
def to_index_struct(self) -> dict:
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}

View File

@ -12,6 +12,7 @@ from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.excel_extractor import ExcelExtractor
from core.rag.extractor.firecrawl.firecrawl_web_extractor import FirecrawlWebExtractor
from core.rag.extractor.html_extractor import HtmlExtractor
from core.rag.extractor.jina_reader_extractor import JinaReaderWebExtractor
from core.rag.extractor.markdown_extractor import MarkdownExtractor
from core.rag.extractor.notion_extractor import NotionExtractor
from core.rag.extractor.pdf_extractor import PdfExtractor
@ -171,6 +172,15 @@ class ExtractProcessor:
only_main_content=extract_setting.website_info.only_main_content,
)
return extractor.extract()
elif extract_setting.website_info.provider == "jinareader":
extractor = JinaReaderWebExtractor(
url=extract_setting.website_info.url,
job_id=extract_setting.website_info.job_id,
tenant_id=extract_setting.website_info.tenant_id,
mode=extract_setting.website_info.mode,
only_main_content=extract_setting.website_info.only_main_content,
)
return extractor.extract()
else:
raise ValueError(f"Unsupported website provider: {extract_setting.website_info.provider}")
else:

View File

@ -0,0 +1,35 @@
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
from services.website_service import WebsiteService
class JinaReaderWebExtractor(BaseExtractor):
"""
Crawl and scrape websites and return content in clean llm-ready markdown.
"""
def __init__(self, url: str, job_id: str, tenant_id: str, mode: str = "crawl", only_main_content: bool = False):
"""Initialize with url, api_key, base_url and mode."""
self._url = url
self.job_id = job_id
self.tenant_id = tenant_id
self.mode = mode
self.only_main_content = only_main_content
def extract(self) -> list[Document]:
"""Extract content from the URL."""
documents = []
if self.mode == "crawl":
crawl_data = WebsiteService.get_crawl_url_data(self.job_id, "jinareader", self._url, self.tenant_id)
if crawl_data is None:
return []
document = Document(
page_content=crawl_data.get("content", ""),
metadata={
"source_url": crawl_data.get("url"),
"description": crawl_data.get("description"),
"title": crawl_data.get("title"),
},
)
documents.append(document)
return documents

View File

@ -229,7 +229,9 @@ class KnowledgeRetrievalNode(BaseNode):
source["content"] = segment.get_sign_content()
retrieval_resource_list.append(source)
if retrieval_resource_list:
retrieval_resource_list = sorted(retrieval_resource_list, key=lambda x: x.get("score"), reverse=True)
retrieval_resource_list = sorted(
retrieval_resource_list, key=lambda x: x.get("metadata").get("score"), reverse=True
)
position = 1
for item in retrieval_resource_list:
item["metadata"]["position"] = position

View File

@ -14,7 +14,7 @@ from models.dataset import Document
@document_index_created.connect
def handle(sender, **kwargs):
dataset_id = sender
document_ids = kwargs.get("document_ids", None)
document_ids = kwargs.get("document_ids")
documents = []
start_at = time.perf_counter()
for document_id in document_ids:

1549
api/poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -123,6 +123,7 @@ FIRECRAWL_API_KEY = "fc-"
TEI_EMBEDDING_SERVER_URL = "http://a.abc.com:11451"
TEI_RERANK_SERVER_URL = "http://a.abc.com:11451"
MIXEDBREAD_API_KEY = "mk-aaaaaaaaaaaaaaaaaaaa"
VOYAGE_API_KEY = "va-aaaaaaaaaaaaaaaaaaaa"
[tool.poetry]
name = "dify-api"
@ -287,4 +288,4 @@ optional = true
[tool.poetry.group.lint.dependencies]
dotenv-linter = "~0.5.0"
ruff = "~0.6.5"
ruff = "~0.6.8"

View File

@ -1,10 +1,13 @@
from services.auth.firecrawl import FirecrawlAuth
from services.auth.jina import JinaAuth
class ApiKeyAuthFactory:
def __init__(self, provider: str, credentials: dict):
if provider == "firecrawl":
self.auth = FirecrawlAuth(credentials)
elif provider == "jinareader":
self.auth = JinaAuth(credentials)
else:
raise ValueError("Invalid provider")

44
api/services/auth/jina.py Normal file
View File

@ -0,0 +1,44 @@
import json
import requests
from services.auth.api_key_auth_base import ApiKeyAuthBase
class JinaAuth(ApiKeyAuthBase):
def __init__(self, credentials: dict):
super().__init__(credentials)
auth_type = credentials.get("auth_type")
if auth_type != "bearer":
raise ValueError("Invalid auth type, Jina Reader auth type must be Bearer")
self.api_key = credentials.get("config").get("api_key", None)
if not self.api_key:
raise ValueError("No API key provided")
def validate_credentials(self):
headers = self._prepare_headers()
options = {
"url": "https://example.com",
}
response = self._post_request("https://r.jina.ai", options, headers)
if response.status_code == 200:
return True
else:
self._handle_error(response)
def _prepare_headers(self):
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
def _post_request(self, url, data, headers):
return requests.post(url, headers=headers, json=data)
def _handle_error(self, response):
if response.status_code in {402, 409, 500}:
error_message = response.json().get("error", "Unknown error occurred")
raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}")
else:
if response.text:
error_message = json.loads(response.text).get("error", "Unknown error occurred")
raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}")
raise Exception(f"Unexpected error occurred while trying to authorize. Status code: {response.status_code}")

View File

@ -250,6 +250,9 @@ class DatasetService:
db.session.commit()
else:
data.pop("partial_member_list", None)
data.pop("external_knowledge_api_id", None)
data.pop("external_knowledge_id", None)
data.pop("external_retrieval_model", None)
filtered_data = {k: v for k, v in data.items() if v is not None or k == "description"}
action = None
if dataset.indexing_technique != data["indexing_technique"]:

View File

@ -1,25 +1,19 @@
import json
import random
import time
from copy import deepcopy
from datetime import datetime, timezone
from typing import Any, Optional, Union
import boto3
import httpx
import validators
# from tasks.external_document_indexing_task import external_document_indexing_task
from configs import dify_config
from core.helper import ssrf_proxy
from extensions.ext_database import db
from models.dataset import (
Dataset,
Document,
ExternalKnowledgeApis,
ExternalKnowledgeBindings,
)
from models.model import UploadFile
from services.entities.external_knowledge_entities.external_knowledge_entities import (
Authorization,
ExternalKnowledgeApiSetting,
@ -150,70 +144,6 @@ class ExternalDatasetService:
if parameter.get("required", False) and not process_parameter.get(parameter.get("name")):
raise ValueError(f'{parameter.get("name")} is required')
@staticmethod
def init_external_dataset(tenant_id: str, user_id: str, args: dict, created_from: str = "web"):
external_knowledge_api_id = args.get("external_knowledge_api_id")
data_source = args.get("data_source")
if data_source is None:
raise ValueError("data source is required")
process_parameter = args.get("process_parameter")
external_knowledge_api = ExternalKnowledgeApis.query.filter_by(
id=external_knowledge_api_id, tenant_id=tenant_id
).first()
if external_knowledge_api is None:
raise ValueError("api template not found")
dataset = Dataset(
tenant_id=tenant_id,
name=args.get("name"),
description=args.get("description", ""),
provider="external",
created_by=user_id,
)
db.session.add(dataset)
db.session.flush()
document = Document.query.filter_by(dataset_id=dataset.id).order_by(Document.position.desc()).first()
position = document.position + 1 if document else 1
batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999))
document_ids = []
if data_source["type"] == "upload_file":
upload_file_list = data_source["info_list"]["file_info_list"]["file_ids"]
for file_id in upload_file_list:
file = (
db.session.query(UploadFile)
.filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id)
.first()
)
if file:
data_source_info = {
"upload_file_id": file_id,
}
document = Document(
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
position=position,
data_source_type=data_source["type"],
data_source_info=json.dumps(data_source_info),
batch=batch,
name=file.name,
created_from=created_from,
created_by=user_id,
)
position += 1
db.session.add(document)
db.session.flush()
document_ids.append(document.id)
db.session.commit()
# external_document_indexing_task.delay(dataset.id, external_knowledge_api_id, data_source, process_parameter)
return dataset
@staticmethod
def process_external_api(
settings: ExternalKnowledgeApiSetting, files: Union[None, dict[str, Any]]
@ -342,37 +272,3 @@ class ExternalDatasetService:
if response.status_code == 200:
return response.json().get("records", [])
return []
@staticmethod
def test_external_knowledge_retrieval(retrieval_setting: dict, query: str, external_knowledge_id: str):
client = boto3.client(
"bedrock-agent-runtime",
aws_secret_access_key=dify_config.AWS_SECRET_ACCESS_KEY,
aws_access_key_id=dify_config.AWS_ACCESS_KEY_ID,
region_name="us-east-1",
)
response = client.retrieve(
knowledgeBaseId=external_knowledge_id,
retrievalConfiguration={
"vectorSearchConfiguration": {
"numberOfResults": retrieval_setting.get("top_k"),
"overrideSearchType": "HYBRID",
}
},
retrievalQuery={"text": query},
)
results = []
if response.get("ResponseMetadata") and response.get("ResponseMetadata").get("HTTPStatusCode") == 200:
if response.get("retrievalResults"):
retrieval_results = response.get("retrievalResults")
for retrieval_result in retrieval_results:
if retrieval_result.get("score") < retrieval_setting.get("score_threshold", 0.0):
continue
result = {
"metadata": retrieval_result.get("metadata"),
"score": retrieval_result.get("score"),
"title": retrieval_result.get("metadata").get("x-amz-bedrock-kb-source-uri"),
"content": retrieval_result.get("content").get("text"),
}
results.append(result)
return {"records": results}

View File

@ -1,6 +1,7 @@
import datetime
import json
import requests
from flask_login import current_user
from core.helper import encrypter
@ -65,6 +66,35 @@ class WebsiteService:
time = str(datetime.datetime.now().timestamp())
redis_client.setex(website_crawl_time_cache_key, 3600, time)
return {"status": "active", "job_id": job_id}
elif provider == "jinareader":
api_key = encrypter.decrypt_token(
tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key")
)
crawl_sub_pages = options.get("crawl_sub_pages", False)
if not crawl_sub_pages:
response = requests.get(
f"https://r.jina.ai/{url}",
headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"},
)
if response.json().get("code") != 200:
raise ValueError("Failed to crawl")
return {"status": "active", "data": response.json().get("data")}
else:
response = requests.post(
"https://adaptivecrawl-kir3wx7b3a-uc.a.run.app",
json={
"url": url,
"maxPages": options.get("limit", 1),
"useSitemap": options.get("use_sitemap", True),
},
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}",
},
)
if response.json().get("code") != 200:
raise ValueError("Failed to crawl")
return {"status": "active", "job_id": response.json().get("data", {}).get("taskId")}
else:
raise ValueError("Invalid provider")
@ -93,6 +123,42 @@ class WebsiteService:
time_consuming = abs(end_time - float(start_time))
crawl_status_data["time_consuming"] = f"{time_consuming:.2f}"
redis_client.delete(website_crawl_time_cache_key)
elif provider == "jinareader":
api_key = encrypter.decrypt_token(
tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key")
)
response = requests.post(
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
json={"taskId": job_id},
)
data = response.json().get("data", {})
crawl_status_data = {
"status": data.get("status", "active"),
"job_id": job_id,
"total": len(data.get("urls", [])),
"current": len(data.get("processed", [])) + len(data.get("failed", [])),
"data": [],
"time_consuming": data.get("duration", 0) / 1000,
}
if crawl_status_data["status"] == "completed":
response = requests.post(
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
json={"taskId": job_id, "urls": list(data.get("processed", {}).keys())},
)
data = response.json().get("data", {})
formatted_data = [
{
"title": item.get("data", {}).get("title"),
"source_url": item.get("data", {}).get("url"),
"description": item.get("data", {}).get("description"),
"markdown": item.get("data", {}).get("content"),
}
for item in data.get("processed", {}).values()
]
crawl_status_data["data"] = formatted_data
else:
raise ValueError("Invalid provider")
return crawl_status_data
@ -100,6 +166,8 @@ class WebsiteService:
@classmethod
def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str) -> dict | None:
credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider)
# decrypt api_key
api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key"))
if provider == "firecrawl":
file_key = "website_files/" + job_id + ".txt"
if storage.exists(file_key):
@ -107,8 +175,6 @@ class WebsiteService:
if data:
data = json.loads(data.decode("utf-8"))
else:
# decrypt api_key
api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key"))
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None))
result = firecrawl_app.check_crawl_status(job_id)
if result.get("status") != "completed":
@ -119,6 +185,40 @@ class WebsiteService:
if item.get("source_url") == url:
return item
return None
elif provider == "jinareader":
file_key = "website_files/" + job_id + ".txt"
if storage.exists(file_key):
data = storage.load_once(file_key)
if data:
data = json.loads(data.decode("utf-8"))
elif not job_id:
response = requests.get(
f"https://r.jina.ai/{url}",
headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"},
)
if response.json().get("code") != 200:
raise ValueError("Failed to crawl")
return response.json().get("data")
else:
api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key"))
response = requests.post(
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
json={"taskId": job_id},
)
data = response.json().get("data", {})
if data.get("status") != "completed":
raise ValueError("Crawl job is not completed")
response = requests.post(
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
json={"taskId": job_id, "urls": list(data.get("processed", {}).keys())},
)
data = response.json().get("data", {})
for item in data.get("processed", {}).values():
if item.get("data", {}).get("url") == url:
return item.get("data", {})
else:
raise ValueError("Invalid provider")

View File

@ -0,0 +1,25 @@
import os
from unittest.mock import Mock, patch
import pytest
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.voyage.voyage import VoyageProvider
def test_validate_provider_credentials():
provider = VoyageProvider()
with pytest.raises(CredentialsValidateFailedError):
provider.validate_provider_credentials(credentials={"api_key": "hahahaha"})
with patch("requests.post") as mock_post:
mock_response = Mock()
mock_response.json.return_value = {
"object": "list",
"data": [{"object": "embedding", "embedding": [0.23333 for _ in range(1024)], "index": 0}],
"model": "voyage-3",
"usage": {"total_tokens": 1},
}
mock_response.status_code = 200
mock_post.return_value = mock_response
provider.validate_provider_credentials(credentials={"api_key": os.environ.get("VOYAGE_API_KEY")})

View File

@ -0,0 +1,92 @@
import os
from unittest.mock import Mock, patch
import pytest
from core.model_runtime.entities.rerank_entities import RerankResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.voyage.rerank.rerank import VoyageRerankModel
def test_validate_credentials():
model = VoyageRerankModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model="rerank-lite-1",
credentials={"api_key": "invalid_key"},
)
with patch("httpx.post") as mock_post:
mock_response = Mock()
mock_response.json.return_value = {
"object": "list",
"data": [
{
"relevance_score": 0.546875,
"index": 0,
"document": "Carson City is the capital city of the American state of Nevada. At the 2010 United "
"States Census, Carson City had a population of 55,274.",
},
{
"relevance_score": 0.4765625,
"index": 1,
"document": "The Commonwealth of the Northern Mariana Islands is a group of islands in the "
"Pacific Ocean that are a political division controlled by the United States. Its "
"capital is Saipan.",
},
],
"model": "rerank-lite-1",
"usage": {"total_tokens": 96},
}
mock_response.status_code = 200
mock_post.return_value = mock_response
model.validate_credentials(
model="rerank-lite-1",
credentials={
"api_key": os.environ.get("VOYAGE_API_KEY"),
},
)
def test_invoke_model():
model = VoyageRerankModel()
with patch("httpx.post") as mock_post:
mock_response = Mock()
mock_response.json.return_value = {
"object": "list",
"data": [
{
"relevance_score": 0.84375,
"index": 0,
"document": "Kasumi is a girl name of Japanese origin meaning mist.",
},
{
"relevance_score": 0.4765625,
"index": 1,
"document": "Her music is a kawaii bass, a mix of future bass, pop, and kawaii music and she "
"leads a team named PopiParty.",
},
],
"model": "rerank-lite-1",
"usage": {"total_tokens": 59},
}
mock_response.status_code = 200
mock_post.return_value = mock_response
result = model.invoke(
model="rerank-lite-1",
credentials={
"api_key": os.environ.get("VOYAGE_API_KEY"),
},
query="Who is Kasumi?",
docs=[
"Kasumi is a girl name of Japanese origin meaning mist.",
"Her music is a kawaii bass, a mix of future bass, pop, and kawaii music and she leads a team named "
"PopiParty.",
],
score_threshold=0.5,
)
assert isinstance(result, RerankResult)
assert len(result.docs) == 1
assert result.docs[0].index == 0
assert result.docs[0].score >= 0.5

View File

@ -0,0 +1,70 @@
import os
from unittest.mock import Mock, patch
import pytest
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.voyage.text_embedding.text_embedding import VoyageTextEmbeddingModel
def test_validate_credentials():
model = VoyageTextEmbeddingModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(model="voyage-3", credentials={"api_key": "invalid_key"})
with patch("requests.post") as mock_post:
mock_response = Mock()
mock_response.json.return_value = {
"object": "list",
"data": [{"object": "embedding", "embedding": [0.23333 for _ in range(1024)], "index": 0}],
"model": "voyage-3",
"usage": {"total_tokens": 1},
}
mock_response.status_code = 200
mock_post.return_value = mock_response
model.validate_credentials(model="voyage-3", credentials={"api_key": os.environ.get("VOYAGE_API_KEY")})
def test_invoke_model():
model = VoyageTextEmbeddingModel()
with patch("requests.post") as mock_post:
mock_response = Mock()
mock_response.json.return_value = {
"object": "list",
"data": [
{"object": "embedding", "embedding": [0.23333 for _ in range(1024)], "index": 0},
{"object": "embedding", "embedding": [0.23333 for _ in range(1024)], "index": 1},
],
"model": "voyage-3",
"usage": {"total_tokens": 2},
}
mock_response.status_code = 200
mock_post.return_value = mock_response
result = model.invoke(
model="voyage-3",
credentials={
"api_key": os.environ.get("VOYAGE_API_KEY"),
},
texts=["hello", "world"],
user="abc-123",
)
assert isinstance(result, TextEmbeddingResult)
assert len(result.embeddings) == 2
assert result.usage.total_tokens == 2
def test_get_num_tokens():
model = VoyageTextEmbeddingModel()
num_tokens = model.get_num_tokens(
model="voyage-3",
credentials={
"api_key": os.environ.get("VOYAGE_API_KEY"),
},
texts=["ping"],
)
assert num_tokens == 1

View File

@ -17,7 +17,7 @@ class MockedHttp:
request = httpx.Request(
method, url, params=kwargs.get("params"), headers=kwargs.get("headers"), cookies=kwargs.get("cookies")
)
data = kwargs.get("data", None)
data = kwargs.get("data")
resp = json.dumps(data).encode("utf-8") if data else b"OK"
response = httpx.Response(
status_code=200,

View File

@ -22,8 +22,8 @@ class MockedHttp:
return response
# get data, files
data = kwargs.get("data", None)
files = kwargs.get("files", None)
data = kwargs.get("data")
files = kwargs.get("files")
if data is not None:
resp = dumps(data).encode("utf-8")
elif files is not None:

View File

@ -0,0 +1,38 @@
import pytest
from controllers.console.version import _has_new_version
@pytest.mark.parametrize(
("latest_version", "current_version", "expected"),
[
("1.0.1", "1.0.0", True),
("1.1.0", "1.0.0", True),
("2.0.0", "1.9.9", True),
("1.0.0", "1.0.0", False),
("1.0.0", "1.0.1", False),
("1.0.0", "2.0.0", False),
("1.0.1", "1.0.0-beta", True),
("1.0.0", "1.0.0-alpha", True),
("1.0.0-beta", "1.0.0-alpha", True),
("1.0.0", "1.0.0-rc1", True),
("1.0.0", "0.9.9", True),
("1.0.0", "1.0.0-dev", True),
],
)
def test_has_new_version(latest_version, current_version, expected):
assert _has_new_version(latest_version=latest_version, current_version=current_version) == expected
def test_has_new_version_invalid_input():
with pytest.raises(ValueError):
_has_new_version(latest_version="1.0", current_version="1.0.0")
with pytest.raises(ValueError):
_has_new_version(latest_version="1.0.0", current_version="1.0")
with pytest.raises(ValueError):
_has_new_version(latest_version="invalid", current_version="1.0.0")
with pytest.raises(ValueError):
_has_new_version(latest_version="1.0.0", current_version="invalid")

View File

@ -9,4 +9,5 @@ pytest api/tests/integration_tests/model_runtime/anthropic \
api/tests/integration_tests/model_runtime/upstage \
api/tests/integration_tests/model_runtime/fireworks \
api/tests/integration_tests/model_runtime/nomic \
api/tests/integration_tests/model_runtime/mixedbread
api/tests/integration_tests/model_runtime/mixedbread \
api/tests/integration_tests/model_runtime/voyage

View File

@ -2,7 +2,7 @@ version: '3'
services:
# API service
api:
image: langgenius/dify-api:0.8.3
image: langgenius/dify-api:0.9.1
restart: always
environment:
# Startup mode, 'api' starts the API server.
@ -227,7 +227,7 @@ services:
# worker service
# The Celery worker for processing the queue.
worker:
image: langgenius/dify-api:0.8.3
image: langgenius/dify-api:0.9.1
restart: always
environment:
CONSOLE_WEB_URL: ''
@ -396,7 +396,7 @@ services:
# Frontend web application.
web:
image: langgenius/dify-web:0.8.3
image: langgenius/dify-web:0.9.1
restart: always
environment:
# The base URL of console application api server, refers to the Console base URL of WEB service if console domain is

View File

@ -213,7 +213,7 @@ x-shared-env: &shared-api-worker-env
services:
# API service
api:
image: langgenius/dify-api:0.8.3
image: langgenius/dify-api:0.9.1
restart: always
environment:
# Use the shared environment variables.
@ -233,7 +233,7 @@ services:
# worker service
# The Celery worker for processing the queue.
worker:
image: langgenius/dify-api:0.8.3
image: langgenius/dify-api:0.9.1
restart: always
environment:
# Use the shared environment variables.
@ -252,7 +252,7 @@ services:
# Frontend web application.
web:
image: langgenius/dify-web:0.8.3
image: langgenius/dify-web:0.9.1
restart: always
environment:
CONSOLE_API_URL: ${CONSOLE_API_URL:-}

View File

@ -1,24 +0,0 @@
# unstructured .
# (if used, you need to set ETL_TYPE to Unstructured in the api & worker service.)
unstructured:
image: downloads.unstructured.io/unstructured-io/unstructured-api:latest
profiles:
- unstructured
restart: always
volumes:
- ./volumes/unstructured:/app/data
networks:
# create a network between sandbox, api and ssrf_proxy, and can not access outside.
ssrf_proxy_network:
driver: bridge
internal: true
milvus:
driver: bridge
opensearch-net:
driver: bridge
internal: true
volumes:
oradata:

View File

@ -3,6 +3,7 @@ import type { FC } from 'react'
import React, { useEffect, useState } from 'react'
import { useTranslation } from 'react-i18next'
import { Pagination } from 'react-headless-pagination'
import { useDebounce } from 'ahooks'
import { ArrowLeftIcon, ArrowRightIcon } from '@heroicons/react/24/outline'
import Toast from '../../base/toast'
import Filter from './filter'
@ -67,10 +68,11 @@ const Annotation: FC<Props> = ({
const [queryParams, setQueryParams] = useState<QueryParam>({})
const [currPage, setCurrPage] = React.useState<number>(0)
const debouncedQueryParams = useDebounce(queryParams, { wait: 500 })
const query = {
page: currPage + 1,
limit: APP_PAGE_LIMIT,
keyword: queryParams.keyword || '',
keyword: debouncedQueryParams.keyword || '',
}
const [controlUpdateList, setControlUpdateList] = useState(Date.now())
@ -232,8 +234,8 @@ const Annotation: FC<Props> = ({
middlePagesSiblingCount={1}
setCurrentPage={setCurrPage}
totalPages={Math.ceil(total / APP_PAGE_LIMIT)}
truncatableClassName="w-8 px-0.5 text-center"
truncatableText="..."
truncableClassName="w-8 px-0.5 text-center"
truncableText="..."
>
<Pagination.PrevButton
disabled={currPage === 0}

View File

@ -243,7 +243,7 @@ const ConfigContent: FC<Props> = ({
/>
)
}
<div className='leading-[32px] text-text-secondary system-sm-semibold ml-1'>{t('common.modelProvider.rerankModel.key')}</div>
<div className='leading-[32px] text-text-secondary system-sm-semibold'>{t('common.modelProvider.rerankModel.key')}</div>
<Tooltip
popupContent={
<div className="w-[200px]">

View File

@ -59,10 +59,10 @@ const SettingsModal: FC<SettingsModalProps> = ({
const { t } = useTranslation()
const { notify } = useToastContext()
const ref = useRef(null)
const isExternal = currentDataset.provider === 'external'
const [topK, setTopK] = useState(currentDataset?.external_retrieval_model.top_k ?? 2)
const [scoreThreshold, setScoreThreshold] = useState(currentDataset?.external_retrieval_model.score_threshold ?? 0.5)
const [scoreThresholdEnabled, setScoreThresholdEnabled] = useState(currentDataset?.external_retrieval_model.score_threshold_enabled ?? false)
const { setShowAccountSettingModal } = useModalContext()
const [loading, setLoading] = useState(false)
const { isCurrentWorkspaceDatasetOperator } = useAppContext()
@ -122,19 +122,21 @@ const SettingsModal: FC<SettingsModalProps> = ({
description,
permission,
indexing_technique: indexMethod,
external_retrieval_model: {
top_k: topK,
score_threshold: scoreThreshold,
score_threshold_enabled: scoreThresholdEnabled,
},
retrieval_model: {
...postRetrievalConfig,
score_threshold: postRetrievalConfig.score_threshold_enabled ? postRetrievalConfig.score_threshold : 0,
},
external_knowledge_id: currentDataset!.external_knowledge_info.external_knowledge_id,
external_knowledge_api_id: currentDataset!.external_knowledge_info.external_knowledge_api_id,
embedding_model: localeCurrentDataset.embedding_model,
embedding_model_provider: localeCurrentDataset.embedding_model_provider,
...(isExternal && {
external_knowledge_id: currentDataset!.external_knowledge_info.external_knowledge_id,
external_knowledge_api_id: currentDataset!.external_knowledge_info.external_knowledge_api_id,
external_retrieval_model: {
top_k: topK,
score_threshold: scoreThreshold,
score_threshold_enabled: scoreThresholdEnabled,
},
}),
},
} as any
if (permission === 'partial_members') {

View File

@ -4,6 +4,7 @@ import React, { useState } from 'react'
import useSWR from 'swr'
import { usePathname } from 'next/navigation'
import { Pagination } from 'react-headless-pagination'
import { useDebounce } from 'ahooks'
import { omit } from 'lodash-es'
import dayjs from 'dayjs'
import { ArrowLeftIcon, ArrowRightIcon } from '@heroicons/react/24/outline'
@ -59,6 +60,7 @@ const Logs: FC<ILogsProps> = ({ appDetail }) => {
sort_by: '-created_at',
})
const [currPage, setCurrPage] = React.useState<number>(0)
const debouncedQueryParams = useDebounce(queryParams, { wait: 500 })
// Get the app type first
const isChatMode = appDetail.mode !== 'completion'
@ -66,14 +68,14 @@ const Logs: FC<ILogsProps> = ({ appDetail }) => {
const query = {
page: currPage + 1,
limit: APP_PAGE_LIMIT,
...(queryParams.period !== 'all'
...(debouncedQueryParams.period !== 'all'
? {
start: dayjs().subtract(queryParams.period as number, 'day').startOf('day').format('YYYY-MM-DD HH:mm'),
start: dayjs().subtract(debouncedQueryParams.period as number, 'day').startOf('day').format('YYYY-MM-DD HH:mm'),
end: dayjs().endOf('day').format('YYYY-MM-DD HH:mm'),
}
: {}),
...(isChatMode ? { sort_by: queryParams.sort_by } : {}),
...omit(queryParams, ['period']),
...(isChatMode ? { sort_by: debouncedQueryParams.sort_by } : {}),
...omit(debouncedQueryParams, ['period']),
}
const getWebAppType = (appType: AppMode) => {
@ -119,8 +121,8 @@ const Logs: FC<ILogsProps> = ({ appDetail }) => {
middlePagesSiblingCount={1}
setCurrentPage={setCurrPage}
totalPages={Math.ceil(total / APP_PAGE_LIMIT)}
truncatableClassName="w-8 px-0.5 text-center"
truncatableText="..."
truncableClassName="w-8 px-0.5 text-center"
truncableText="..."
>
<Pagination.PrevButton
disabled={currPage === 0}

View File

@ -4,6 +4,7 @@ import React, { useState } from 'react'
import useSWR from 'swr'
import { usePathname } from 'next/navigation'
import { Pagination } from 'react-headless-pagination'
import { useDebounce } from 'ahooks'
import { ArrowLeftIcon, ArrowRightIcon } from '@heroicons/react/24/outline'
import { Trans, useTranslation } from 'react-i18next'
import Link from 'next/link'
@ -51,12 +52,13 @@ const Logs: FC<ILogsProps> = ({ appDetail }) => {
const { t } = useTranslation()
const [queryParams, setQueryParams] = useState<QueryParam>({ status: 'all' })
const [currPage, setCurrPage] = React.useState<number>(0)
const debouncedQueryParams = useDebounce(queryParams, { wait: 500 })
const query = {
page: currPage + 1,
limit: APP_PAGE_LIMIT,
...(queryParams.status !== 'all' ? { status: queryParams.status } : {}),
...(queryParams.keyword ? { keyword: queryParams.keyword } : {}),
...(debouncedQueryParams.status !== 'all' ? { status: debouncedQueryParams.status } : {}),
...(debouncedQueryParams.keyword ? { keyword: debouncedQueryParams.keyword } : {}),
}
const getWebAppType = (appType: AppMode) => {
@ -93,8 +95,8 @@ const Logs: FC<ILogsProps> = ({ appDetail }) => {
middlePagesSiblingCount={1}
setCurrentPage={setCurrPage}
totalPages={Math.ceil(total / APP_PAGE_LIMIT)}
truncatableClassName="w-8 px-0.5 text-center"
truncatableText="..."
truncableClassName="w-8 px-0.5 text-center"
truncableText="..."
>
<Pagination.PrevButton
disabled={currPage === 0}

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.7 KiB

View File

@ -11,7 +11,7 @@ import { DataSourceType } from '@/models/datasets'
import type { CrawlOptions, CrawlResultItem, DataSet, FileItem, createDocumentResponse } from '@/models/datasets'
import { fetchDataSource } from '@/service/common'
import { fetchDatasetDetail } from '@/service/datasets'
import type { NotionPage } from '@/models/common'
import { DataSourceProvider, type NotionPage } from '@/models/common'
import { useModalContext } from '@/context/modal-context'
import { useDefaultModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
@ -26,6 +26,7 @@ const DEFAULT_CRAWL_OPTIONS: CrawlOptions = {
excludes: '',
limit: 10,
max_depth: '',
use_sitemap: true,
}
const DatasetUpdateForm = ({ datasetId }: DatasetUpdateFormProps) => {
@ -51,7 +52,8 @@ const DatasetUpdateForm = ({ datasetId }: DatasetUpdateFormProps) => {
const updateFileList = (preparedFiles: FileItem[]) => {
setFiles(preparedFiles)
}
const [fireCrawlJobId, setFireCrawlJobId] = useState('')
const [websiteCrawlProvider, setWebsiteCrawlProvider] = useState<DataSourceProvider>(DataSourceProvider.fireCrawl)
const [websiteCrawlJobId, setWebsiteCrawlJobId] = useState('')
const updateFile = (fileItem: FileItem, progress: number, list: FileItem[]) => {
const targetIndex = list.findIndex(file => file.fileID === fileItem.fileID)
@ -137,7 +139,8 @@ const DatasetUpdateForm = ({ datasetId }: DatasetUpdateFormProps) => {
onStepChange={nextStep}
websitePages={websitePages}
updateWebsitePages={setWebsitePages}
onFireCrawlJobIdChange={setFireCrawlJobId}
onWebsiteCrawlProviderChange={setWebsiteCrawlProvider}
onWebsiteCrawlJobIdChange={setWebsiteCrawlJobId}
crawlOptions={crawlOptions}
onCrawlOptionsChange={setCrawlOptions}
/>
@ -151,7 +154,8 @@ const DatasetUpdateForm = ({ datasetId }: DatasetUpdateFormProps) => {
files={fileList.map(file => file.file)}
notionPages={notionPages}
websitePages={websitePages}
fireCrawlJobId={fireCrawlJobId}
websiteCrawlProvider={websiteCrawlProvider}
websiteCrawlJobId={websiteCrawlJobId}
onStepChange={changeStep}
updateIndexingTypeCache={updateIndexingTypeCache}
updateResultCache={updateResultCache}

View File

@ -10,7 +10,7 @@ import WebsitePreview from '../website/preview'
import s from './index.module.css'
import cn from '@/utils/classnames'
import type { CrawlOptions, CrawlResultItem, FileItem } from '@/models/datasets'
import type { NotionPage } from '@/models/common'
import type { DataSourceProvider, NotionPage } from '@/models/common'
import { DataSourceType } from '@/models/datasets'
import Button from '@/app/components/base/button'
import { NotionPageSelector } from '@/app/components/base/notion-page-selector'
@ -33,7 +33,8 @@ type IStepOneProps = {
changeType: (type: DataSourceType) => void
websitePages?: CrawlResultItem[]
updateWebsitePages: (value: CrawlResultItem[]) => void
onFireCrawlJobIdChange: (jobId: string) => void
onWebsiteCrawlProviderChange: (provider: DataSourceProvider) => void
onWebsiteCrawlJobIdChange: (jobId: string) => void
crawlOptions: CrawlOptions
onCrawlOptionsChange: (payload: CrawlOptions) => void
}
@ -69,7 +70,8 @@ const StepOne = ({
updateNotionPages,
websitePages = [],
updateWebsitePages,
onFireCrawlJobIdChange,
onWebsiteCrawlProviderChange,
onWebsiteCrawlJobIdChange,
crawlOptions,
onCrawlOptionsChange,
}: IStepOneProps) => {
@ -229,7 +231,8 @@ const StepOne = ({
onPreview={setCurrentWebsite}
checkedCrawlResult={websitePages}
onCheckedCrawlResultChange={updateWebsitePages}
onJobIdChange={onFireCrawlJobIdChange}
onCrawlProviderChange={onWebsiteCrawlProviderChange}
onJobIdChange={onWebsiteCrawlJobIdChange}
crawlOptions={crawlOptions}
onCrawlOptionsChange={onCrawlOptionsChange}
/>

View File

@ -33,6 +33,7 @@ import { ensureRerankModelSelected, isReRankModelSelected } from '@/app/componen
import Toast from '@/app/components/base/toast'
import { formatNumber } from '@/utils/format'
import type { NotionPage } from '@/models/common'
import { DataSourceProvider } from '@/models/common'
import { DataSourceType, DocForm } from '@/models/datasets'
import NotionIcon from '@/app/components/base/notion-icon'
import Switch from '@/app/components/base/switch'
@ -63,7 +64,8 @@ type StepTwoProps = {
notionPages?: NotionPage[]
websitePages?: CrawlResultItem[]
crawlOptions?: CrawlOptions
fireCrawlJobId?: string
websiteCrawlProvider?: DataSourceProvider
websiteCrawlJobId?: string
onStepChange?: (delta: number) => void
updateIndexingTypeCache?: (type: string) => void
updateResultCache?: (res: createDocumentResponse) => void
@ -94,7 +96,8 @@ const StepTwo = ({
notionPages = [],
websitePages = [],
crawlOptions,
fireCrawlJobId = '',
websiteCrawlProvider = DataSourceProvider.fireCrawl,
websiteCrawlJobId = '',
onStepChange,
updateIndexingTypeCache,
updateResultCache,
@ -260,8 +263,8 @@ const StepTwo = ({
const getWebsiteInfo = () => {
return {
provider: 'firecrawl',
job_id: fireCrawlJobId,
provider: websiteCrawlProvider,
job_id: websiteCrawlJobId,
urls: websitePages.map(page => page.source_url),
only_main_content: crawlOptions?.only_main_content,
}

View File

@ -3,6 +3,7 @@ import type { FC } from 'react'
import React from 'react'
import cn from '@/utils/classnames'
import Checkbox from '@/app/components/base/checkbox'
import Tooltip from '@/app/components/base/tooltip'
type Props = {
className?: string
@ -10,6 +11,7 @@ type Props = {
onChange: (isChecked: boolean) => void
label: string
labelClassName?: string
tooltip?: string
}
const CheckboxWithLabel: FC<Props> = ({
@ -18,11 +20,20 @@ const CheckboxWithLabel: FC<Props> = ({
onChange,
label,
labelClassName,
tooltip,
}) => {
return (
<label className={cn(className, 'flex items-center h-7 space-x-2')}>
<Checkbox checked={isChecked} onCheck={() => onChange(!isChecked)} />
<div className={cn(labelClassName, 'text-sm font-normal text-gray-800')}>{label}</div>
{tooltip && (
<Tooltip
popupContent={
<div className='w-[200px]'>{tooltip}</div>
}
triggerClassName='ml-0.5 w-4 h-4'
/>
)}
</label>
)
}

Some files were not shown because too many files have changed in this diff Show More