mirror of
https://github.com/langgenius/dify.git
synced 2026-05-02 16:38:04 +08:00
Merge branch 'refs/heads/main' into feat/workflow-parallel-support
# Conflicts: # api/core/app/apps/advanced_chat/app_generator.py # api/core/app/apps/advanced_chat/generate_task_pipeline.py # api/core/app/apps/workflow/app_runner.py # api/core/app/apps/workflow/generate_task_pipeline.py # api/core/app/task_pipeline/workflow_cycle_state_manager.py # api/core/workflow/entities/variable_pool.py # api/core/workflow/nodes/code/code_node.py # api/core/workflow/nodes/llm/llm_node.py # api/core/workflow/nodes/start/start_node.py # api/core/workflow/nodes/variable_assigner/__init__.py # api/tests/integration_tests/workflow/nodes/test_llm.py # api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py # api/tests/unit_tests/core/workflow/nodes/test_answer.py # api/tests/unit_tests/core/workflow/nodes/test_if_else.py # api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py
This commit is contained in:
@ -1,13 +1,13 @@
|
||||
from core.app.segments import SecretVariable, StringSegment, parser
|
||||
from core.helper import encrypter
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariable
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
|
||||
|
||||
def test_segment_group_to_text():
|
||||
variable_pool = VariablePool(
|
||||
system_variables={
|
||||
SystemVariable('user_id'): 'fake-user-id',
|
||||
SystemVariableKey('user_id'): 'fake-user-id',
|
||||
},
|
||||
user_inputs={},
|
||||
environment_variables=[
|
||||
@ -42,7 +42,7 @@ def test_convert_constant_to_segment_group():
|
||||
def test_convert_variable_to_segment_group():
|
||||
variable_pool = VariablePool(
|
||||
system_variables={
|
||||
SystemVariable('user_id'): 'fake-user-id',
|
||||
SystemVariableKey('user_id'): 'fake-user-id',
|
||||
},
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
|
||||
0
api/tests/unit_tests/core/model_runtime/__init__.py
Normal file
0
api/tests/unit_tests/core/model_runtime/__init__.py
Normal file
@ -0,0 +1,75 @@
|
||||
import numpy as np
|
||||
|
||||
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
|
||||
from core.model_runtime.model_providers.wenxin.text_embedding.text_embedding import (
|
||||
TextEmbedding,
|
||||
WenxinTextEmbeddingModel,
|
||||
)
|
||||
|
||||
|
||||
def test_max_chunks():
|
||||
class _MockTextEmbedding(TextEmbedding):
|
||||
def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int):
|
||||
embeddings = [[1.0, 2.0, 3.0] for i in range(len(texts))]
|
||||
tokens = 0
|
||||
for text in texts:
|
||||
tokens += len(text)
|
||||
|
||||
return embeddings, tokens, tokens
|
||||
|
||||
def _create_text_embedding(api_key: str, secret_key: str) -> TextEmbedding:
|
||||
return _MockTextEmbedding()
|
||||
|
||||
model = 'embedding-v1'
|
||||
credentials = {
|
||||
'api_key': 'xxxx',
|
||||
'secret_key': 'yyyy',
|
||||
}
|
||||
embedding_model = WenxinTextEmbeddingModel()
|
||||
context_size = embedding_model._get_context_size(model, credentials)
|
||||
max_chunks = embedding_model._get_max_chunks(model, credentials)
|
||||
embedding_model._create_text_embedding = _create_text_embedding
|
||||
|
||||
texts = ['0123456789' for i in range(0, max_chunks * 2)]
|
||||
result: TextEmbeddingResult = embedding_model.invoke(model, credentials, texts, 'test')
|
||||
assert len(result.embeddings) == max_chunks * 2
|
||||
|
||||
|
||||
def test_context_size():
|
||||
def get_num_tokens_by_gpt2(text: str) -> int:
|
||||
return GPT2Tokenizer.get_num_tokens(text)
|
||||
|
||||
def mock_text(token_size: int) -> str:
|
||||
_text = "".join(['0' for i in range(token_size)])
|
||||
num_tokens = get_num_tokens_by_gpt2(_text)
|
||||
ratio = int(np.floor(len(_text) / num_tokens))
|
||||
m_text = "".join([_text for i in range(ratio)])
|
||||
return m_text
|
||||
|
||||
model = 'embedding-v1'
|
||||
credentials = {
|
||||
'api_key': 'xxxx',
|
||||
'secret_key': 'yyyy',
|
||||
}
|
||||
embedding_model = WenxinTextEmbeddingModel()
|
||||
context_size = embedding_model._get_context_size(model, credentials)
|
||||
|
||||
class _MockTextEmbedding(TextEmbedding):
|
||||
def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int):
|
||||
embeddings = [[1.0, 2.0, 3.0] for i in range(len(texts))]
|
||||
tokens = 0
|
||||
for text in texts:
|
||||
tokens += get_num_tokens_by_gpt2(text)
|
||||
return embeddings, tokens, tokens
|
||||
|
||||
def _create_text_embedding(api_key: str, secret_key: str) -> TextEmbedding:
|
||||
return _MockTextEmbedding()
|
||||
|
||||
embedding_model._create_text_embedding = _create_text_embedding
|
||||
text = mock_text(context_size * 2)
|
||||
assert get_num_tokens_by_gpt2(text) == context_size * 2
|
||||
|
||||
texts = [text]
|
||||
result: TextEmbeddingResult = embedding_model.invoke(model, credentials, texts, 'test')
|
||||
assert result.usage.tokens == context_size
|
||||
@ -5,7 +5,7 @@ from unittest.mock import MagicMock
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.node_entities import UserFrom
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariable
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
@ -60,8 +60,8 @@ def test_execute_answer():
|
||||
# construct variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables={
|
||||
SystemVariable.FILES: [],
|
||||
SystemVariable.USER_ID: 'aaa'
|
||||
SystemVariableKey.FILES: [],
|
||||
SystemVariableKey.USER_ID: 'aaa'
|
||||
},
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
@ -88,6 +88,14 @@ def test_execute_answer():
|
||||
}
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
pool = VariablePool(system_variables={
|
||||
SystemVariable.FILES: [],
|
||||
SystemVariable.USER_ID: 'aaa'
|
||||
}, user_inputs={}, environment_variables=[])
|
||||
pool.add(['start', 'weather'], 'sunny')
|
||||
pool.add(['llm', 'text'], 'You are a helpful AI.')
|
||||
|
||||
# Mock db.session.close()
|
||||
db.session.close = MagicMock()
|
||||
|
||||
|
||||
@ -5,7 +5,7 @@ from unittest.mock import MagicMock
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.node_entities import UserFrom
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariable
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
@ -45,8 +45,8 @@ def test_execute_if_else_result_true():
|
||||
|
||||
# construct variable pool
|
||||
pool = VariablePool(system_variables={
|
||||
SystemVariable.FILES: [],
|
||||
SystemVariable.USER_ID: 'aaa'
|
||||
SystemVariableKey.FILES: [],
|
||||
SystemVariableKey.USER_ID: 'aaa'
|
||||
}, user_inputs={})
|
||||
pool.add(['start', 'array_contains'], ['ab', 'def'])
|
||||
pool.add(['start', 'array_not_contains'], ['ac', 'def'])
|
||||
@ -176,8 +176,8 @@ def test_execute_if_else_result_true():
|
||||
|
||||
# construct variable pool
|
||||
pool = VariablePool(system_variables={
|
||||
SystemVariable.FILES: [],
|
||||
SystemVariable.USER_ID: 'aaa'
|
||||
SystemVariableKey.FILES: [],
|
||||
SystemVariableKey.USER_ID: 'aaa'
|
||||
}, user_inputs={}, environment_variables=[])
|
||||
pool.add(['start', 'array_contains'], ['ab', 'def'])
|
||||
pool.add(['start', 'array_not_contains'], ['ac', 'def'])
|
||||
@ -250,8 +250,8 @@ def test_execute_if_else_result_false():
|
||||
|
||||
# construct variable pool
|
||||
pool = VariablePool(system_variables={
|
||||
SystemVariable.FILES: [],
|
||||
SystemVariable.USER_ID: 'aaa'
|
||||
SystemVariableKey.FILES: [],
|
||||
SystemVariableKey.USER_ID: 'aaa'
|
||||
}, user_inputs={}, environment_variables=[])
|
||||
pool.add(['start', 'array_contains'], ['1ab', 'def'])
|
||||
pool.add(['start', 'array_not_contains'], ['ab', 'def'])
|
||||
|
||||
@ -7,7 +7,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.segments import ArrayStringVariable, StringVariable
|
||||
from core.workflow.entities.node_entities import UserFrom
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariable
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
@ -72,7 +72,7 @@ def test_overwrite_string_variable():
|
||||
|
||||
# construct variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables={SystemVariable.CONVERSATION_ID: 'conversation_id'},
|
||||
system_variables={SystemVariableKey.CONVERSATION_ID: 'conversation_id'},
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[conversation_variable],
|
||||
@ -101,7 +101,7 @@ def test_overwrite_string_variable():
|
||||
},
|
||||
)
|
||||
|
||||
with mock.patch('core.workflow.nodes.variable_assigner.update_conversation_variable') as mock_run:
|
||||
with mock.patch('core.workflow.nodes.variable_assigner.node.update_conversation_variable') as mock_run:
|
||||
list(node.run())
|
||||
mock_run.assert_called_once()
|
||||
|
||||
@ -165,7 +165,7 @@ def test_append_variable_to_array():
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables={SystemVariable.CONVERSATION_ID: 'conversation_id'},
|
||||
system_variables={SystemVariableKey.CONVERSATION_ID: 'conversation_id'},
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[conversation_variable],
|
||||
@ -193,7 +193,7 @@ def test_append_variable_to_array():
|
||||
},
|
||||
)
|
||||
|
||||
with mock.patch('core.workflow.nodes.variable_assigner.update_conversation_variable') as mock_run:
|
||||
with mock.patch('core.workflow.nodes.variable_assigner.node.update_conversation_variable') as mock_run:
|
||||
list(node.run())
|
||||
mock_run.assert_called_once()
|
||||
|
||||
@ -250,7 +250,7 @@ def test_clear_array():
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables={SystemVariable.CONVERSATION_ID: 'conversation_id'},
|
||||
system_variables={SystemVariableKey.CONVERSATION_ID: 'conversation_id'},
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[conversation_variable],
|
||||
|
||||
@ -11,7 +11,17 @@ def test_environment_variables():
|
||||
contexts.tenant_id.set('tenant_id')
|
||||
|
||||
# Create a Workflow instance
|
||||
workflow = Workflow()
|
||||
workflow = Workflow(
|
||||
tenant_id='tenant_id',
|
||||
app_id='app_id',
|
||||
type='workflow',
|
||||
version='draft',
|
||||
graph='{}',
|
||||
features='{}',
|
||||
created_by='account_id',
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
|
||||
# Create some EnvironmentVariable instances
|
||||
variable1 = StringVariable.model_validate({'name': 'var1', 'value': 'value1', 'id': str(uuid4())})
|
||||
@ -35,7 +45,17 @@ def test_update_environment_variables():
|
||||
contexts.tenant_id.set('tenant_id')
|
||||
|
||||
# Create a Workflow instance
|
||||
workflow = Workflow()
|
||||
workflow = Workflow(
|
||||
tenant_id='tenant_id',
|
||||
app_id='app_id',
|
||||
type='workflow',
|
||||
version='draft',
|
||||
graph='{}',
|
||||
features='{}',
|
||||
created_by='account_id',
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
|
||||
# Create some EnvironmentVariable instances
|
||||
variable1 = StringVariable.model_validate({'name': 'var1', 'value': 'value1', 'id': str(uuid4())})
|
||||
@ -70,9 +90,17 @@ def test_to_dict():
|
||||
contexts.tenant_id.set('tenant_id')
|
||||
|
||||
# Create a Workflow instance
|
||||
workflow = Workflow()
|
||||
workflow.graph = '{}'
|
||||
workflow.features = '{}'
|
||||
workflow = Workflow(
|
||||
tenant_id='tenant_id',
|
||||
app_id='app_id',
|
||||
type='workflow',
|
||||
version='draft',
|
||||
graph='{}',
|
||||
features='{}',
|
||||
created_by='account_id',
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
|
||||
# Create some EnvironmentVariable instances
|
||||
|
||||
|
||||
@ -14,6 +14,7 @@ from core.app.app_config.entities import (
|
||||
ModelConfigEntity,
|
||||
PromptTemplateEntity,
|
||||
VariableEntity,
|
||||
VariableEntityType,
|
||||
)
|
||||
from core.helper import encrypter
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
@ -25,23 +26,24 @@ from services.workflow.workflow_converter import WorkflowConverter
|
||||
|
||||
@pytest.fixture
|
||||
def default_variables():
|
||||
return [
|
||||
value = [
|
||||
VariableEntity(
|
||||
variable="text_input",
|
||||
label="text-input",
|
||||
type=VariableEntity.Type.TEXT_INPUT
|
||||
type=VariableEntityType.TEXT_INPUT,
|
||||
),
|
||||
VariableEntity(
|
||||
variable="paragraph",
|
||||
label="paragraph",
|
||||
type=VariableEntity.Type.PARAGRAPH
|
||||
type=VariableEntityType.PARAGRAPH,
|
||||
),
|
||||
VariableEntity(
|
||||
variable="select",
|
||||
label="select",
|
||||
type=VariableEntity.Type.SELECT
|
||||
)
|
||||
type=VariableEntityType.SELECT,
|
||||
),
|
||||
]
|
||||
return value
|
||||
|
||||
|
||||
def test__convert_to_start_node(default_variables):
|
||||
|
||||
@ -2,7 +2,7 @@ from textwrap import dedent
|
||||
|
||||
import pytest
|
||||
|
||||
from core.helper.position_helper import get_position_map
|
||||
from core.helper.position_helper import get_position_map, is_filtered, pin_position_map, sort_by_position_map
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -14,7 +14,7 @@ def prepare_example_positions_yaml(tmp_path, monkeypatch) -> str:
|
||||
- second
|
||||
# - commented
|
||||
- third
|
||||
|
||||
|
||||
- 9999999999999
|
||||
- forth
|
||||
"""))
|
||||
@ -28,9 +28,9 @@ def prepare_empty_commented_positions_yaml(tmp_path, monkeypatch) -> str:
|
||||
"""\
|
||||
# - commented1
|
||||
# - commented2
|
||||
-
|
||||
-
|
||||
|
||||
-
|
||||
-
|
||||
|
||||
"""))
|
||||
return str(tmp_path)
|
||||
|
||||
@ -53,3 +53,79 @@ def test_position_helper_with_all_commented(prepare_empty_commented_positions_ya
|
||||
folder_path=prepare_empty_commented_positions_yaml,
|
||||
file_name='example_positions_all_commented.yaml')
|
||||
assert position_map == {}
|
||||
|
||||
|
||||
def test_excluded_position_data(prepare_example_positions_yaml):
|
||||
position_map = get_position_map(
|
||||
folder_path=prepare_example_positions_yaml,
|
||||
file_name='example_positions.yaml'
|
||||
)
|
||||
pin_list = ['forth', 'first']
|
||||
include_set = set()
|
||||
exclude_set = {'9999999999999'}
|
||||
|
||||
position_map = pin_position_map(
|
||||
original_position_map=position_map,
|
||||
pin_list=pin_list
|
||||
)
|
||||
|
||||
data = [
|
||||
"forth",
|
||||
"first",
|
||||
"second",
|
||||
"third",
|
||||
"9999999999999",
|
||||
"extra1",
|
||||
"extra2",
|
||||
]
|
||||
|
||||
# filter out the data
|
||||
data = [item for item in data if not is_filtered(include_set, exclude_set, item, lambda x: x)]
|
||||
|
||||
# sort data by position map
|
||||
sorted_data = sort_by_position_map(
|
||||
position_map=position_map,
|
||||
data=data,
|
||||
name_func=lambda x: x,
|
||||
)
|
||||
|
||||
# assert the result in the correct order
|
||||
assert sorted_data == ['forth', 'first', 'second', 'third', 'extra1', 'extra2']
|
||||
|
||||
|
||||
def test_included_position_data(prepare_example_positions_yaml):
|
||||
position_map = get_position_map(
|
||||
folder_path=prepare_example_positions_yaml,
|
||||
file_name='example_positions.yaml'
|
||||
)
|
||||
pin_list = ['forth', 'first']
|
||||
include_set = {'forth', 'first'}
|
||||
exclude_set = {}
|
||||
|
||||
position_map = pin_position_map(
|
||||
original_position_map=position_map,
|
||||
pin_list=pin_list
|
||||
)
|
||||
|
||||
data = [
|
||||
"forth",
|
||||
"first",
|
||||
"second",
|
||||
"third",
|
||||
"9999999999999",
|
||||
"extra1",
|
||||
"extra2",
|
||||
]
|
||||
|
||||
# filter out the data
|
||||
data = [item for item in data if not is_filtered(include_set, exclude_set, item, lambda x: x)]
|
||||
|
||||
# sort data by position map
|
||||
sorted_data = sort_by_position_map(
|
||||
position_map=position_map,
|
||||
data=data,
|
||||
name_func=lambda x: x,
|
||||
)
|
||||
|
||||
# assert the result in the correct order
|
||||
assert sorted_data == ['forth', 'first']
|
||||
|
||||
Reference in New Issue
Block a user