Merge branch 'refs/heads/main' into feat/workflow-parallel-support

# Conflicts:
#	api/core/workflow/workflow_entry.py
This commit is contained in:
takatost
2024-07-24 23:43:14 +08:00
128 changed files with 4008 additions and 1419 deletions

View File

@ -62,7 +62,12 @@ class DatasetConfigManager:
return None
# dataset configs
dataset_configs = config.get('dataset_configs', {'retrieval_model': 'single'})
if 'dataset_configs' in config and config.get('dataset_configs'):
dataset_configs = config.get('dataset_configs')
else:
dataset_configs = {
'retrieval_model': 'multiple'
}
query_variable = config.get('dataset_query_variable')
if dataset_configs['retrieval_model'] == 'single':
@ -83,9 +88,10 @@ class DatasetConfigManager:
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
dataset_configs['retrieval_model']
),
top_k=dataset_configs.get('top_k'),
top_k=dataset_configs.get('top_k', 4),
score_threshold=dataset_configs.get('score_threshold'),
reranking_model=dataset_configs.get('reranking_model')
reranking_model=dataset_configs.get('reranking_model'),
weights=dataset_configs.get('weights')
)
)
@ -114,12 +120,6 @@ class DatasetConfigManager:
if not isinstance(config["dataset_configs"], dict):
raise ValueError("dataset_configs must be of object type")
if config["dataset_configs"]['retrieval_model'] == 'multiple':
if not config["dataset_configs"]['reranking_model']:
raise ValueError("reranking_model has not been set")
if not isinstance(config["dataset_configs"]['reranking_model'], dict):
raise ValueError("reranking_model must be of object type")
if not isinstance(config["dataset_configs"], dict):
raise ValueError("dataset_configs must be of object type")

View File

@ -159,7 +159,11 @@ class DatasetRetrieveConfigEntity(BaseModel):
retrieve_strategy: RetrieveStrategy
top_k: Optional[int] = None
score_threshold: Optional[float] = None
rerank_mode: Optional[str] = 'reranking_model'
reranking_model: Optional[dict] = None
weights: Optional[dict] = None
class DatasetEntity(BaseModel):

View File

@ -1,11 +1,12 @@
from .segment_group import SegmentGroup
from .segments import Segment
from .segments import NoneSegment, Segment
from .types import SegmentType
from .variables import (
ArrayVariable,
FileVariable,
FloatVariable,
IntegerVariable,
NoneVariable,
ObjectVariable,
SecretVariable,
StringVariable,
@ -23,5 +24,7 @@ __all__ = [
'Variable',
'SegmentType',
'SegmentGroup',
'Segment'
'Segment',
'NoneSegment',
'NoneVariable',
]

View File

@ -10,6 +10,7 @@ from .variables import (
FileVariable,
FloatVariable,
IntegerVariable,
NoneVariable,
ObjectVariable,
SecretVariable,
StringVariable,
@ -39,6 +40,8 @@ def build_variable_from_mapping(m: Mapping[str, Any], /) -> Variable:
def build_anonymous_variable(value: Any, /) -> Variable:
if value is None:
return NoneVariable(name='anonymous')
if isinstance(value, str):
return StringVariable(name='anonymous', value=value)
if isinstance(value, int):

View File

@ -43,6 +43,23 @@ class Segment(BaseModel):
return self.value
class NoneSegment(Segment):
value_type: SegmentType = SegmentType.NONE
value: None = None
@property
def text(self) -> str:
return 'null'
@property
def log(self) -> str:
return 'null'
@property
def markdown(self) -> str:
return 'null'
class StringSegment(Segment):
value_type: SegmentType = SegmentType.STRING
value: str

View File

@ -2,16 +2,10 @@ from enum import Enum
class SegmentType(str, Enum):
STRING = 'string'
NONE = 'none'
NUMBER = 'number'
FILE = 'file'
STRING = 'string'
SECRET = 'secret'
OBJECT = 'object'
ARRAY = 'array'
ARRAY_STRING = 'array[string]'
ARRAY_NUMBER = 'array[number]'
ARRAY_OBJECT = 'array[object]'
ARRAY_FILE = 'array[file]'
OBJECT = 'object'
FILE = 'file'

View File

@ -6,7 +6,7 @@ from pydantic import Field
from core.file.file_obj import FileVar
from core.helper import encrypter
from .segments import Segment, StringSegment
from .segments import NoneSegment, Segment, StringSegment
from .types import SegmentType
@ -20,6 +20,7 @@ class Variable(Segment):
description="Unique identity for variable. It's only used by environment variables now.",
)
name: str
description: str = Field(default='', description='Description of the variable.')
class StringVariable(StringSegment, Variable):
@ -81,3 +82,8 @@ class SecretVariable(StringVariable):
@property
def log(self) -> str:
return encrypter.obfuscated_token(self.value)
class NoneVariable(NoneSegment, Variable):
value_type: SegmentType = SegmentType.NONE
value: None = None