mirror of
https://github.com/langgenius/dify.git
synced 2026-05-01 07:58:02 +08:00
Merge branch 'refs/heads/main' into feat/workflow-parallel-support
# Conflicts: # api/core/app/apps/advanced_chat/app_generator.py # api/core/app/apps/advanced_chat/app_runner.py # api/core/app/apps/advanced_chat/generate_task_pipeline.py # api/core/app/apps/workflow/app_runner.py # api/core/app/task_pipeline/workflow_cycle_manage.py # api/core/workflow/entities/variable_pool.py # api/core/workflow/nodes/base_node.py # api/core/workflow/workflow_engine_manager.py
This commit is contained in:
@ -79,4 +79,7 @@ CODE_EXECUTION_API_KEY=
|
||||
VOLC_API_KEY=
|
||||
VOLC_SECRET_KEY=
|
||||
VOLC_MODEL_ENDPOINT_ID=
|
||||
VOLC_EMBEDDING_ENDPOINT_ID=
|
||||
VOLC_EMBEDDING_ENDPOINT_ID=
|
||||
|
||||
# 360 AI Credentials
|
||||
ZHINAO_API_KEY=
|
||||
|
||||
@ -0,0 +1,94 @@
|
||||
|
||||
from api.core.model_runtime.model_providers.huggingface_tei.tei_helper import TeiModelExtraParameter
|
||||
|
||||
|
||||
class MockTEIClass:
|
||||
@staticmethod
|
||||
def get_tei_extra_parameter(server_url: str, model_name: str) -> TeiModelExtraParameter:
|
||||
# During mock, we don't have a real server to query, so we just return a dummy value
|
||||
if 'rerank' in model_name:
|
||||
model_type = 'reranker'
|
||||
else:
|
||||
model_type = 'embedding'
|
||||
|
||||
return TeiModelExtraParameter(model_type=model_type, max_input_length=512, max_client_batch_size=1)
|
||||
|
||||
@staticmethod
|
||||
def invoke_tokenize(server_url: str, texts: list[str]) -> list[list[dict]]:
|
||||
# Use space as token separator, and split the text into tokens
|
||||
tokenized_texts = []
|
||||
for text in texts:
|
||||
tokens = text.split(' ')
|
||||
current_index = 0
|
||||
tokenized_text = []
|
||||
for idx, token in enumerate(tokens):
|
||||
s_token = {
|
||||
'id': idx,
|
||||
'text': token,
|
||||
'special': False,
|
||||
'start': current_index,
|
||||
'stop': current_index + len(token),
|
||||
}
|
||||
current_index += len(token) + 1
|
||||
tokenized_text.append(s_token)
|
||||
tokenized_texts.append(tokenized_text)
|
||||
return tokenized_texts
|
||||
|
||||
@staticmethod
|
||||
def invoke_embeddings(server_url: str, texts: list[str]) -> dict:
|
||||
# {
|
||||
# "object": "list",
|
||||
# "data": [
|
||||
# {
|
||||
# "object": "embedding",
|
||||
# "embedding": [...],
|
||||
# "index": 0
|
||||
# }
|
||||
# ],
|
||||
# "model": "MODEL_NAME",
|
||||
# "usage": {
|
||||
# "prompt_tokens": 3,
|
||||
# "total_tokens": 3
|
||||
# }
|
||||
# }
|
||||
embeddings = []
|
||||
for idx, text in enumerate(texts):
|
||||
embedding = [0.1] * 768
|
||||
embeddings.append(
|
||||
{
|
||||
'object': 'embedding',
|
||||
'embedding': embedding,
|
||||
'index': idx,
|
||||
}
|
||||
)
|
||||
return {
|
||||
'object': 'list',
|
||||
'data': embeddings,
|
||||
'model': 'MODEL_NAME',
|
||||
'usage': {
|
||||
'prompt_tokens': sum(len(text.split(' ')) for text in texts),
|
||||
'total_tokens': sum(len(text.split(' ')) for text in texts),
|
||||
},
|
||||
}
|
||||
|
||||
def invoke_rerank(server_url: str, query: str, texts: list[str]) -> list[dict]:
|
||||
# Example response:
|
||||
# [
|
||||
# {
|
||||
# "index": 0,
|
||||
# "text": "Deep Learning is ...",
|
||||
# "score": 0.9950755
|
||||
# }
|
||||
# ]
|
||||
reranked_docs = []
|
||||
for idx, text in enumerate(texts):
|
||||
reranked_docs.append(
|
||||
{
|
||||
'index': idx,
|
||||
'text': text,
|
||||
'score': 0.9,
|
||||
}
|
||||
)
|
||||
# For mock, only return the first document
|
||||
break
|
||||
return reranked_docs
|
||||
@ -106,7 +106,7 @@ class MockXinferenceClass:
|
||||
def _check_cluster_authenticated(self):
|
||||
self._cluster_authed = True
|
||||
|
||||
def rerank(self: RESTfulRerankModelHandle, documents: list[str], query: str, top_n: int) -> dict:
|
||||
def rerank(self: RESTfulRerankModelHandle, documents: list[str], query: str, top_n: int, return_documents: bool) -> dict:
|
||||
# check if self._model_uid is a valid uuid
|
||||
if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', self._model_uid) and \
|
||||
self._model_uid != 'rerank':
|
||||
|
||||
@ -0,0 +1,72 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from api.core.model_runtime.model_providers.huggingface_tei.text_embedding.text_embedding import TeiHelper
|
||||
|
||||
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.huggingface_tei.text_embedding.text_embedding import (
|
||||
HuggingfaceTeiTextEmbeddingModel,
|
||||
)
|
||||
from tests.integration_tests.model_runtime.__mock.huggingface_tei import MockTEIClass
|
||||
|
||||
MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_tei_mock(request, monkeypatch: pytest.MonkeyPatch):
|
||||
if MOCK:
|
||||
monkeypatch.setattr(TeiHelper, 'get_tei_extra_parameter', MockTEIClass.get_tei_extra_parameter)
|
||||
monkeypatch.setattr(TeiHelper, 'invoke_tokenize', MockTEIClass.invoke_tokenize)
|
||||
monkeypatch.setattr(TeiHelper, 'invoke_embeddings', MockTEIClass.invoke_embeddings)
|
||||
monkeypatch.setattr(TeiHelper, 'invoke_rerank', MockTEIClass.invoke_rerank)
|
||||
yield
|
||||
|
||||
if MOCK:
|
||||
monkeypatch.undo()
|
||||
|
||||
|
||||
@pytest.mark.parametrize('setup_tei_mock', [['none']], indirect=True)
|
||||
def test_validate_credentials(setup_tei_mock):
|
||||
model = HuggingfaceTeiTextEmbeddingModel()
|
||||
# model name is only used in mock
|
||||
model_name = 'embedding'
|
||||
|
||||
if MOCK:
|
||||
# TEI Provider will check model type by API endpoint, at real server, the model type is correct.
|
||||
# So we dont need to check model type here. Only check in mock
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='reranker',
|
||||
credentials={
|
||||
'server_url': os.environ.get('TEI_EMBEDDING_SERVER_URL', ""),
|
||||
}
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model=model_name,
|
||||
credentials={
|
||||
'server_url': os.environ.get('TEI_EMBEDDING_SERVER_URL', ""),
|
||||
}
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize('setup_tei_mock', [['none']], indirect=True)
|
||||
def test_invoke_model(setup_tei_mock):
|
||||
model = HuggingfaceTeiTextEmbeddingModel()
|
||||
model_name = 'embedding'
|
||||
|
||||
result = model.invoke(
|
||||
model=model_name,
|
||||
credentials={
|
||||
'server_url': os.environ.get('TEI_EMBEDDING_SERVER_URL', ""),
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
],
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
assert len(result.embeddings) == 2
|
||||
assert result.usage.total_tokens > 0
|
||||
@ -0,0 +1,76 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
|
||||
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.huggingface_tei.rerank.rerank import (
|
||||
HuggingfaceTeiRerankModel,
|
||||
)
|
||||
from core.model_runtime.model_providers.huggingface_tei.text_embedding.text_embedding import TeiHelper
|
||||
from tests.integration_tests.model_runtime.__mock.huggingface_tei import MockTEIClass
|
||||
|
||||
MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_tei_mock(request, monkeypatch: pytest.MonkeyPatch):
|
||||
if MOCK:
|
||||
monkeypatch.setattr(TeiHelper, 'get_tei_extra_parameter', MockTEIClass.get_tei_extra_parameter)
|
||||
monkeypatch.setattr(TeiHelper, 'invoke_tokenize', MockTEIClass.invoke_tokenize)
|
||||
monkeypatch.setattr(TeiHelper, 'invoke_embeddings', MockTEIClass.invoke_embeddings)
|
||||
monkeypatch.setattr(TeiHelper, 'invoke_rerank', MockTEIClass.invoke_rerank)
|
||||
yield
|
||||
|
||||
if MOCK:
|
||||
monkeypatch.undo()
|
||||
|
||||
@pytest.mark.parametrize('setup_tei_mock', [['none']], indirect=True)
|
||||
def test_validate_credentials(setup_tei_mock):
|
||||
model = HuggingfaceTeiRerankModel()
|
||||
# model name is only used in mock
|
||||
model_name = 'reranker'
|
||||
|
||||
if MOCK:
|
||||
# TEI Provider will check model type by API endpoint, at real server, the model type is correct.
|
||||
# So we dont need to check model type here. Only check in mock
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='embedding',
|
||||
credentials={
|
||||
'server_url': os.environ.get('TEI_RERANK_SERVER_URL'),
|
||||
}
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model=model_name,
|
||||
credentials={
|
||||
'server_url': os.environ.get('TEI_RERANK_SERVER_URL'),
|
||||
}
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize('setup_tei_mock', [['none']], indirect=True)
|
||||
def test_invoke_model(setup_tei_mock):
|
||||
model = HuggingfaceTeiRerankModel()
|
||||
# model name is only used in mock
|
||||
model_name = 'reranker'
|
||||
|
||||
result = model.invoke(
|
||||
model=model_name,
|
||||
credentials={
|
||||
'server_url': os.environ.get('TEI_RERANK_SERVER_URL'),
|
||||
},
|
||||
query="Who is Kasumi?",
|
||||
docs=[
|
||||
"Kasumi is a girl's name of Japanese origin meaning \"mist\".",
|
||||
"Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ",
|
||||
"and she leads a team named PopiParty."
|
||||
],
|
||||
score_threshold=0.8
|
||||
)
|
||||
|
||||
assert isinstance(result, RerankResult)
|
||||
assert len(result.docs) == 1
|
||||
assert result.docs[0].index == 0
|
||||
assert result.docs[0].score >= 0.8
|
||||
@ -0,0 +1,59 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.openai_api_compatible.speech2text.speech2text import (
|
||||
OAICompatSpeech2TextModel,
|
||||
)
|
||||
|
||||
|
||||
def test_validate_credentials():
|
||||
model = OAICompatSpeech2TextModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model="whisper-1",
|
||||
credentials={
|
||||
"api_key": "invalid_key",
|
||||
"endpoint_url": "https://api.openai.com/v1/"
|
||||
},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model="whisper-1",
|
||||
credentials={
|
||||
"api_key": os.environ.get("OPENAI_API_KEY"),
|
||||
"endpoint_url": "https://api.openai.com/v1/"
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
model = OAICompatSpeech2TextModel()
|
||||
|
||||
# Get the directory of the current file
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# Get assets directory
|
||||
assets_dir = os.path.join(os.path.dirname(current_dir), "assets")
|
||||
|
||||
# Construct the path to the audio file
|
||||
audio_file_path = os.path.join(assets_dir, "audio.mp3")
|
||||
|
||||
# Open the file and get the file object
|
||||
with open(audio_file_path, "rb") as audio_file:
|
||||
file = audio_file
|
||||
|
||||
result = model.invoke(
|
||||
model="whisper-1",
|
||||
credentials={
|
||||
"api_key": os.environ.get("OPENAI_API_KEY"),
|
||||
"endpoint_url": "https://api.openai.com/v1/"
|
||||
},
|
||||
file=file,
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert result == '1, 2, 3, 4, 5, 6, 7, 8, 9, 10'
|
||||
@ -0,0 +1,53 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.siliconflow.speech2text.speech2text import SiliconflowSpeech2TextModel
|
||||
|
||||
|
||||
def test_validate_credentials():
|
||||
model = SiliconflowSpeech2TextModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model="iic/SenseVoiceSmall",
|
||||
credentials={
|
||||
"api_key": "invalid_key"
|
||||
},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model="iic/SenseVoiceSmall",
|
||||
credentials={
|
||||
"api_key": os.environ.get("API_KEY")
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
model = SiliconflowSpeech2TextModel()
|
||||
|
||||
# Get the directory of the current file
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# Get assets directory
|
||||
assets_dir = os.path.join(os.path.dirname(current_dir), "assets")
|
||||
|
||||
# Construct the path to the audio file
|
||||
audio_file_path = os.path.join(assets_dir, "audio.mp3")
|
||||
|
||||
# Open the file and get the file object
|
||||
with open(audio_file_path, "rb") as audio_file:
|
||||
file = audio_file
|
||||
|
||||
result = model.invoke(
|
||||
model="iic/SenseVoiceSmall",
|
||||
credentials={
|
||||
"api_key": os.environ.get("API_KEY")
|
||||
},
|
||||
file=file
|
||||
)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert result == '1,2,3,4,5,6,7,8,9,10.'
|
||||
@ -0,0 +1,62 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.siliconflow.text_embedding.text_embedding import (
|
||||
SiliconflowTextEmbeddingModel,
|
||||
)
|
||||
|
||||
|
||||
def test_validate_credentials():
|
||||
model = SiliconflowTextEmbeddingModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model="BAAI/bge-large-zh-v1.5",
|
||||
credentials={
|
||||
"api_key": "invalid_key"
|
||||
},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model="BAAI/bge-large-zh-v1.5",
|
||||
credentials={
|
||||
"api_key": os.environ.get("API_KEY"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
model = SiliconflowTextEmbeddingModel()
|
||||
|
||||
result = model.invoke(
|
||||
model="BAAI/bge-large-zh-v1.5",
|
||||
credentials={
|
||||
"api_key": os.environ.get("API_KEY"),
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world",
|
||||
],
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
assert len(result.embeddings) == 2
|
||||
assert result.usage.total_tokens == 6
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = SiliconflowTextEmbeddingModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model="BAAI/bge-large-zh-v1.5",
|
||||
credentials={
|
||||
"api_key": os.environ.get("API_KEY"),
|
||||
},
|
||||
texts=["hello", "world"],
|
||||
)
|
||||
|
||||
assert num_tokens == 2
|
||||
245
api/tests/integration_tests/model_runtime/upstage/test_llm.py
Normal file
245
api/tests/integration_tests/model_runtime/upstage/test_llm.py
Normal file
@ -0,0 +1,245 @@
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.model_providers.upstage.llm.llm import UpstageLargeLanguageModel
|
||||
|
||||
"""FOR MOCK FIXTURES, DO NOT REMOVE"""
|
||||
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
|
||||
|
||||
|
||||
def test_predefined_models():
|
||||
model = UpstageLargeLanguageModel()
|
||||
model_schemas = model.predefined_models()
|
||||
|
||||
assert len(model_schemas) >= 1
|
||||
assert isinstance(model_schemas[0], AIModelEntity)
|
||||
|
||||
@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
|
||||
def test_validate_credentials_for_chat_model(setup_openai_mock):
|
||||
model = UpstageLargeLanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
# model name to gpt-3.5-turbo because of mocking
|
||||
model.validate_credentials(
|
||||
model='gpt-3.5-turbo',
|
||||
credentials={
|
||||
'upstage_api_key': 'invalid_key'
|
||||
}
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='solar-1-mini-chat',
|
||||
credentials={
|
||||
'upstage_api_key': os.environ.get('UPSTAGE_API_KEY')
|
||||
}
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
|
||||
def test_invoke_chat_model(setup_openai_mock):
|
||||
model = UpstageLargeLanguageModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='solar-1-mini-chat',
|
||||
credentials={
|
||||
'upstage_api_key': os.environ.get('UPSTAGE_API_KEY')
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.0,
|
||||
'top_p': 1.0,
|
||||
'presence_penalty': 0.0,
|
||||
'frequency_penalty': 0.0,
|
||||
'max_tokens': 10
|
||||
},
|
||||
stop=['How'],
|
||||
stream=False,
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(result, LLMResult)
|
||||
assert len(result.message.content) > 0
|
||||
|
||||
@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
|
||||
def test_invoke_chat_model_with_tools(setup_openai_mock):
|
||||
model = UpstageLargeLanguageModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='solar-1-mini-chat',
|
||||
credentials={
|
||||
'upstage_api_key': os.environ.get('UPSTAGE_API_KEY')
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
),
|
||||
UserPromptMessage(
|
||||
content="what's the weather today in London?",
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.0,
|
||||
'max_tokens': 100
|
||||
},
|
||||
tools=[
|
||||
PromptMessageTool(
|
||||
name='get_weather',
|
||||
description='Determine weather in my location',
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state e.g. San Francisco, CA"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"c",
|
||||
"f"
|
||||
]
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"location"
|
||||
]
|
||||
}
|
||||
),
|
||||
PromptMessageTool(
|
||||
name='get_stock_price',
|
||||
description='Get the current stock price',
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"symbol": {
|
||||
"type": "string",
|
||||
"description": "The stock symbol"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"symbol"
|
||||
]
|
||||
}
|
||||
)
|
||||
],
|
||||
stream=False,
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(result, LLMResult)
|
||||
assert isinstance(result.message, AssistantPromptMessage)
|
||||
assert len(result.message.tool_calls) > 0
|
||||
|
||||
@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
|
||||
def test_invoke_stream_chat_model(setup_openai_mock):
|
||||
model = UpstageLargeLanguageModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='solar-1-mini-chat',
|
||||
credentials={
|
||||
'upstage_api_key': os.environ.get('UPSTAGE_API_KEY')
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.0,
|
||||
'max_tokens': 100
|
||||
},
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(result, Generator)
|
||||
|
||||
for chunk in result:
|
||||
assert isinstance(chunk, LLMResultChunk)
|
||||
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
|
||||
if chunk.delta.finish_reason is not None:
|
||||
assert chunk.delta.usage is not None
|
||||
assert chunk.delta.usage.completion_tokens > 0
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = UpstageLargeLanguageModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='solar-1-mini-chat',
|
||||
credentials={
|
||||
'upstage_api_key': os.environ.get('UPSTAGE_API_KEY')
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
assert num_tokens == 13
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='solar-1-mini-chat',
|
||||
credentials={
|
||||
'upstage_api_key': os.environ.get('UPSTAGE_API_KEY')
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
tools=[
|
||||
PromptMessageTool(
|
||||
name='get_weather',
|
||||
description='Determine weather in my location',
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state e.g. San Francisco, CA"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"c",
|
||||
"f"
|
||||
]
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"location"
|
||||
]
|
||||
}
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
assert num_tokens == 106
|
||||
@ -0,0 +1,23 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.upstage.upstage import UpstageProvider
|
||||
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
|
||||
|
||||
|
||||
@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
|
||||
def test_validate_provider_credentials(setup_openai_mock):
|
||||
provider = UpstageProvider()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
provider.validate_provider_credentials(
|
||||
credentials={}
|
||||
)
|
||||
|
||||
provider.validate_provider_credentials(
|
||||
credentials={
|
||||
'upstage_api_key': os.environ.get('UPSTAGE_API_KEY')
|
||||
}
|
||||
)
|
||||
@ -0,0 +1,67 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.upstage.text_embedding.text_embedding import UpstageTextEmbeddingModel
|
||||
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
|
||||
|
||||
|
||||
@pytest.mark.parametrize('setup_openai_mock', [['text_embedding']], indirect=True)
|
||||
def test_validate_credentials(setup_openai_mock):
|
||||
model = UpstageTextEmbeddingModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='solar-embedding-1-large-passage',
|
||||
credentials={
|
||||
'upstage_api_key': 'invalid_key'
|
||||
}
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='solar-embedding-1-large-passage',
|
||||
credentials={
|
||||
'upstage_api_key': os.environ.get('UPSTAGE_API_KEY')
|
||||
}
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize('setup_openai_mock', [['text_embedding']], indirect=True)
|
||||
def test_invoke_model(setup_openai_mock):
|
||||
model = UpstageTextEmbeddingModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='solar-embedding-1-large-passage',
|
||||
credentials={
|
||||
'upstage_api_key': os.environ.get('UPSTAGE_API_KEY'),
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world",
|
||||
" ".join(["long_text"] * 100),
|
||||
" ".join(["another_long_text"] * 100)
|
||||
],
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
assert len(result.embeddings) == 4
|
||||
assert result.usage.total_tokens == 2
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = UpstageTextEmbeddingModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='solar-embedding-1-large-passage',
|
||||
credentials={
|
||||
'upstage_api_key': os.environ.get('UPSTAGE_API_KEY'),
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
]
|
||||
)
|
||||
|
||||
assert num_tokens == 5
|
||||
106
api/tests/integration_tests/model_runtime/zhinao/test_llm.py
Normal file
106
api/tests/integration_tests/model_runtime/zhinao/test_llm.py
Normal file
@ -0,0 +1,106 @@
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.zhinao.llm.llm import ZhinaoLargeLanguageModel
|
||||
|
||||
|
||||
def test_validate_credentials():
|
||||
model = ZhinaoLargeLanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='360gpt2-pro',
|
||||
credentials={
|
||||
'api_key': 'invalid_key'
|
||||
}
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='360gpt2-pro',
|
||||
credentials={
|
||||
'api_key': os.environ.get('ZHINAO_API_KEY')
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
model = ZhinaoLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='360gpt2-pro',
|
||||
credentials={
|
||||
'api_key': os.environ.get('ZHINAO_API_KEY')
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Who are you?'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.5,
|
||||
'max_tokens': 10
|
||||
},
|
||||
stop=['How'],
|
||||
stream=False,
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
|
||||
|
||||
def test_invoke_stream_model():
|
||||
model = ZhinaoLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='360gpt2-pro',
|
||||
credentials={
|
||||
'api_key': os.environ.get('ZHINAO_API_KEY')
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.5,
|
||||
'max_tokens': 100,
|
||||
'seed': 1234
|
||||
},
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
|
||||
for chunk in response:
|
||||
assert isinstance(chunk, LLMResultChunk)
|
||||
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = ZhinaoLargeLanguageModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='360gpt2-pro',
|
||||
credentials={
|
||||
'api_key': os.environ.get('ZHINAO_API_KEY')
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
assert num_tokens == 21
|
||||
@ -0,0 +1,21 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.zhinao.zhinao import ZhinaoProvider
|
||||
|
||||
|
||||
def test_validate_provider_credentials():
|
||||
provider = ZhinaoProvider()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
provider.validate_provider_credentials(
|
||||
credentials={}
|
||||
)
|
||||
|
||||
provider.validate_provider_credentials(
|
||||
credentials={
|
||||
'api_key': os.environ.get('ZHINAO_API_KEY')
|
||||
}
|
||||
)
|
||||
@ -7,15 +7,16 @@ from core.app.segments import (
|
||||
ArrayNumberVariable,
|
||||
ArrayObjectVariable,
|
||||
ArrayStringVariable,
|
||||
FileSegment,
|
||||
FileVariable,
|
||||
FloatVariable,
|
||||
IntegerVariable,
|
||||
NoneSegment,
|
||||
ObjectSegment,
|
||||
SecretVariable,
|
||||
StringVariable,
|
||||
factory,
|
||||
)
|
||||
from core.app.segments.exc import VariableError
|
||||
|
||||
|
||||
def test_string_variable():
|
||||
@ -44,7 +45,7 @@ def test_secret_variable():
|
||||
|
||||
def test_invalid_value_type():
|
||||
test_data = {'value_type': 'unknown', 'name': 'test_invalid', 'value': 'value'}
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(VariableError):
|
||||
factory.build_variable_from_mapping(test_data)
|
||||
|
||||
|
||||
@ -67,7 +68,7 @@ def test_build_a_object_variable_with_none_value():
|
||||
}
|
||||
)
|
||||
assert isinstance(var, ObjectSegment)
|
||||
assert isinstance(var.value['key1'], NoneSegment)
|
||||
assert var.value['key1'] is None
|
||||
|
||||
|
||||
def test_object_variable():
|
||||
@ -77,26 +78,14 @@ def test_object_variable():
|
||||
'name': 'test_object',
|
||||
'description': 'Description of the variable.',
|
||||
'value': {
|
||||
'key1': {
|
||||
'id': str(uuid4()),
|
||||
'value_type': 'string',
|
||||
'name': 'text',
|
||||
'value': 'text',
|
||||
'description': 'Description of the variable.',
|
||||
},
|
||||
'key2': {
|
||||
'id': str(uuid4()),
|
||||
'value_type': 'number',
|
||||
'name': 'number',
|
||||
'value': 1,
|
||||
'description': 'Description of the variable.',
|
||||
},
|
||||
'key1': 'text',
|
||||
'key2': 2,
|
||||
},
|
||||
}
|
||||
variable = factory.build_variable_from_mapping(mapping)
|
||||
assert isinstance(variable, ObjectSegment)
|
||||
assert isinstance(variable.value['key1'], StringVariable)
|
||||
assert isinstance(variable.value['key2'], IntegerVariable)
|
||||
assert isinstance(variable.value['key1'], str)
|
||||
assert isinstance(variable.value['key2'], int)
|
||||
|
||||
|
||||
def test_array_string_variable():
|
||||
@ -106,26 +95,14 @@ def test_array_string_variable():
|
||||
'name': 'test_array',
|
||||
'description': 'Description of the variable.',
|
||||
'value': [
|
||||
{
|
||||
'id': str(uuid4()),
|
||||
'value_type': 'string',
|
||||
'name': 'text',
|
||||
'value': 'text',
|
||||
'description': 'Description of the variable.',
|
||||
},
|
||||
{
|
||||
'id': str(uuid4()),
|
||||
'value_type': 'string',
|
||||
'name': 'text',
|
||||
'value': 'text',
|
||||
'description': 'Description of the variable.',
|
||||
},
|
||||
'text',
|
||||
'text',
|
||||
],
|
||||
}
|
||||
variable = factory.build_variable_from_mapping(mapping)
|
||||
assert isinstance(variable, ArrayStringVariable)
|
||||
assert isinstance(variable.value[0], StringVariable)
|
||||
assert isinstance(variable.value[1], StringVariable)
|
||||
assert isinstance(variable.value[0], str)
|
||||
assert isinstance(variable.value[1], str)
|
||||
|
||||
|
||||
def test_array_number_variable():
|
||||
@ -135,26 +112,14 @@ def test_array_number_variable():
|
||||
'name': 'test_array',
|
||||
'description': 'Description of the variable.',
|
||||
'value': [
|
||||
{
|
||||
'id': str(uuid4()),
|
||||
'value_type': 'number',
|
||||
'name': 'number',
|
||||
'value': 1,
|
||||
'description': 'Description of the variable.',
|
||||
},
|
||||
{
|
||||
'id': str(uuid4()),
|
||||
'value_type': 'number',
|
||||
'name': 'number',
|
||||
'value': 2.0,
|
||||
'description': 'Description of the variable.',
|
||||
},
|
||||
1,
|
||||
2.0,
|
||||
],
|
||||
}
|
||||
variable = factory.build_variable_from_mapping(mapping)
|
||||
assert isinstance(variable, ArrayNumberVariable)
|
||||
assert isinstance(variable.value[0], IntegerVariable)
|
||||
assert isinstance(variable.value[1], FloatVariable)
|
||||
assert isinstance(variable.value[0], int)
|
||||
assert isinstance(variable.value[1], float)
|
||||
|
||||
|
||||
def test_array_object_variable():
|
||||
@ -165,59 +130,23 @@ def test_array_object_variable():
|
||||
'description': 'Description of the variable.',
|
||||
'value': [
|
||||
{
|
||||
'id': str(uuid4()),
|
||||
'value_type': 'object',
|
||||
'name': 'object',
|
||||
'description': 'Description of the variable.',
|
||||
'value': {
|
||||
'key1': {
|
||||
'id': str(uuid4()),
|
||||
'value_type': 'string',
|
||||
'name': 'text',
|
||||
'value': 'text',
|
||||
'description': 'Description of the variable.',
|
||||
},
|
||||
'key2': {
|
||||
'id': str(uuid4()),
|
||||
'value_type': 'number',
|
||||
'name': 'number',
|
||||
'value': 1,
|
||||
'description': 'Description of the variable.',
|
||||
},
|
||||
},
|
||||
'key1': 'text',
|
||||
'key2': 1,
|
||||
},
|
||||
{
|
||||
'id': str(uuid4()),
|
||||
'value_type': 'object',
|
||||
'name': 'object',
|
||||
'description': 'Description of the variable.',
|
||||
'value': {
|
||||
'key1': {
|
||||
'id': str(uuid4()),
|
||||
'value_type': 'string',
|
||||
'name': 'text',
|
||||
'value': 'text',
|
||||
'description': 'Description of the variable.',
|
||||
},
|
||||
'key2': {
|
||||
'id': str(uuid4()),
|
||||
'value_type': 'number',
|
||||
'name': 'number',
|
||||
'value': 1,
|
||||
'description': 'Description of the variable.',
|
||||
},
|
||||
},
|
||||
'key1': 'text',
|
||||
'key2': 1,
|
||||
},
|
||||
],
|
||||
}
|
||||
variable = factory.build_variable_from_mapping(mapping)
|
||||
assert isinstance(variable, ArrayObjectVariable)
|
||||
assert isinstance(variable.value[0], ObjectSegment)
|
||||
assert isinstance(variable.value[1], ObjectSegment)
|
||||
assert isinstance(variable.value[0].value['key1'], StringVariable)
|
||||
assert isinstance(variable.value[0].value['key2'], IntegerVariable)
|
||||
assert isinstance(variable.value[1].value['key1'], StringVariable)
|
||||
assert isinstance(variable.value[1].value['key2'], IntegerVariable)
|
||||
assert isinstance(variable.value[0], dict)
|
||||
assert isinstance(variable.value[1], dict)
|
||||
assert isinstance(variable.value[0]['key1'], str)
|
||||
assert isinstance(variable.value[0]['key2'], int)
|
||||
assert isinstance(variable.value[1]['key1'], str)
|
||||
assert isinstance(variable.value[1]['key2'], int)
|
||||
|
||||
|
||||
def test_file_variable():
|
||||
@ -257,51 +186,53 @@ def test_array_file_variable():
|
||||
'value': [
|
||||
{
|
||||
'id': str(uuid4()),
|
||||
'name': 'file',
|
||||
'value_type': 'file',
|
||||
'value': {
|
||||
'id': str(uuid4()),
|
||||
'tenant_id': 'tenant_id',
|
||||
'type': 'image',
|
||||
'transfer_method': 'local_file',
|
||||
'url': 'url',
|
||||
'related_id': 'related_id',
|
||||
'extra_config': {
|
||||
'image_config': {
|
||||
'width': 100,
|
||||
'height': 100,
|
||||
},
|
||||
'tenant_id': 'tenant_id',
|
||||
'type': 'image',
|
||||
'transfer_method': 'local_file',
|
||||
'url': 'url',
|
||||
'related_id': 'related_id',
|
||||
'extra_config': {
|
||||
'image_config': {
|
||||
'width': 100,
|
||||
'height': 100,
|
||||
},
|
||||
'filename': 'filename',
|
||||
'extension': 'extension',
|
||||
'mime_type': 'mime_type',
|
||||
},
|
||||
'filename': 'filename',
|
||||
'extension': 'extension',
|
||||
'mime_type': 'mime_type',
|
||||
},
|
||||
{
|
||||
'id': str(uuid4()),
|
||||
'name': 'file',
|
||||
'value_type': 'file',
|
||||
'value': {
|
||||
'id': str(uuid4()),
|
||||
'tenant_id': 'tenant_id',
|
||||
'type': 'image',
|
||||
'transfer_method': 'local_file',
|
||||
'url': 'url',
|
||||
'related_id': 'related_id',
|
||||
'extra_config': {
|
||||
'image_config': {
|
||||
'width': 100,
|
||||
'height': 100,
|
||||
},
|
||||
'tenant_id': 'tenant_id',
|
||||
'type': 'image',
|
||||
'transfer_method': 'local_file',
|
||||
'url': 'url',
|
||||
'related_id': 'related_id',
|
||||
'extra_config': {
|
||||
'image_config': {
|
||||
'width': 100,
|
||||
'height': 100,
|
||||
},
|
||||
'filename': 'filename',
|
||||
'extension': 'extension',
|
||||
'mime_type': 'mime_type',
|
||||
},
|
||||
'filename': 'filename',
|
||||
'extension': 'extension',
|
||||
'mime_type': 'mime_type',
|
||||
},
|
||||
],
|
||||
}
|
||||
variable = factory.build_variable_from_mapping(mapping)
|
||||
assert isinstance(variable, ArrayFileVariable)
|
||||
assert isinstance(variable.value[0], FileVariable)
|
||||
assert isinstance(variable.value[1], FileVariable)
|
||||
assert isinstance(variable.value[0], FileSegment)
|
||||
assert isinstance(variable.value[1], FileSegment)
|
||||
|
||||
|
||||
def test_variable_cannot_large_than_5_kb():
|
||||
with pytest.raises(VariableError):
|
||||
factory.build_variable_from_mapping(
|
||||
{
|
||||
'id': str(uuid4()),
|
||||
'value_type': 'string',
|
||||
'name': 'test_text',
|
||||
'value': 'a' * 1024 * 6,
|
||||
}
|
||||
)
|
||||
|
||||
@ -54,20 +54,10 @@ def test_object_variable_to_object():
|
||||
var = ObjectVariable(
|
||||
name='object',
|
||||
value={
|
||||
'key1': ObjectVariable(
|
||||
name='object',
|
||||
value={
|
||||
'key2': StringVariable(name='key2', value='value2'),
|
||||
},
|
||||
),
|
||||
'key2': ArrayAnyVariable(
|
||||
name='array',
|
||||
value=[
|
||||
StringVariable(name='key5_1', value='value5_1'),
|
||||
IntegerVariable(name='key5_2', value=42),
|
||||
ObjectVariable(name='key5_3', value={}),
|
||||
],
|
||||
),
|
||||
'key1': {
|
||||
'key2': 'value2',
|
||||
},
|
||||
'key2': ['value5_1', 42, {}],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@ -2,8 +2,8 @@ from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.app_config.entities import FileExtraConfig, ModelConfigEntity
|
||||
from core.file.file_obj import FileTransferMethod, FileType, FileVar
|
||||
from core.app.app_config.entities import ModelConfigEntity
|
||||
from core.file.file_obj import FileExtraConfig, FileTransferMethod, FileType, FileVar
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessageRole, UserPromptMessage
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
|
||||
@ -0,0 +1,150 @@
|
||||
from unittest import mock
|
||||
from uuid import uuid4
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.segments import ArrayStringVariable, StringVariable
|
||||
from core.workflow.entities.node_entities import SystemVariable
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import UserFrom
|
||||
from core.workflow.nodes.variable_assigner import VariableAssignerNode, WriteMode
|
||||
|
||||
DEFAULT_NODE_ID = 'node_id'
|
||||
|
||||
|
||||
def test_overwrite_string_variable():
|
||||
conversation_variable = StringVariable(
|
||||
id=str(uuid4()),
|
||||
name='test_conversation_variable',
|
||||
value='the first value',
|
||||
)
|
||||
|
||||
input_variable = StringVariable(
|
||||
id=str(uuid4()),
|
||||
name='test_string_variable',
|
||||
value='the second value',
|
||||
)
|
||||
|
||||
node = VariableAssignerNode(
|
||||
tenant_id='tenant_id',
|
||||
app_id='app_id',
|
||||
workflow_id='workflow_id',
|
||||
user_id='user_id',
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
config={
|
||||
'id': 'node_id',
|
||||
'data': {
|
||||
'assigned_variable_selector': ['conversation', conversation_variable.name],
|
||||
'write_mode': WriteMode.OVER_WRITE.value,
|
||||
'input_variable_selector': [DEFAULT_NODE_ID, input_variable.name],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables={SystemVariable.CONVERSATION_ID: 'conversation_id'},
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[conversation_variable],
|
||||
)
|
||||
variable_pool.add(
|
||||
[DEFAULT_NODE_ID, input_variable.name],
|
||||
input_variable,
|
||||
)
|
||||
|
||||
with mock.patch('core.workflow.nodes.variable_assigner.update_conversation_variable') as mock_run:
|
||||
node.run(variable_pool)
|
||||
mock_run.assert_called_once()
|
||||
|
||||
got = variable_pool.get(['conversation', conversation_variable.name])
|
||||
assert got is not None
|
||||
assert got.value == 'the second value'
|
||||
assert got.to_object() == 'the second value'
|
||||
|
||||
|
||||
def test_append_variable_to_array():
|
||||
conversation_variable = ArrayStringVariable(
|
||||
id=str(uuid4()),
|
||||
name='test_conversation_variable',
|
||||
value=['the first value'],
|
||||
)
|
||||
|
||||
input_variable = StringVariable(
|
||||
id=str(uuid4()),
|
||||
name='test_string_variable',
|
||||
value='the second value',
|
||||
)
|
||||
|
||||
node = VariableAssignerNode(
|
||||
tenant_id='tenant_id',
|
||||
app_id='app_id',
|
||||
workflow_id='workflow_id',
|
||||
user_id='user_id',
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
config={
|
||||
'id': 'node_id',
|
||||
'data': {
|
||||
'assigned_variable_selector': ['conversation', conversation_variable.name],
|
||||
'write_mode': WriteMode.APPEND.value,
|
||||
'input_variable_selector': [DEFAULT_NODE_ID, input_variable.name],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables={SystemVariable.CONVERSATION_ID: 'conversation_id'},
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[conversation_variable],
|
||||
)
|
||||
variable_pool.add(
|
||||
[DEFAULT_NODE_ID, input_variable.name],
|
||||
input_variable,
|
||||
)
|
||||
|
||||
with mock.patch('core.workflow.nodes.variable_assigner.update_conversation_variable') as mock_run:
|
||||
node.run(variable_pool)
|
||||
mock_run.assert_called_once()
|
||||
|
||||
got = variable_pool.get(['conversation', conversation_variable.name])
|
||||
assert got is not None
|
||||
assert got.to_object() == ['the first value', 'the second value']
|
||||
|
||||
|
||||
def test_clear_array():
|
||||
conversation_variable = ArrayStringVariable(
|
||||
id=str(uuid4()),
|
||||
name='test_conversation_variable',
|
||||
value=['the first value'],
|
||||
)
|
||||
|
||||
node = VariableAssignerNode(
|
||||
tenant_id='tenant_id',
|
||||
app_id='app_id',
|
||||
workflow_id='workflow_id',
|
||||
user_id='user_id',
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
config={
|
||||
'id': 'node_id',
|
||||
'data': {
|
||||
'assigned_variable_selector': ['conversation', conversation_variable.name],
|
||||
'write_mode': WriteMode.CLEAR.value,
|
||||
'input_variable_selector': [],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables={SystemVariable.CONVERSATION_ID: 'conversation_id'},
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[conversation_variable],
|
||||
)
|
||||
|
||||
node.run(variable_pool)
|
||||
|
||||
got = variable_pool.get(['conversation', conversation_variable.name])
|
||||
assert got is not None
|
||||
assert got.to_object() == []
|
||||
25
api/tests/unit_tests/models/test_conversation_variable.py
Normal file
25
api/tests/unit_tests/models/test_conversation_variable.py
Normal file
@ -0,0 +1,25 @@
|
||||
from uuid import uuid4
|
||||
|
||||
from core.app.segments import SegmentType, factory
|
||||
from models import ConversationVariable
|
||||
|
||||
|
||||
def test_from_variable_and_to_variable():
|
||||
variable = factory.build_variable_from_mapping(
|
||||
{
|
||||
'id': str(uuid4()),
|
||||
'name': 'name',
|
||||
'value_type': SegmentType.OBJECT,
|
||||
'value': {
|
||||
'key': {
|
||||
'key': 'value',
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
conversation_variable = ConversationVariable.from_variable(
|
||||
app_id='app_id', conversation_id='conversation_id', variable=variable
|
||||
)
|
||||
|
||||
assert conversation_variable.to_variable() == variable
|
||||
@ -208,7 +208,8 @@ def test__convert_to_knowledge_retrieval_node_for_chatbot():
|
||||
reranking_model={
|
||||
'reranking_provider_name': 'cohere',
|
||||
'reranking_model_name': 'rerank-english-v2.0'
|
||||
}
|
||||
},
|
||||
reranking_enabled=True
|
||||
)
|
||||
)
|
||||
|
||||
@ -251,7 +252,8 @@ def test__convert_to_knowledge_retrieval_node_for_workflow_app():
|
||||
reranking_model={
|
||||
'reranking_provider_name': 'cohere',
|
||||
'reranking_model_name': 'rerank-english-v2.0'
|
||||
}
|
||||
},
|
||||
reranking_enabled=True
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user