mirror of
https://github.com/langgenius/dify.git
synced 2026-05-05 18:08:07 +08:00
Introduce Plugins (#13836)
Signed-off-by: yihong0618 <zouzou0208@gmail.com> Signed-off-by: -LAN- <laipz8200@outlook.com> Signed-off-by: xhe <xw897002528@gmail.com> Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: takatost <takatost@gmail.com> Co-authored-by: kurokobo <kuro664@gmail.com> Co-authored-by: Novice Lee <novicelee@NoviPro.local> Co-authored-by: zxhlyh <jasonapring2015@outlook.com> Co-authored-by: AkaraChen <akarachen@outlook.com> Co-authored-by: Yi <yxiaoisme@gmail.com> Co-authored-by: Joel <iamjoel007@gmail.com> Co-authored-by: JzoNg <jzongcode@gmail.com> Co-authored-by: twwu <twwu@dify.ai> Co-authored-by: Hiroshi Fujita <fujita-h@users.noreply.github.com> Co-authored-by: AkaraChen <85140972+AkaraChen@users.noreply.github.com> Co-authored-by: NFish <douxc512@gmail.com> Co-authored-by: Wu Tianwei <30284043+WTW0313@users.noreply.github.com> Co-authored-by: 非法操作 <hjlarry@163.com> Co-authored-by: Novice <857526207@qq.com> Co-authored-by: Hiroki Nagai <82458324+nagaihiroki-git@users.noreply.github.com> Co-authored-by: Gen Sato <52241300+halogen22@users.noreply.github.com> Co-authored-by: eux <euxuuu@gmail.com> Co-authored-by: huangzhuo1949 <167434202+huangzhuo1949@users.noreply.github.com> Co-authored-by: huangzhuo <huangzhuo1@xiaomi.com> Co-authored-by: lotsik <lotsik@mail.ru> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> Co-authored-by: nite-knite <nkCoding@gmail.com> Co-authored-by: Jyong <76649700+JohnJyong@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: gakkiyomi <gakkiyomi@aliyun.com> Co-authored-by: CN-P5 <heibai2006@gmail.com> Co-authored-by: CN-P5 <heibai2006@qq.com> Co-authored-by: Chuehnone <1897025+chuehnone@users.noreply.github.com> Co-authored-by: yihong <zouzou0208@gmail.com> Co-authored-by: Kevin9703 <51311316+Kevin9703@users.noreply.github.com> Co-authored-by: -LAN- <laipz8200@outlook.com> Co-authored-by: Boris Feld <lothiraldan@gmail.com> Co-authored-by: mbo <himabo@gmail.com> Co-authored-by: mabo <mabo@aeyes.ai> Co-authored-by: Warren Chen <warren.chen830@gmail.com> Co-authored-by: JzoNgKVO <27049666+JzoNgKVO@users.noreply.github.com> Co-authored-by: jiandanfeng <chenjh3@wangsu.com> Co-authored-by: zhu-an <70234959+xhdd123321@users.noreply.github.com> Co-authored-by: zhaoqingyu.1075 <zhaoqingyu.1075@bytedance.com> Co-authored-by: 海狸大師 <86974027+yenslife@users.noreply.github.com> Co-authored-by: Xu Song <xusong.vip@gmail.com> Co-authored-by: rayshaw001 <396301947@163.com> Co-authored-by: Ding Jiatong <dingjiatong@gmail.com> Co-authored-by: Bowen Liang <liangbowen@gf.com.cn> Co-authored-by: JasonVV <jasonwangiii@outlook.com> Co-authored-by: le0zh <newlight@qq.com> Co-authored-by: zhuxinliang <zhuxinliang@didiglobal.com> Co-authored-by: k-zaku <zaku99@outlook.jp> Co-authored-by: luckylhb90 <luckylhb90@gmail.com> Co-authored-by: hobo.l <hobo.l@binance.com> Co-authored-by: jiangbo721 <365065261@qq.com> Co-authored-by: 刘江波 <jiangbo721@163.com> Co-authored-by: Shun Miyazawa <34241526+miya@users.noreply.github.com> Co-authored-by: EricPan <30651140+Egfly@users.noreply.github.com> Co-authored-by: crazywoola <427733928@qq.com> Co-authored-by: sino <sino2322@gmail.com> Co-authored-by: Jhvcc <37662342+Jhvcc@users.noreply.github.com> Co-authored-by: lowell <lowell.hu@zkteco.in> Co-authored-by: Boris Polonsky <BorisPolonsky@users.noreply.github.com> Co-authored-by: Ademílson Tonato <ademilsonft@outlook.com> Co-authored-by: Ademílson Tonato <ademilson.tonato@refurbed.com> Co-authored-by: IWAI, Masaharu <iwaim.sub@gmail.com> Co-authored-by: Yueh-Po Peng (Yabi) <94939112+y10ab1@users.noreply.github.com> Co-authored-by: Jason <ggbbddjm@gmail.com> Co-authored-by: Xin Zhang <sjhpzx@gmail.com> Co-authored-by: yjc980121 <3898524+yjc980121@users.noreply.github.com> Co-authored-by: heyszt <36215648+hieheihei@users.noreply.github.com> Co-authored-by: Abdullah AlOsaimi <osaimiacc@gmail.com> Co-authored-by: Abdullah AlOsaimi <189027247+osaimi@users.noreply.github.com> Co-authored-by: Yingchun Lai <laiyingchun@apache.org> Co-authored-by: Hash Brown <hi@xzd.me> Co-authored-by: zuodongxu <192560071+zuodongxu@users.noreply.github.com> Co-authored-by: Masashi Tomooka <tmokmss@users.noreply.github.com> Co-authored-by: aplio <ryo.091219@gmail.com> Co-authored-by: Obada Khalili <54270856+obadakhalili@users.noreply.github.com> Co-authored-by: Nam Vu <zuzoovn@gmail.com> Co-authored-by: Kei YAMAZAKI <1715090+kei-yamazaki@users.noreply.github.com> Co-authored-by: TechnoHouse <13776377+deephbz@users.noreply.github.com> Co-authored-by: Riddhimaan-Senapati <114703025+Riddhimaan-Senapati@users.noreply.github.com> Co-authored-by: MaFee921 <31881301+2284730142@users.noreply.github.com> Co-authored-by: te-chan <t-nakanome@sakura-is.co.jp> Co-authored-by: HQidea <HQidea@users.noreply.github.com> Co-authored-by: Joshbly <36315710+Joshbly@users.noreply.github.com> Co-authored-by: xhe <xw897002528@gmail.com> Co-authored-by: weiwenyan-dev <154779315+weiwenyan-dev@users.noreply.github.com> Co-authored-by: ex_wenyan.wei <ex_wenyan.wei@tcl.com> Co-authored-by: engchina <12236799+engchina@users.noreply.github.com> Co-authored-by: engchina <atjapan2015@gmail.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: 呆萌闷油瓶 <253605712@qq.com> Co-authored-by: Kemal <kemalmeler@outlook.com> Co-authored-by: Lazy_Frog <4590648+lazyFrogLOL@users.noreply.github.com> Co-authored-by: Yi Xiao <54782454+YIXIAO0@users.noreply.github.com> Co-authored-by: Steven sun <98230804+Tuyohai@users.noreply.github.com> Co-authored-by: steven <sunzwj@digitalchina.com> Co-authored-by: Kalo Chin <91766386+fdb02983rhy@users.noreply.github.com> Co-authored-by: Katy Tao <34019945+KatyTao@users.noreply.github.com> Co-authored-by: depy <42985524+h4ckdepy@users.noreply.github.com> Co-authored-by: 胡春东 <gycm520@gmail.com> Co-authored-by: Junjie.M <118170653@qq.com> Co-authored-by: MuYu <mr.muzea@gmail.com> Co-authored-by: Naoki Takashima <39912547+takatea@users.noreply.github.com> Co-authored-by: Summer-Gu <37869445+gubinjie@users.noreply.github.com> Co-authored-by: Fei He <droxer.he@gmail.com> Co-authored-by: ybalbert001 <120714773+ybalbert001@users.noreply.github.com> Co-authored-by: Yuanbo Li <ybalbert@amazon.com> Co-authored-by: douxc <7553076+douxc@users.noreply.github.com> Co-authored-by: liuzhenghua <1090179900@qq.com> Co-authored-by: Wu Jiayang <62842862+Wu-Jiayang@users.noreply.github.com> Co-authored-by: Your Name <you@example.com> Co-authored-by: kimjion <45935338+kimjion@users.noreply.github.com> Co-authored-by: AugNSo <song.tiankai@icloud.com> Co-authored-by: llinvokerl <38915183+llinvokerl@users.noreply.github.com> Co-authored-by: liusurong.lsr <liusurong.lsr@alibaba-inc.com> Co-authored-by: Vasu Negi <vasu-negi@users.noreply.github.com> Co-authored-by: Hundredwz <1808096180@qq.com> Co-authored-by: Xiyuan Chen <52963600+GareArc@users.noreply.github.com>
This commit is contained in:
@ -84,6 +84,13 @@ VOLC_EMBEDDING_ENDPOINT_ID=
|
||||
# 360 AI Credentials
|
||||
ZHINAO_API_KEY=
|
||||
|
||||
# Plugin configuration
|
||||
PLUGIN_DAEMON_KEY=
|
||||
PLUGIN_DAEMON_URL=
|
||||
INNER_API_KEY=
|
||||
|
||||
# Marketplace configuration
|
||||
MARKETPLACE_API_URL=
|
||||
# VESSL AI Credentials
|
||||
VESSL_AI_MODEL_NAME=
|
||||
VESSL_AI_API_KEY=
|
||||
|
||||
@ -1,98 +0,0 @@
|
||||
import os
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Literal, Union
|
||||
|
||||
import anthropic
|
||||
import pytest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from anthropic import Stream
|
||||
from anthropic.resources import Messages
|
||||
from anthropic.types import (
|
||||
ContentBlock,
|
||||
ContentBlockDeltaEvent,
|
||||
Message,
|
||||
MessageDeltaEvent,
|
||||
MessageDeltaUsage,
|
||||
MessageParam,
|
||||
MessageStartEvent,
|
||||
MessageStopEvent,
|
||||
MessageStreamEvent,
|
||||
TextDelta,
|
||||
Usage,
|
||||
)
|
||||
from anthropic.types.message_delta_event import Delta
|
||||
|
||||
MOCK = os.getenv("MOCK_SWITCH", "false") == "true"
|
||||
|
||||
|
||||
class MockAnthropicClass:
|
||||
@staticmethod
|
||||
def mocked_anthropic_chat_create_sync(model: str) -> Message:
|
||||
return Message(
|
||||
id="msg-123",
|
||||
type="message",
|
||||
role="assistant",
|
||||
content=[ContentBlock(text="hello, I'm a chatbot from anthropic", type="text")],
|
||||
model=model,
|
||||
stop_reason="stop_sequence",
|
||||
usage=Usage(input_tokens=1, output_tokens=1),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def mocked_anthropic_chat_create_stream(model: str) -> Stream[MessageStreamEvent]:
|
||||
full_response_text = "hello, I'm a chatbot from anthropic"
|
||||
|
||||
yield MessageStartEvent(
|
||||
type="message_start",
|
||||
message=Message(
|
||||
id="msg-123",
|
||||
content=[],
|
||||
role="assistant",
|
||||
model=model,
|
||||
stop_reason=None,
|
||||
type="message",
|
||||
usage=Usage(input_tokens=1, output_tokens=1),
|
||||
),
|
||||
)
|
||||
|
||||
index = 0
|
||||
for i in range(0, len(full_response_text)):
|
||||
yield ContentBlockDeltaEvent(
|
||||
type="content_block_delta", delta=TextDelta(text=full_response_text[i], type="text_delta"), index=index
|
||||
)
|
||||
|
||||
index += 1
|
||||
|
||||
yield MessageDeltaEvent(
|
||||
type="message_delta", delta=Delta(stop_reason="stop_sequence"), usage=MessageDeltaUsage(output_tokens=1)
|
||||
)
|
||||
|
||||
yield MessageStopEvent(type="message_stop")
|
||||
|
||||
def mocked_anthropic(
|
||||
self: Messages,
|
||||
*,
|
||||
max_tokens: int,
|
||||
messages: Iterable[MessageParam],
|
||||
model: str,
|
||||
stream: Literal[True],
|
||||
**kwargs: Any,
|
||||
) -> Union[Message, Stream[MessageStreamEvent]]:
|
||||
if len(self._client.api_key) < 18:
|
||||
raise anthropic.AuthenticationError("Invalid API key")
|
||||
|
||||
if stream:
|
||||
return MockAnthropicClass.mocked_anthropic_chat_create_stream(model=model)
|
||||
else:
|
||||
return MockAnthropicClass.mocked_anthropic_chat_create_sync(model=model)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_anthropic_mock(request, monkeypatch: MonkeyPatch):
|
||||
if MOCK:
|
||||
monkeypatch.setattr(Messages, "create", MockAnthropicClass.mocked_anthropic)
|
||||
|
||||
yield
|
||||
|
||||
if MOCK:
|
||||
monkeypatch.undo()
|
||||
@ -1,82 +0,0 @@
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from typing import Literal
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
|
||||
|
||||
def mock_get(*args, **kwargs):
|
||||
if kwargs.get("headers", {}).get("Authorization") != "Bearer test":
|
||||
raise httpx.HTTPStatusError(
|
||||
"Invalid API key",
|
||||
request=httpx.Request("GET", ""),
|
||||
response=httpx.Response(401),
|
||||
)
|
||||
|
||||
return httpx.Response(
|
||||
200,
|
||||
json={
|
||||
"items": [
|
||||
{"title": "Model 1", "_id": "model1"},
|
||||
{"title": "Model 2", "_id": "model2"},
|
||||
]
|
||||
},
|
||||
request=httpx.Request("GET", ""),
|
||||
)
|
||||
|
||||
|
||||
def mock_stream(*args, **kwargs):
|
||||
class MockStreamResponse:
|
||||
def __init__(self):
|
||||
self.status_code = 200
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
pass
|
||||
|
||||
def iter_bytes(self):
|
||||
yield b"Mocked audio data"
|
||||
|
||||
return MockStreamResponse()
|
||||
|
||||
|
||||
def mock_fishaudio(
|
||||
monkeypatch: MonkeyPatch,
|
||||
methods: list[Literal["list-models", "tts"]],
|
||||
) -> Callable[[], None]:
|
||||
"""
|
||||
mock fishaudio module
|
||||
|
||||
:param monkeypatch: pytest monkeypatch fixture
|
||||
:return: unpatch function
|
||||
"""
|
||||
|
||||
def unpatch() -> None:
|
||||
monkeypatch.undo()
|
||||
|
||||
if "list-models" in methods:
|
||||
monkeypatch.setattr(httpx, "get", mock_get)
|
||||
|
||||
if "tts" in methods:
|
||||
monkeypatch.setattr(httpx, "stream", mock_stream)
|
||||
|
||||
return unpatch
|
||||
|
||||
|
||||
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_fishaudio_mock(request, monkeypatch):
|
||||
methods = request.param if hasattr(request, "param") else []
|
||||
if MOCK:
|
||||
unpatch = mock_fishaudio(monkeypatch, methods=methods)
|
||||
|
||||
yield
|
||||
|
||||
if MOCK:
|
||||
unpatch()
|
||||
@ -1,115 +0,0 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import google.generativeai.types.generation_types as generation_config_types # type: ignore
|
||||
import pytest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from google.ai import generativelanguage as glm
|
||||
from google.ai.generativelanguage_v1beta.types import content as gag_content
|
||||
from google.generativeai import GenerativeModel
|
||||
from google.generativeai.types import GenerateContentResponse, content_types, safety_types
|
||||
from google.generativeai.types.generation_types import BaseGenerateContentResponse
|
||||
|
||||
from extensions import ext_redis
|
||||
|
||||
|
||||
class MockGoogleResponseClass:
|
||||
_done = False
|
||||
|
||||
def __iter__(self):
|
||||
full_response_text = "it's google!"
|
||||
|
||||
for i in range(0, len(full_response_text) + 1, 1):
|
||||
if i == len(full_response_text):
|
||||
self._done = True
|
||||
yield GenerateContentResponse(
|
||||
done=True, iterator=None, result=glm.GenerateContentResponse({}), chunks=[]
|
||||
)
|
||||
else:
|
||||
yield GenerateContentResponse(
|
||||
done=False, iterator=None, result=glm.GenerateContentResponse({}), chunks=[]
|
||||
)
|
||||
|
||||
|
||||
class MockGoogleResponseCandidateClass:
|
||||
finish_reason = "stop"
|
||||
|
||||
@property
|
||||
def content(self) -> gag_content.Content:
|
||||
return gag_content.Content(parts=[gag_content.Part(text="it's google!")])
|
||||
|
||||
|
||||
class MockGoogleClass:
|
||||
@staticmethod
|
||||
def generate_content_sync() -> GenerateContentResponse:
|
||||
return GenerateContentResponse(done=True, iterator=None, result=glm.GenerateContentResponse({}), chunks=[])
|
||||
|
||||
@staticmethod
|
||||
def generate_content_stream() -> MockGoogleResponseClass:
|
||||
return MockGoogleResponseClass()
|
||||
|
||||
def generate_content(
|
||||
self: GenerativeModel,
|
||||
contents: content_types.ContentsType,
|
||||
*,
|
||||
generation_config: generation_config_types.GenerationConfigType | None = None,
|
||||
safety_settings: safety_types.SafetySettingOptions | None = None,
|
||||
stream: bool = False,
|
||||
**kwargs,
|
||||
) -> GenerateContentResponse:
|
||||
if stream:
|
||||
return MockGoogleClass.generate_content_stream()
|
||||
|
||||
return MockGoogleClass.generate_content_sync()
|
||||
|
||||
@property
|
||||
def generative_response_text(self) -> str:
|
||||
return "it's google!"
|
||||
|
||||
@property
|
||||
def generative_response_candidates(self) -> list[MockGoogleResponseCandidateClass]:
|
||||
return [MockGoogleResponseCandidateClass()]
|
||||
|
||||
|
||||
def mock_configure(api_key: str):
|
||||
if len(api_key) < 16:
|
||||
raise Exception("Invalid API key")
|
||||
|
||||
|
||||
class MockFileState:
|
||||
def __init__(self):
|
||||
self.name = "FINISHED"
|
||||
|
||||
|
||||
class MockGoogleFile:
|
||||
def __init__(self, name: str = "mock_file_name"):
|
||||
self.name = name
|
||||
self.state = MockFileState()
|
||||
|
||||
|
||||
def mock_get_file(name: str) -> MockGoogleFile:
|
||||
return MockGoogleFile(name)
|
||||
|
||||
|
||||
def mock_upload_file(path: str, mime_type: str) -> MockGoogleFile:
|
||||
return MockGoogleFile()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_google_mock(request, monkeypatch: MonkeyPatch):
|
||||
monkeypatch.setattr(BaseGenerateContentResponse, "text", MockGoogleClass.generative_response_text)
|
||||
monkeypatch.setattr(BaseGenerateContentResponse, "candidates", MockGoogleClass.generative_response_candidates)
|
||||
monkeypatch.setattr(GenerativeModel, "generate_content", MockGoogleClass.generate_content)
|
||||
monkeypatch.setattr("google.generativeai.configure", mock_configure)
|
||||
monkeypatch.setattr("google.generativeai.get_file", mock_get_file)
|
||||
monkeypatch.setattr("google.generativeai.upload_file", mock_upload_file)
|
||||
|
||||
yield
|
||||
|
||||
monkeypatch.undo()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_mock_redis() -> None:
|
||||
ext_redis.redis_client.get = MagicMock(return_value=None)
|
||||
ext_redis.redis_client.setex = MagicMock(return_value=None)
|
||||
ext_redis.redis_client.exists = MagicMock(return_value=True)
|
||||
@ -1,20 +0,0 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from huggingface_hub import InferenceClient # type: ignore
|
||||
|
||||
from tests.integration_tests.model_runtime.__mock.huggingface_chat import MockHuggingfaceChatClass
|
||||
|
||||
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_huggingface_mock(request, monkeypatch: MonkeyPatch):
|
||||
if MOCK:
|
||||
monkeypatch.setattr(InferenceClient, "text_generation", MockHuggingfaceChatClass.text_generation)
|
||||
|
||||
yield
|
||||
|
||||
if MOCK:
|
||||
monkeypatch.undo()
|
||||
@ -1,56 +0,0 @@
|
||||
import re
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from huggingface_hub import InferenceClient # type: ignore
|
||||
from huggingface_hub.inference._text_generation import ( # type: ignore
|
||||
Details,
|
||||
StreamDetails,
|
||||
TextGenerationResponse,
|
||||
TextGenerationStreamResponse,
|
||||
Token,
|
||||
)
|
||||
from huggingface_hub.utils import BadRequestError # type: ignore
|
||||
|
||||
|
||||
class MockHuggingfaceChatClass:
|
||||
@staticmethod
|
||||
def generate_create_sync(model: str) -> TextGenerationResponse:
|
||||
response = TextGenerationResponse(
|
||||
generated_text="You can call me Miku Miku o~e~o~",
|
||||
details=Details(
|
||||
finish_reason="length",
|
||||
generated_tokens=6,
|
||||
tokens=[Token(id=0, text="You", logprob=0.0, special=False) for i in range(0, 6)],
|
||||
),
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@staticmethod
|
||||
def generate_create_stream(model: str) -> Generator[TextGenerationStreamResponse, None, None]:
|
||||
full_text = "You can call me Miku Miku o~e~o~"
|
||||
|
||||
for i in range(0, len(full_text)):
|
||||
response = TextGenerationStreamResponse(
|
||||
token=Token(id=i, text=full_text[i], logprob=0.0, special=False),
|
||||
)
|
||||
response.generated_text = full_text[i]
|
||||
response.details = StreamDetails(finish_reason="stop_sequence", generated_tokens=1)
|
||||
|
||||
yield response
|
||||
|
||||
def text_generation(
|
||||
self: InferenceClient, prompt: str, *, stream: Literal[False] = ..., model: Optional[str] = None, **kwargs: Any
|
||||
) -> Union[TextGenerationResponse, Generator[TextGenerationStreamResponse, None, None]]:
|
||||
# check if key is valid
|
||||
if not re.match(r"Bearer\shf\-[a-zA-Z0-9]{16,}", self.headers["authorization"]):
|
||||
raise BadRequestError("Invalid API key")
|
||||
|
||||
if model is None:
|
||||
raise BadRequestError("Invalid model")
|
||||
|
||||
if stream:
|
||||
return MockHuggingfaceChatClass.generate_create_stream(model)
|
||||
return MockHuggingfaceChatClass.generate_create_sync(model)
|
||||
@ -1,94 +0,0 @@
|
||||
from core.model_runtime.model_providers.huggingface_tei.tei_helper import TeiModelExtraParameter
|
||||
|
||||
|
||||
class MockTEIClass:
|
||||
@staticmethod
|
||||
def get_tei_extra_parameter(server_url: str, model_name: str) -> TeiModelExtraParameter:
|
||||
# During mock, we don't have a real server to query, so we just return a dummy value
|
||||
if "rerank" in model_name:
|
||||
model_type = "reranker"
|
||||
else:
|
||||
model_type = "embedding"
|
||||
|
||||
return TeiModelExtraParameter(model_type=model_type, max_input_length=512, max_client_batch_size=1)
|
||||
|
||||
@staticmethod
|
||||
def invoke_tokenize(server_url: str, texts: list[str]) -> list[list[dict]]:
|
||||
# Use space as token separator, and split the text into tokens
|
||||
tokenized_texts = []
|
||||
for text in texts:
|
||||
tokens = text.split(" ")
|
||||
current_index = 0
|
||||
tokenized_text = []
|
||||
for idx, token in enumerate(tokens):
|
||||
s_token = {
|
||||
"id": idx,
|
||||
"text": token,
|
||||
"special": False,
|
||||
"start": current_index,
|
||||
"stop": current_index + len(token),
|
||||
}
|
||||
current_index += len(token) + 1
|
||||
tokenized_text.append(s_token)
|
||||
tokenized_texts.append(tokenized_text)
|
||||
return tokenized_texts
|
||||
|
||||
@staticmethod
|
||||
def invoke_embeddings(server_url: str, texts: list[str]) -> dict:
|
||||
# {
|
||||
# "object": "list",
|
||||
# "data": [
|
||||
# {
|
||||
# "object": "embedding",
|
||||
# "embedding": [...],
|
||||
# "index": 0
|
||||
# }
|
||||
# ],
|
||||
# "model": "MODEL_NAME",
|
||||
# "usage": {
|
||||
# "prompt_tokens": 3,
|
||||
# "total_tokens": 3
|
||||
# }
|
||||
# }
|
||||
embeddings = []
|
||||
for idx in range(len(texts)):
|
||||
embedding = [0.1] * 768
|
||||
embeddings.append(
|
||||
{
|
||||
"object": "embedding",
|
||||
"embedding": embedding,
|
||||
"index": idx,
|
||||
}
|
||||
)
|
||||
return {
|
||||
"object": "list",
|
||||
"data": embeddings,
|
||||
"model": "MODEL_NAME",
|
||||
"usage": {
|
||||
"prompt_tokens": sum(len(text.split(" ")) for text in texts),
|
||||
"total_tokens": sum(len(text.split(" ")) for text in texts),
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def invoke_rerank(server_url: str, query: str, texts: list[str]) -> list[dict]:
|
||||
# Example response:
|
||||
# [
|
||||
# {
|
||||
# "index": 0,
|
||||
# "text": "Deep Learning is ...",
|
||||
# "score": 0.9950755
|
||||
# }
|
||||
# ]
|
||||
reranked_docs = []
|
||||
for idx, text in enumerate(texts):
|
||||
reranked_docs.append(
|
||||
{
|
||||
"index": idx,
|
||||
"text": text,
|
||||
"score": 0.9,
|
||||
}
|
||||
)
|
||||
# For mock, only return the first document
|
||||
break
|
||||
return reranked_docs
|
||||
@ -1,59 +0,0 @@
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Literal
|
||||
|
||||
import pytest
|
||||
|
||||
# import monkeypatch
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from nomic import embed # type: ignore
|
||||
|
||||
|
||||
def create_embedding(texts: list[str], model: str, **kwargs: Any) -> dict:
|
||||
texts_len = len(texts)
|
||||
|
||||
foo_embedding_sample = 0.123456
|
||||
|
||||
combined = {
|
||||
"embeddings": [[foo_embedding_sample for _ in range(768)] for _ in range(texts_len)],
|
||||
"usage": {"prompt_tokens": texts_len, "total_tokens": texts_len},
|
||||
"model": model,
|
||||
"inference_mode": "remote",
|
||||
}
|
||||
|
||||
return combined
|
||||
|
||||
|
||||
def mock_nomic(
|
||||
monkeypatch: MonkeyPatch,
|
||||
methods: list[Literal["text_embedding"]],
|
||||
) -> Callable[[], None]:
|
||||
"""
|
||||
mock nomic module
|
||||
|
||||
:param monkeypatch: pytest monkeypatch fixture
|
||||
:return: unpatch function
|
||||
"""
|
||||
|
||||
def unpatch() -> None:
|
||||
monkeypatch.undo()
|
||||
|
||||
if "text_embedding" in methods:
|
||||
monkeypatch.setattr(embed, "text", create_embedding)
|
||||
|
||||
return unpatch
|
||||
|
||||
|
||||
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_nomic_mock(request, monkeypatch):
|
||||
methods = request.param if hasattr(request, "param") else []
|
||||
if MOCK:
|
||||
unpatch = mock_nomic(monkeypatch, methods=methods)
|
||||
|
||||
yield
|
||||
|
||||
if MOCK:
|
||||
unpatch()
|
||||
@ -1,71 +0,0 @@
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from typing import Literal
|
||||
|
||||
import pytest
|
||||
|
||||
# import monkeypatch
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from openai.resources.audio.transcriptions import Transcriptions
|
||||
from openai.resources.chat import Completions as ChatCompletions
|
||||
from openai.resources.completions import Completions
|
||||
from openai.resources.embeddings import Embeddings
|
||||
from openai.resources.models import Models
|
||||
from openai.resources.moderations import Moderations
|
||||
|
||||
from tests.integration_tests.model_runtime.__mock.openai_chat import MockChatClass
|
||||
from tests.integration_tests.model_runtime.__mock.openai_completion import MockCompletionsClass
|
||||
from tests.integration_tests.model_runtime.__mock.openai_embeddings import MockEmbeddingsClass
|
||||
from tests.integration_tests.model_runtime.__mock.openai_moderation import MockModerationClass
|
||||
from tests.integration_tests.model_runtime.__mock.openai_remote import MockModelClass
|
||||
from tests.integration_tests.model_runtime.__mock.openai_speech2text import MockSpeech2TextClass
|
||||
|
||||
|
||||
def mock_openai(
|
||||
monkeypatch: MonkeyPatch,
|
||||
methods: list[Literal["completion", "chat", "remote", "moderation", "speech2text", "text_embedding"]],
|
||||
) -> Callable[[], None]:
|
||||
"""
|
||||
mock openai module
|
||||
|
||||
:param monkeypatch: pytest monkeypatch fixture
|
||||
:return: unpatch function
|
||||
"""
|
||||
|
||||
def unpatch() -> None:
|
||||
monkeypatch.undo()
|
||||
|
||||
if "completion" in methods:
|
||||
monkeypatch.setattr(Completions, "create", MockCompletionsClass.completion_create)
|
||||
|
||||
if "chat" in methods:
|
||||
monkeypatch.setattr(ChatCompletions, "create", MockChatClass.chat_create)
|
||||
|
||||
if "remote" in methods:
|
||||
monkeypatch.setattr(Models, "list", MockModelClass.list)
|
||||
|
||||
if "moderation" in methods:
|
||||
monkeypatch.setattr(Moderations, "create", MockModerationClass.moderation_create)
|
||||
|
||||
if "speech2text" in methods:
|
||||
monkeypatch.setattr(Transcriptions, "create", MockSpeech2TextClass.speech2text_create)
|
||||
|
||||
if "text_embedding" in methods:
|
||||
monkeypatch.setattr(Embeddings, "create", MockEmbeddingsClass.create_embeddings)
|
||||
|
||||
return unpatch
|
||||
|
||||
|
||||
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_openai_mock(request, monkeypatch):
|
||||
methods = request.param if hasattr(request, "param") else []
|
||||
if MOCK:
|
||||
unpatch = mock_openai(monkeypatch, methods=methods)
|
||||
|
||||
yield
|
||||
|
||||
if MOCK:
|
||||
unpatch()
|
||||
@ -1,267 +0,0 @@
|
||||
import re
|
||||
from collections.abc import Generator
|
||||
from json import dumps
|
||||
from time import time
|
||||
|
||||
# import monkeypatch
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
from openai import AzureOpenAI, OpenAI
|
||||
from openai._types import NOT_GIVEN, NotGiven
|
||||
from openai.resources.chat.completions import Completions
|
||||
from openai.types import Completion as CompletionMessage
|
||||
from openai.types.chat import (
|
||||
ChatCompletionChunk,
|
||||
ChatCompletionMessageParam,
|
||||
ChatCompletionMessageToolCall,
|
||||
ChatCompletionToolParam,
|
||||
completion_create_params,
|
||||
)
|
||||
from openai.types.chat.chat_completion import ChatCompletion as _ChatCompletion
|
||||
from openai.types.chat.chat_completion import Choice as _ChatCompletionChoice
|
||||
from openai.types.chat.chat_completion_chunk import (
|
||||
Choice,
|
||||
ChoiceDelta,
|
||||
ChoiceDeltaFunctionCall,
|
||||
ChoiceDeltaToolCall,
|
||||
ChoiceDeltaToolCallFunction,
|
||||
)
|
||||
from openai.types.chat.chat_completion_message import ChatCompletionMessage, FunctionCall
|
||||
from openai.types.chat.chat_completion_message_tool_call import Function
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
|
||||
|
||||
class MockChatClass:
|
||||
@staticmethod
|
||||
def generate_function_call(
|
||||
functions: list[completion_create_params.Function] | NotGiven = NOT_GIVEN,
|
||||
) -> Optional[FunctionCall]:
|
||||
if not functions or len(functions) == 0:
|
||||
return None
|
||||
function: completion_create_params.Function = functions[0]
|
||||
function_name = function["name"]
|
||||
function_description = function["description"]
|
||||
function_parameters = function["parameters"]
|
||||
function_parameters_type = function_parameters["type"]
|
||||
if function_parameters_type != "object":
|
||||
return None
|
||||
function_parameters_properties = function_parameters["properties"]
|
||||
function_parameters_required = function_parameters["required"]
|
||||
parameters = {}
|
||||
for parameter_name, parameter in function_parameters_properties.items():
|
||||
if parameter_name not in function_parameters_required:
|
||||
continue
|
||||
parameter_type = parameter["type"]
|
||||
if parameter_type == "string":
|
||||
if "enum" in parameter:
|
||||
if len(parameter["enum"]) == 0:
|
||||
continue
|
||||
parameters[parameter_name] = parameter["enum"][0]
|
||||
else:
|
||||
parameters[parameter_name] = "kawaii"
|
||||
elif parameter_type == "integer":
|
||||
parameters[parameter_name] = 114514
|
||||
elif parameter_type == "number":
|
||||
parameters[parameter_name] = 1919810.0
|
||||
elif parameter_type == "boolean":
|
||||
parameters[parameter_name] = True
|
||||
|
||||
return FunctionCall(name=function_name, arguments=dumps(parameters))
|
||||
|
||||
@staticmethod
|
||||
def generate_tool_calls(tools=NOT_GIVEN) -> Optional[list[ChatCompletionMessageToolCall]]:
|
||||
list_tool_calls = []
|
||||
if not tools or len(tools) == 0:
|
||||
return None
|
||||
tool = tools[0]
|
||||
|
||||
if "type" in tools and tools["type"] != "function":
|
||||
return None
|
||||
|
||||
function = tool["function"]
|
||||
|
||||
function_call = MockChatClass.generate_function_call(functions=[function])
|
||||
if function_call is None:
|
||||
return None
|
||||
|
||||
list_tool_calls.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
id="sakurajima-mai",
|
||||
function=Function(
|
||||
name=function_call.name,
|
||||
arguments=function_call.arguments,
|
||||
),
|
||||
type="function",
|
||||
)
|
||||
)
|
||||
|
||||
return list_tool_calls
|
||||
|
||||
@staticmethod
|
||||
def mocked_openai_chat_create_sync(
|
||||
model: str,
|
||||
functions: list[completion_create_params.Function] | NotGiven = NOT_GIVEN,
|
||||
tools: list[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
|
||||
) -> CompletionMessage:
|
||||
tool_calls = []
|
||||
function_call = MockChatClass.generate_function_call(functions=functions)
|
||||
if not function_call:
|
||||
tool_calls = MockChatClass.generate_tool_calls(tools=tools)
|
||||
|
||||
return _ChatCompletion(
|
||||
id="cmpl-3QJQa5jXJ5Z5X",
|
||||
choices=[
|
||||
_ChatCompletionChoice(
|
||||
finish_reason="content_filter",
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
content="elaina", role="assistant", function_call=function_call, tool_calls=tool_calls
|
||||
),
|
||||
)
|
||||
],
|
||||
created=int(time()),
|
||||
model=model,
|
||||
object="chat.completion",
|
||||
system_fingerprint="",
|
||||
usage=CompletionUsage(
|
||||
prompt_tokens=2,
|
||||
completion_tokens=1,
|
||||
total_tokens=3,
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def mocked_openai_chat_create_stream(
|
||||
model: str,
|
||||
functions: list[completion_create_params.Function] | NotGiven = NOT_GIVEN,
|
||||
tools: list[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
|
||||
) -> Generator[ChatCompletionChunk, None, None]:
|
||||
tool_calls = []
|
||||
function_call = MockChatClass.generate_function_call(functions=functions)
|
||||
if not function_call:
|
||||
tool_calls = MockChatClass.generate_tool_calls(tools=tools)
|
||||
|
||||
full_text = "Hello, world!\n\n```python\nprint('Hello, world!')\n```"
|
||||
for i in range(0, len(full_text) + 1):
|
||||
if i == len(full_text):
|
||||
yield ChatCompletionChunk(
|
||||
id="cmpl-3QJQa5jXJ5Z5X",
|
||||
choices=[
|
||||
Choice(
|
||||
delta=ChoiceDelta(
|
||||
content="",
|
||||
function_call=ChoiceDeltaFunctionCall(
|
||||
name=function_call.name,
|
||||
arguments=function_call.arguments,
|
||||
)
|
||||
if function_call
|
||||
else None,
|
||||
role="assistant",
|
||||
tool_calls=[
|
||||
ChoiceDeltaToolCall(
|
||||
index=0,
|
||||
id="misaka-mikoto",
|
||||
function=ChoiceDeltaToolCallFunction(
|
||||
name=tool_calls[0].function.name,
|
||||
arguments=tool_calls[0].function.arguments,
|
||||
),
|
||||
type="function",
|
||||
)
|
||||
]
|
||||
if tool_calls and len(tool_calls) > 0
|
||||
else None,
|
||||
),
|
||||
finish_reason="function_call",
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
created=int(time()),
|
||||
model=model,
|
||||
object="chat.completion.chunk",
|
||||
system_fingerprint="",
|
||||
usage=CompletionUsage(
|
||||
prompt_tokens=2,
|
||||
completion_tokens=17,
|
||||
total_tokens=19,
|
||||
),
|
||||
)
|
||||
else:
|
||||
yield ChatCompletionChunk(
|
||||
id="cmpl-3QJQa5jXJ5Z5X",
|
||||
choices=[
|
||||
Choice(
|
||||
delta=ChoiceDelta(
|
||||
content=full_text[i],
|
||||
role="assistant",
|
||||
),
|
||||
finish_reason="content_filter",
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
created=int(time()),
|
||||
model=model,
|
||||
object="chat.completion.chunk",
|
||||
system_fingerprint="",
|
||||
)
|
||||
|
||||
def chat_create(
|
||||
self: Completions,
|
||||
*,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
model: Union[
|
||||
str,
|
||||
Literal[
|
||||
"gpt-4-1106-preview",
|
||||
"gpt-4-vision-preview",
|
||||
"gpt-4",
|
||||
"gpt-4-0314",
|
||||
"gpt-4-0613",
|
||||
"gpt-4-32k",
|
||||
"gpt-4-32k-0314",
|
||||
"gpt-4-32k-0613",
|
||||
"gpt-3.5-turbo-1106",
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-3.5-turbo-16k",
|
||||
"gpt-3.5-turbo-0301",
|
||||
"gpt-3.5-turbo-0613",
|
||||
"gpt-3.5-turbo-16k-0613",
|
||||
],
|
||||
],
|
||||
functions: list[completion_create_params.Function] | NotGiven = NOT_GIVEN,
|
||||
response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN,
|
||||
stream: Optional[Literal[False]] | NotGiven = NOT_GIVEN,
|
||||
tools: list[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
|
||||
**kwargs: Any,
|
||||
):
|
||||
openai_models = [
|
||||
"gpt-4-1106-preview",
|
||||
"gpt-4-vision-preview",
|
||||
"gpt-4",
|
||||
"gpt-4-0314",
|
||||
"gpt-4-0613",
|
||||
"gpt-4-32k",
|
||||
"gpt-4-32k-0314",
|
||||
"gpt-4-32k-0613",
|
||||
"gpt-3.5-turbo-1106",
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-3.5-turbo-16k",
|
||||
"gpt-3.5-turbo-0301",
|
||||
"gpt-3.5-turbo-0613",
|
||||
"gpt-3.5-turbo-16k-0613",
|
||||
]
|
||||
azure_openai_models = ["gpt35", "gpt-4v", "gpt-35-turbo"]
|
||||
if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", str(self._client.base_url)):
|
||||
raise InvokeAuthorizationError("Invalid base url")
|
||||
if model in openai_models + azure_openai_models:
|
||||
if not re.match(r"sk-[a-zA-Z0-9]{24,}$", self._client.api_key) and type(self._client) == OpenAI:
|
||||
# sometime, provider use OpenAI compatible API will not have api key or have different api key format
|
||||
# so we only check if model is in openai_models
|
||||
raise InvokeAuthorizationError("Invalid api key")
|
||||
if len(self._client.api_key) < 18 and type(self._client) == AzureOpenAI:
|
||||
raise InvokeAuthorizationError("Invalid api key")
|
||||
if stream:
|
||||
return MockChatClass.mocked_openai_chat_create_stream(model=model, functions=functions, tools=tools)
|
||||
|
||||
return MockChatClass.mocked_openai_chat_create_sync(model=model, functions=functions, tools=tools)
|
||||
@ -1,130 +0,0 @@
|
||||
import re
|
||||
from collections.abc import Generator
|
||||
from time import time
|
||||
|
||||
# import monkeypatch
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
from openai import AzureOpenAI, BadRequestError, OpenAI
|
||||
from openai._types import NOT_GIVEN, NotGiven
|
||||
from openai.resources.completions import Completions
|
||||
from openai.types import Completion as CompletionMessage
|
||||
from openai.types.completion import CompletionChoice
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
|
||||
|
||||
class MockCompletionsClass:
|
||||
@staticmethod
|
||||
def mocked_openai_completion_create_sync(model: str) -> CompletionMessage:
|
||||
return CompletionMessage(
|
||||
id="cmpl-3QJQa5jXJ5Z5X",
|
||||
object="text_completion",
|
||||
created=int(time()),
|
||||
model=model,
|
||||
system_fingerprint="",
|
||||
choices=[
|
||||
CompletionChoice(
|
||||
text="mock",
|
||||
index=0,
|
||||
logprobs=None,
|
||||
finish_reason="stop",
|
||||
)
|
||||
],
|
||||
usage=CompletionUsage(
|
||||
prompt_tokens=2,
|
||||
completion_tokens=1,
|
||||
total_tokens=3,
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def mocked_openai_completion_create_stream(model: str) -> Generator[CompletionMessage, None, None]:
|
||||
full_text = "Hello, world!\n\n```python\nprint('Hello, world!')\n```"
|
||||
for i in range(0, len(full_text) + 1):
|
||||
if i == len(full_text):
|
||||
yield CompletionMessage(
|
||||
id="cmpl-3QJQa5jXJ5Z5X",
|
||||
object="text_completion",
|
||||
created=int(time()),
|
||||
model=model,
|
||||
system_fingerprint="",
|
||||
choices=[
|
||||
CompletionChoice(
|
||||
text="",
|
||||
index=0,
|
||||
logprobs=None,
|
||||
finish_reason="stop",
|
||||
)
|
||||
],
|
||||
usage=CompletionUsage(
|
||||
prompt_tokens=2,
|
||||
completion_tokens=17,
|
||||
total_tokens=19,
|
||||
),
|
||||
)
|
||||
else:
|
||||
yield CompletionMessage(
|
||||
id="cmpl-3QJQa5jXJ5Z5X",
|
||||
object="text_completion",
|
||||
created=int(time()),
|
||||
model=model,
|
||||
system_fingerprint="",
|
||||
choices=[
|
||||
CompletionChoice(text=full_text[i], index=0, logprobs=None, finish_reason="content_filter")
|
||||
],
|
||||
)
|
||||
|
||||
def completion_create(
|
||||
self: Completions,
|
||||
*,
|
||||
model: Union[
|
||||
str,
|
||||
Literal[
|
||||
"babbage-002",
|
||||
"davinci-002",
|
||||
"gpt-3.5-turbo-instruct",
|
||||
"text-davinci-003",
|
||||
"text-davinci-002",
|
||||
"text-davinci-001",
|
||||
"code-davinci-002",
|
||||
"text-curie-001",
|
||||
"text-babbage-001",
|
||||
"text-ada-001",
|
||||
],
|
||||
],
|
||||
prompt: Union[str, list[str], list[int], list[list[int]], None],
|
||||
stream: Optional[Literal[False]] | NotGiven = NOT_GIVEN,
|
||||
**kwargs: Any,
|
||||
):
|
||||
openai_models = [
|
||||
"babbage-002",
|
||||
"davinci-002",
|
||||
"gpt-3.5-turbo-instruct",
|
||||
"text-davinci-003",
|
||||
"text-davinci-002",
|
||||
"text-davinci-001",
|
||||
"code-davinci-002",
|
||||
"text-curie-001",
|
||||
"text-babbage-001",
|
||||
"text-ada-001",
|
||||
]
|
||||
azure_openai_models = ["gpt-35-turbo-instruct"]
|
||||
|
||||
if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", str(self._client.base_url)):
|
||||
raise InvokeAuthorizationError("Invalid base url")
|
||||
if model in openai_models + azure_openai_models:
|
||||
if not re.match(r"sk-[a-zA-Z0-9]{24,}$", self._client.api_key) and type(self._client) == OpenAI:
|
||||
# sometime, provider use OpenAI compatible API will not have api key or have different api key format
|
||||
# so we only check if model is in openai_models
|
||||
raise InvokeAuthorizationError("Invalid api key")
|
||||
if len(self._client.api_key) < 18 and type(self._client) == AzureOpenAI:
|
||||
raise InvokeAuthorizationError("Invalid api key")
|
||||
|
||||
if not prompt:
|
||||
raise BadRequestError("Invalid prompt")
|
||||
if stream:
|
||||
return MockCompletionsClass.mocked_openai_completion_create_stream(model=model)
|
||||
|
||||
return MockCompletionsClass.mocked_openai_completion_create_sync(model=model)
|
||||
File diff suppressed because one or more lines are too long
@ -1,140 +0,0 @@
|
||||
import re
|
||||
from typing import Any, Literal, Union
|
||||
|
||||
from openai._types import NOT_GIVEN, NotGiven
|
||||
from openai.resources.moderations import Moderations
|
||||
from openai.types import ModerationCreateResponse
|
||||
from openai.types.moderation import Categories, CategoryScores, Moderation
|
||||
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
|
||||
|
||||
class MockModerationClass:
|
||||
def moderation_create(
|
||||
self: Moderations,
|
||||
*,
|
||||
input: Union[str, list[str]],
|
||||
model: Union[str, Literal["text-moderation-latest", "text-moderation-stable"]] | NotGiven = NOT_GIVEN,
|
||||
**kwargs: Any,
|
||||
) -> ModerationCreateResponse:
|
||||
if isinstance(input, str):
|
||||
input = [input]
|
||||
|
||||
if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", str(self._client.base_url)):
|
||||
raise InvokeAuthorizationError("Invalid base url")
|
||||
|
||||
if len(self._client.api_key) < 18:
|
||||
raise InvokeAuthorizationError("Invalid API key")
|
||||
|
||||
for text in input:
|
||||
result = []
|
||||
if "kill" in text:
|
||||
moderation_categories = {
|
||||
"harassment": False,
|
||||
"harassment/threatening": False,
|
||||
"hate": False,
|
||||
"hate/threatening": False,
|
||||
"self-harm": False,
|
||||
"self-harm/instructions": False,
|
||||
"self-harm/intent": False,
|
||||
"sexual": False,
|
||||
"sexual/minors": False,
|
||||
"violence": False,
|
||||
"violence/graphic": False,
|
||||
"illicit": False,
|
||||
"illicit/violent": False,
|
||||
}
|
||||
moderation_categories_scores = {
|
||||
"harassment": 1.0,
|
||||
"harassment/threatening": 1.0,
|
||||
"hate": 1.0,
|
||||
"hate/threatening": 1.0,
|
||||
"self-harm": 1.0,
|
||||
"self-harm/instructions": 1.0,
|
||||
"self-harm/intent": 1.0,
|
||||
"sexual": 1.0,
|
||||
"sexual/minors": 1.0,
|
||||
"violence": 1.0,
|
||||
"violence/graphic": 1.0,
|
||||
"illicit": 1.0,
|
||||
"illicit/violent": 1.0,
|
||||
}
|
||||
category_applied_input_types = {
|
||||
"sexual": ["text", "image"],
|
||||
"hate": ["text"],
|
||||
"harassment": ["text"],
|
||||
"self-harm": ["text", "image"],
|
||||
"sexual/minors": ["text"],
|
||||
"hate/threatening": ["text"],
|
||||
"violence/graphic": ["text", "image"],
|
||||
"self-harm/intent": ["text", "image"],
|
||||
"self-harm/instructions": ["text", "image"],
|
||||
"harassment/threatening": ["text"],
|
||||
"violence": ["text", "image"],
|
||||
"illicit": ["text"],
|
||||
"illicit/violent": ["text"],
|
||||
}
|
||||
result.append(
|
||||
Moderation(
|
||||
flagged=True,
|
||||
categories=Categories(**moderation_categories),
|
||||
category_scores=CategoryScores(**moderation_categories_scores),
|
||||
category_applied_input_types=category_applied_input_types,
|
||||
)
|
||||
)
|
||||
else:
|
||||
moderation_categories = {
|
||||
"harassment": False,
|
||||
"harassment/threatening": False,
|
||||
"hate": False,
|
||||
"hate/threatening": False,
|
||||
"self-harm": False,
|
||||
"self-harm/instructions": False,
|
||||
"self-harm/intent": False,
|
||||
"sexual": False,
|
||||
"sexual/minors": False,
|
||||
"violence": False,
|
||||
"violence/graphic": False,
|
||||
"illicit": False,
|
||||
"illicit/violent": False,
|
||||
}
|
||||
moderation_categories_scores = {
|
||||
"harassment": 0.0,
|
||||
"harassment/threatening": 0.0,
|
||||
"hate": 0.0,
|
||||
"hate/threatening": 0.0,
|
||||
"self-harm": 0.0,
|
||||
"self-harm/instructions": 0.0,
|
||||
"self-harm/intent": 0.0,
|
||||
"sexual": 0.0,
|
||||
"sexual/minors": 0.0,
|
||||
"violence": 0.0,
|
||||
"violence/graphic": 0.0,
|
||||
"illicit": 0.0,
|
||||
"illicit/violent": 0.0,
|
||||
}
|
||||
category_applied_input_types = {
|
||||
"sexual": ["text", "image"],
|
||||
"hate": ["text"],
|
||||
"harassment": ["text"],
|
||||
"self-harm": ["text", "image"],
|
||||
"sexual/minors": ["text"],
|
||||
"hate/threatening": ["text"],
|
||||
"violence/graphic": ["text", "image"],
|
||||
"self-harm/intent": ["text", "image"],
|
||||
"self-harm/instructions": ["text", "image"],
|
||||
"harassment/threatening": ["text"],
|
||||
"violence": ["text", "image"],
|
||||
"illicit": ["text"],
|
||||
"illicit/violent": ["text"],
|
||||
}
|
||||
result.append(
|
||||
Moderation(
|
||||
flagged=False,
|
||||
categories=Categories(**moderation_categories),
|
||||
category_scores=CategoryScores(**moderation_categories_scores),
|
||||
category_applied_input_types=category_applied_input_types,
|
||||
)
|
||||
)
|
||||
|
||||
return ModerationCreateResponse(id="shiroii kuloko", model=model, results=result)
|
||||
@ -1,22 +0,0 @@
|
||||
from time import time
|
||||
|
||||
from openai.types.model import Model
|
||||
|
||||
|
||||
class MockModelClass:
|
||||
"""
|
||||
mock class for openai.models.Models
|
||||
"""
|
||||
|
||||
def list(
|
||||
self,
|
||||
**kwargs,
|
||||
) -> list[Model]:
|
||||
return [
|
||||
Model(
|
||||
id="ft:gpt-3.5-turbo-0613:personal::8GYJLPDQ",
|
||||
created=int(time()),
|
||||
object="model",
|
||||
owned_by="organization:org-123",
|
||||
)
|
||||
]
|
||||
@ -1,29 +0,0 @@
|
||||
import re
|
||||
from typing import Any, Literal, Union
|
||||
|
||||
from openai._types import NOT_GIVEN, FileTypes, NotGiven
|
||||
from openai.resources.audio.transcriptions import Transcriptions
|
||||
from openai.types.audio.transcription import Transcription
|
||||
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
|
||||
|
||||
class MockSpeech2TextClass:
|
||||
def speech2text_create(
|
||||
self: Transcriptions,
|
||||
*,
|
||||
file: FileTypes,
|
||||
model: Union[str, Literal["whisper-1"]],
|
||||
language: str | NotGiven = NOT_GIVEN,
|
||||
prompt: str | NotGiven = NOT_GIVEN,
|
||||
response_format: Literal["json", "text", "srt", "verbose_json", "vtt"] | NotGiven = NOT_GIVEN,
|
||||
temperature: float | NotGiven = NOT_GIVEN,
|
||||
**kwargs: Any,
|
||||
) -> Transcription:
|
||||
if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", str(self._client.base_url)):
|
||||
raise InvokeAuthorizationError("Invalid base url")
|
||||
|
||||
if len(self._client.api_key) < 18:
|
||||
raise InvokeAuthorizationError("Invalid API key")
|
||||
|
||||
return Transcription(text="1, 2, 3, 4, 5, 6, 7, 8, 9, 10")
|
||||
@ -0,0 +1,44 @@
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
|
||||
import pytest
|
||||
|
||||
# import monkeypatch
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
|
||||
from core.plugin.manager.model import PluginModelManager
|
||||
from tests.integration_tests.model_runtime.__mock.plugin_model import MockModelClass
|
||||
|
||||
|
||||
def mock_plugin_daemon(
|
||||
monkeypatch: MonkeyPatch,
|
||||
) -> Callable[[], None]:
|
||||
"""
|
||||
mock openai module
|
||||
|
||||
:param monkeypatch: pytest monkeypatch fixture
|
||||
:return: unpatch function
|
||||
"""
|
||||
|
||||
def unpatch() -> None:
|
||||
monkeypatch.undo()
|
||||
|
||||
monkeypatch.setattr(PluginModelManager, "invoke_llm", MockModelClass.invoke_llm)
|
||||
monkeypatch.setattr(PluginModelManager, "fetch_model_providers", MockModelClass.fetch_model_providers)
|
||||
monkeypatch.setattr(PluginModelManager, "get_model_schema", MockModelClass.get_model_schema)
|
||||
|
||||
return unpatch
|
||||
|
||||
|
||||
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_model_mock(monkeypatch):
|
||||
if MOCK:
|
||||
unpatch = mock_plugin_daemon(monkeypatch)
|
||||
|
||||
yield
|
||||
|
||||
if MOCK:
|
||||
unpatch()
|
||||
249
api/tests/integration_tests/model_runtime/__mock/plugin_model.py
Normal file
249
api/tests/integration_tests/model_runtime/__mock/plugin_model.py
Normal file
@ -0,0 +1,249 @@
|
||||
import datetime
|
||||
import uuid
|
||||
from collections.abc import Generator, Sequence
|
||||
from decimal import Decimal
|
||||
from json import dumps
|
||||
|
||||
# import monkeypatch
|
||||
from typing import Optional
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage, PromptMessageTool
|
||||
from core.model_runtime.entities.model_entities import (
|
||||
AIModelEntity,
|
||||
FetchFrom,
|
||||
ModelFeature,
|
||||
ModelPropertyKey,
|
||||
ModelType,
|
||||
)
|
||||
from core.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity
|
||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
from core.plugin.manager.model import PluginModelManager
|
||||
|
||||
|
||||
class MockModelClass(PluginModelManager):
|
||||
def fetch_model_providers(self, tenant_id: str) -> Sequence[PluginModelProviderEntity]:
|
||||
"""
|
||||
Fetch model providers for the given tenant.
|
||||
"""
|
||||
return [
|
||||
PluginModelProviderEntity(
|
||||
id=uuid.uuid4().hex,
|
||||
created_at=datetime.datetime.now(),
|
||||
updated_at=datetime.datetime.now(),
|
||||
provider="openai",
|
||||
tenant_id=tenant_id,
|
||||
plugin_unique_identifier="langgenius/openai/openai",
|
||||
plugin_id="langgenius/openai",
|
||||
declaration=ProviderEntity(
|
||||
provider="openai",
|
||||
label=I18nObject(
|
||||
en_US="OpenAI",
|
||||
zh_Hans="OpenAI",
|
||||
),
|
||||
description=I18nObject(
|
||||
en_US="OpenAI",
|
||||
zh_Hans="OpenAI",
|
||||
),
|
||||
icon_small=I18nObject(
|
||||
en_US="https://example.com/icon_small.png",
|
||||
zh_Hans="https://example.com/icon_small.png",
|
||||
),
|
||||
icon_large=I18nObject(
|
||||
en_US="https://example.com/icon_large.png",
|
||||
zh_Hans="https://example.com/icon_large.png",
|
||||
),
|
||||
supported_model_types=[ModelType.LLM],
|
||||
configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL],
|
||||
models=[
|
||||
AIModelEntity(
|
||||
model="gpt-3.5-turbo",
|
||||
label=I18nObject(
|
||||
en_US="gpt-3.5-turbo",
|
||||
zh_Hans="gpt-3.5-turbo",
|
||||
),
|
||||
model_type=ModelType.LLM,
|
||||
fetch_from=FetchFrom.PREDEFINED_MODEL,
|
||||
model_properties={},
|
||||
features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL],
|
||||
),
|
||||
AIModelEntity(
|
||||
model="gpt-3.5-turbo-instruct",
|
||||
label=I18nObject(
|
||||
en_US="gpt-3.5-turbo-instruct",
|
||||
zh_Hans="gpt-3.5-turbo-instruct",
|
||||
),
|
||||
model_type=ModelType.LLM,
|
||||
fetch_from=FetchFrom.PREDEFINED_MODEL,
|
||||
model_properties={
|
||||
ModelPropertyKey.MODE: LLMMode.COMPLETION,
|
||||
},
|
||||
features=[],
|
||||
),
|
||||
],
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
def get_model_schema(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model_type: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
) -> AIModelEntity | None:
|
||||
"""
|
||||
Get model schema
|
||||
"""
|
||||
return AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(
|
||||
en_US="OpenAI",
|
||||
zh_Hans="OpenAI",
|
||||
),
|
||||
model_type=ModelType(model_type),
|
||||
fetch_from=FetchFrom.PREDEFINED_MODEL,
|
||||
model_properties={},
|
||||
features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL] if model == "gpt-3.5-turbo" else [],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def generate_function_call(
|
||||
tools: Optional[list[PromptMessageTool]],
|
||||
) -> Optional[AssistantPromptMessage.ToolCall]:
|
||||
if not tools or len(tools) == 0:
|
||||
return None
|
||||
function: PromptMessageTool = tools[0]
|
||||
function_name = function.name
|
||||
function_parameters = function.parameters
|
||||
function_parameters_type = function_parameters["type"]
|
||||
if function_parameters_type != "object":
|
||||
return None
|
||||
function_parameters_properties = function_parameters["properties"]
|
||||
function_parameters_required = function_parameters["required"]
|
||||
parameters = {}
|
||||
for parameter_name, parameter in function_parameters_properties.items():
|
||||
if parameter_name not in function_parameters_required:
|
||||
continue
|
||||
parameter_type = parameter["type"]
|
||||
if parameter_type == "string":
|
||||
if "enum" in parameter:
|
||||
if len(parameter["enum"]) == 0:
|
||||
continue
|
||||
parameters[parameter_name] = parameter["enum"][0]
|
||||
else:
|
||||
parameters[parameter_name] = "kawaii"
|
||||
elif parameter_type == "integer":
|
||||
parameters[parameter_name] = 114514
|
||||
elif parameter_type == "number":
|
||||
parameters[parameter_name] = 1919810.0
|
||||
elif parameter_type == "boolean":
|
||||
parameters[parameter_name] = True
|
||||
|
||||
return AssistantPromptMessage.ToolCall(
|
||||
id=str(uuid.uuid4()),
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=function_name,
|
||||
arguments=dumps(parameters),
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def mocked_chat_create_sync(
|
||||
model: str,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
) -> LLMResult:
|
||||
tool_call = MockModelClass.generate_function_call(tools=tools)
|
||||
|
||||
return LLMResult(
|
||||
id=str(uuid.uuid4()),
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(content="elaina", tool_calls=[tool_call] if tool_call else []),
|
||||
usage=LLMUsage(
|
||||
prompt_tokens=2,
|
||||
completion_tokens=1,
|
||||
total_tokens=3,
|
||||
prompt_unit_price=Decimal(0.0001),
|
||||
completion_unit_price=Decimal(0.0002),
|
||||
prompt_price_unit=Decimal(1),
|
||||
prompt_price=Decimal(0.0001),
|
||||
completion_price_unit=Decimal(1),
|
||||
completion_price=Decimal(0.0002),
|
||||
total_price=Decimal(0.0003),
|
||||
currency="USD",
|
||||
latency=0.001,
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def mocked_chat_create_stream(
|
||||
model: str,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
) -> Generator[LLMResultChunk, None, None]:
|
||||
tool_call = MockModelClass.generate_function_call(tools=tools)
|
||||
|
||||
full_text = "Hello, world!\n\n```python\nprint('Hello, world!')\n```"
|
||||
for i in range(0, len(full_text) + 1):
|
||||
if i == len(full_text):
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content="",
|
||||
tool_calls=[tool_call] if tool_call else [],
|
||||
),
|
||||
),
|
||||
)
|
||||
else:
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=full_text[i],
|
||||
tool_calls=[tool_call] if tool_call else [],
|
||||
),
|
||||
usage=LLMUsage(
|
||||
prompt_tokens=2,
|
||||
completion_tokens=17,
|
||||
total_tokens=19,
|
||||
prompt_unit_price=Decimal(0.0001),
|
||||
completion_unit_price=Decimal(0.0002),
|
||||
prompt_price_unit=Decimal(1),
|
||||
prompt_price=Decimal(0.0001),
|
||||
completion_price_unit=Decimal(1),
|
||||
completion_price=Decimal(0.0002),
|
||||
total_price=Decimal(0.0003),
|
||||
currency="USD",
|
||||
latency=0.001,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
def invoke_llm(
|
||||
self: PluginModelManager,
|
||||
*,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: Optional[dict] = None,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
):
|
||||
return MockModelClass.mocked_chat_create_stream(model=model, prompt_messages=prompt_messages, tools=tools)
|
||||
@ -1,169 +0,0 @@
|
||||
import os
|
||||
import re
|
||||
from typing import Union
|
||||
|
||||
import pytest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from requests import Response
|
||||
from requests.sessions import Session
|
||||
from xinference_client.client.restful.restful_client import ( # type: ignore
|
||||
Client,
|
||||
RESTfulChatModelHandle,
|
||||
RESTfulEmbeddingModelHandle,
|
||||
RESTfulGenerateModelHandle,
|
||||
RESTfulRerankModelHandle,
|
||||
)
|
||||
from xinference_client.types import Embedding, EmbeddingData, EmbeddingUsage # type: ignore
|
||||
|
||||
|
||||
class MockXinferenceClass:
|
||||
def get_chat_model(self: Client, model_uid: str) -> Union[RESTfulGenerateModelHandle, RESTfulChatModelHandle]:
|
||||
if not re.match(r"https?:\/\/[^\s\/$.?#].[^\s]*$", self.base_url):
|
||||
raise RuntimeError("404 Not Found")
|
||||
|
||||
if model_uid == "generate":
|
||||
return RESTfulGenerateModelHandle(model_uid, base_url=self.base_url, auth_headers={})
|
||||
if model_uid == "chat":
|
||||
return RESTfulChatModelHandle(model_uid, base_url=self.base_url, auth_headers={})
|
||||
if model_uid == "embedding":
|
||||
return RESTfulEmbeddingModelHandle(model_uid, base_url=self.base_url, auth_headers={})
|
||||
if model_uid == "rerank":
|
||||
return RESTfulRerankModelHandle(model_uid, base_url=self.base_url, auth_headers={})
|
||||
raise RuntimeError("404 Not Found")
|
||||
|
||||
def get(self: Session, url: str, **kwargs):
|
||||
response = Response()
|
||||
if "v1/models/" in url:
|
||||
# get model uid
|
||||
model_uid = url.split("/")[-1] or ""
|
||||
if not re.match(
|
||||
r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", model_uid
|
||||
) and model_uid not in {"generate", "chat", "embedding", "rerank"}:
|
||||
response.status_code = 404
|
||||
response._content = b"{}"
|
||||
return response
|
||||
|
||||
# check if url is valid
|
||||
if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", url):
|
||||
response.status_code = 404
|
||||
response._content = b"{}"
|
||||
return response
|
||||
|
||||
if model_uid in {"generate", "chat"}:
|
||||
response.status_code = 200
|
||||
response._content = b"""{
|
||||
"model_type": "LLM",
|
||||
"address": "127.0.0.1:43877",
|
||||
"accelerators": [
|
||||
"0",
|
||||
"1"
|
||||
],
|
||||
"model_name": "chatglm3-6b",
|
||||
"model_lang": [
|
||||
"en"
|
||||
],
|
||||
"model_ability": [
|
||||
"generate",
|
||||
"chat"
|
||||
],
|
||||
"model_description": "latest chatglm3",
|
||||
"model_format": "pytorch",
|
||||
"model_size_in_billions": 7,
|
||||
"quantization": "none",
|
||||
"model_hub": "huggingface",
|
||||
"revision": null,
|
||||
"context_length": 2048,
|
||||
"replica": 1
|
||||
}"""
|
||||
return response
|
||||
|
||||
elif model_uid == "embedding":
|
||||
response.status_code = 200
|
||||
response._content = b"""{
|
||||
"model_type": "embedding",
|
||||
"address": "127.0.0.1:43877",
|
||||
"accelerators": [
|
||||
"0",
|
||||
"1"
|
||||
],
|
||||
"model_name": "bge",
|
||||
"model_lang": [
|
||||
"en"
|
||||
],
|
||||
"revision": null,
|
||||
"max_tokens": 512
|
||||
}"""
|
||||
return response
|
||||
|
||||
elif "v1/cluster/auth" in url:
|
||||
response.status_code = 200
|
||||
response._content = b"""{
|
||||
"auth": true
|
||||
}"""
|
||||
return response
|
||||
|
||||
def _check_cluster_authenticated(self):
|
||||
self._cluster_authed = True
|
||||
|
||||
def rerank(
|
||||
self: RESTfulRerankModelHandle, documents: list[str], query: str, top_n: int, return_documents: bool
|
||||
) -> dict:
|
||||
# check if self._model_uid is a valid uuid
|
||||
if (
|
||||
not re.match(r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", self._model_uid)
|
||||
and self._model_uid != "rerank"
|
||||
):
|
||||
raise RuntimeError("404 Not Found")
|
||||
|
||||
if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._base_url):
|
||||
raise RuntimeError("404 Not Found")
|
||||
|
||||
if top_n is None:
|
||||
top_n = 1
|
||||
|
||||
return {
|
||||
"results": [
|
||||
{"index": i, "document": doc, "relevance_score": 0.9} for i, doc in enumerate(documents[:top_n])
|
||||
]
|
||||
}
|
||||
|
||||
def create_embedding(self: RESTfulGenerateModelHandle, input: Union[str, list[str]], **kwargs) -> dict:
|
||||
# check if self._model_uid is a valid uuid
|
||||
if (
|
||||
not re.match(r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", self._model_uid)
|
||||
and self._model_uid != "embedding"
|
||||
):
|
||||
raise RuntimeError("404 Not Found")
|
||||
|
||||
if isinstance(input, str):
|
||||
input = [input]
|
||||
ipt_len = len(input)
|
||||
|
||||
embedding = Embedding(
|
||||
object="list",
|
||||
model=self._model_uid,
|
||||
data=[
|
||||
EmbeddingData(index=i, object="embedding", embedding=[1919.810 for _ in range(768)])
|
||||
for i in range(ipt_len)
|
||||
],
|
||||
usage=EmbeddingUsage(prompt_tokens=ipt_len, total_tokens=ipt_len),
|
||||
)
|
||||
|
||||
return embedding
|
||||
|
||||
|
||||
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_xinference_mock(request, monkeypatch: MonkeyPatch):
|
||||
if MOCK:
|
||||
monkeypatch.setattr(Client, "get_model", MockXinferenceClass.get_chat_model)
|
||||
monkeypatch.setattr(Client, "_check_cluster_authenticated", MockXinferenceClass._check_cluster_authenticated)
|
||||
monkeypatch.setattr(Session, "get", MockXinferenceClass.get)
|
||||
monkeypatch.setattr(RESTfulEmbeddingModelHandle, "create_embedding", MockXinferenceClass.create_embedding)
|
||||
monkeypatch.setattr(RESTfulRerankModelHandle, "rerank", MockXinferenceClass.rerank)
|
||||
yield
|
||||
|
||||
if MOCK:
|
||||
monkeypatch.undo()
|
||||
@ -1,92 +0,0 @@
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.anthropic.llm.llm import AnthropicLargeLanguageModel
|
||||
from tests.integration_tests.model_runtime.__mock.anthropic import setup_anthropic_mock
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_anthropic_mock", [["none"]], indirect=True)
|
||||
def test_validate_credentials(setup_anthropic_mock):
|
||||
model = AnthropicLargeLanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(model="claude-instant-1.2", credentials={"anthropic_api_key": "invalid_key"})
|
||||
|
||||
model.validate_credentials(
|
||||
model="claude-instant-1.2", credentials={"anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY")}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_anthropic_mock", [["none"]], indirect=True)
|
||||
def test_invoke_model(setup_anthropic_mock):
|
||||
model = AnthropicLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model="claude-instant-1.2",
|
||||
credentials={
|
||||
"anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY"),
|
||||
"anthropic_api_url": os.environ.get("ANTHROPIC_API_URL"),
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
model_parameters={"temperature": 0.0, "top_p": 1.0, "max_tokens": 10},
|
||||
stop=["How"],
|
||||
stream=False,
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_anthropic_mock", [["none"]], indirect=True)
|
||||
def test_invoke_stream_model(setup_anthropic_mock):
|
||||
model = AnthropicLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model="claude-instant-1.2",
|
||||
credentials={"anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY")},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
model_parameters={"temperature": 0.0, "max_tokens": 100},
|
||||
stream=True,
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
|
||||
for chunk in response:
|
||||
assert isinstance(chunk, LLMResultChunk)
|
||||
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = AnthropicLargeLanguageModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model="claude-instant-1.2",
|
||||
credentials={"anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY")},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
)
|
||||
|
||||
assert num_tokens == 18
|
||||
@ -1,17 +0,0 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.anthropic.anthropic import AnthropicProvider
|
||||
from tests.integration_tests.model_runtime.__mock.anthropic import setup_anthropic_mock
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_anthropic_mock", [["none"]], indirect=True)
|
||||
def test_validate_provider_credentials(setup_anthropic_mock):
|
||||
provider = AnthropicProvider()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
provider.validate_provider_credentials(credentials={})
|
||||
|
||||
provider.validate_provider_credentials(credentials={"anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY")})
|
||||
Binary file not shown.
@ -1,109 +0,0 @@
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
SystemPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.azure_ai_studio.llm.llm import AzureAIStudioLargeLanguageModel
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_azure_ai_studio_mock", [["chat"]], indirect=True)
|
||||
def test_validate_credentials(setup_azure_ai_studio_mock):
|
||||
model = AzureAIStudioLargeLanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model="gpt-35-turbo",
|
||||
credentials={"api_key": "invalid_key", "api_base": os.getenv("AZURE_AI_STUDIO_API_BASE")},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model="gpt-35-turbo",
|
||||
credentials={
|
||||
"api_key": os.getenv("AZURE_AI_STUDIO_API_KEY"),
|
||||
"api_base": os.getenv("AZURE_AI_STUDIO_API_BASE"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_azure_ai_studio_mock", [["chat"]], indirect=True)
|
||||
def test_invoke_model(setup_azure_ai_studio_mock):
|
||||
model = AzureAIStudioLargeLanguageModel()
|
||||
|
||||
result = model.invoke(
|
||||
model="gpt-35-turbo",
|
||||
credentials={
|
||||
"api_key": os.getenv("AZURE_AI_STUDIO_API_KEY"),
|
||||
"api_base": os.getenv("AZURE_AI_STUDIO_API_BASE"),
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
model_parameters={"temperature": 0.0, "max_tokens": 100},
|
||||
stream=False,
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(result, LLMResult)
|
||||
assert len(result.message.content) > 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_azure_ai_studio_mock", [["chat"]], indirect=True)
|
||||
def test_invoke_stream_model(setup_azure_ai_studio_mock):
|
||||
model = AzureAIStudioLargeLanguageModel()
|
||||
|
||||
result = model.invoke(
|
||||
model="gpt-35-turbo",
|
||||
credentials={
|
||||
"api_key": os.getenv("AZURE_AI_STUDIO_API_KEY"),
|
||||
"api_base": os.getenv("AZURE_AI_STUDIO_API_BASE"),
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
model_parameters={"temperature": 0.0, "max_tokens": 100},
|
||||
stream=True,
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(result, Generator)
|
||||
|
||||
for chunk in result:
|
||||
assert isinstance(chunk, LLMResultChunk)
|
||||
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
if chunk.delta.finish_reason is not None:
|
||||
assert chunk.delta.usage is not None
|
||||
assert chunk.delta.usage.completion_tokens > 0
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = AzureAIStudioLargeLanguageModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model="gpt-35-turbo",
|
||||
credentials={
|
||||
"api_key": os.getenv("AZURE_AI_STUDIO_API_KEY"),
|
||||
"api_base": os.getenv("AZURE_AI_STUDIO_API_BASE"),
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
)
|
||||
|
||||
assert num_tokens == 21
|
||||
@ -1,17 +0,0 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.azure_ai_studio.azure_ai_studio import AzureAIStudioProvider
|
||||
|
||||
|
||||
def test_validate_provider_credentials():
|
||||
provider = AzureAIStudioProvider()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
provider.validate_provider_credentials(credentials={})
|
||||
|
||||
provider.validate_provider_credentials(
|
||||
credentials={"api_key": os.getenv("AZURE_AI_STUDIO_API_KEY"), "api_base": os.getenv("AZURE_AI_STUDIO_API_BASE")}
|
||||
)
|
||||
@ -1,42 +0,0 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.rerank_entities import RerankResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.azure_ai_studio.rerank.rerank import AzureRerankModel
|
||||
|
||||
|
||||
def test_validate_credentials():
|
||||
model = AzureRerankModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model="azure-ai-studio-rerank-v1",
|
||||
credentials={"api_key": "invalid_key", "api_base": os.getenv("AZURE_AI_STUDIO_API_BASE")},
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
model = AzureRerankModel()
|
||||
|
||||
result = model.invoke(
|
||||
model="azure-ai-studio-rerank-v1",
|
||||
credentials={
|
||||
"api_key": os.getenv("AZURE_AI_STUDIO_JWT_TOKEN"),
|
||||
"api_base": os.getenv("AZURE_AI_STUDIO_API_BASE"),
|
||||
},
|
||||
query="What is the capital of the United States?",
|
||||
docs=[
|
||||
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
|
||||
"Census, Carson City had a population of 55,274.",
|
||||
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
|
||||
"are a political division controlled by the United States. Its capital is Saipan.",
|
||||
],
|
||||
score_threshold=0.8,
|
||||
)
|
||||
|
||||
assert isinstance(result, RerankResult)
|
||||
assert len(result.docs) == 1
|
||||
assert result.docs[0].index == 1
|
||||
assert result.docs[0].score >= 0.8
|
||||
File diff suppressed because one or more lines are too long
@ -1,62 +0,0 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.azure_openai.text_embedding.text_embedding import AzureOpenAITextEmbeddingModel
|
||||
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_openai_mock", [["text_embedding"]], indirect=True)
|
||||
def test_validate_credentials(setup_openai_mock):
|
||||
model = AzureOpenAITextEmbeddingModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model="embedding",
|
||||
credentials={
|
||||
"openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"),
|
||||
"openai_api_key": "invalid_key",
|
||||
"base_model_name": "text-embedding-ada-002",
|
||||
},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model="embedding",
|
||||
credentials={
|
||||
"openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"),
|
||||
"openai_api_key": os.environ.get("AZURE_OPENAI_API_KEY"),
|
||||
"base_model_name": "text-embedding-ada-002",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_openai_mock", [["text_embedding"]], indirect=True)
|
||||
def test_invoke_model(setup_openai_mock):
|
||||
model = AzureOpenAITextEmbeddingModel()
|
||||
|
||||
result = model.invoke(
|
||||
model="embedding",
|
||||
credentials={
|
||||
"openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"),
|
||||
"openai_api_key": os.environ.get("AZURE_OPENAI_API_KEY"),
|
||||
"base_model_name": "text-embedding-ada-002",
|
||||
},
|
||||
texts=["hello", "world"],
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
assert len(result.embeddings) == 2
|
||||
assert result.usage.total_tokens == 2
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = AzureOpenAITextEmbeddingModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model="embedding", credentials={"base_model_name": "text-embedding-ada-002"}, texts=["hello", "world"]
|
||||
)
|
||||
|
||||
assert num_tokens == 2
|
||||
@ -1,172 +0,0 @@
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
from time import sleep
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.baichuan.llm.llm import BaichuanLanguageModel
|
||||
|
||||
|
||||
def test_predefined_models():
|
||||
model = BaichuanLanguageModel()
|
||||
model_schemas = model.predefined_models()
|
||||
assert len(model_schemas) >= 1
|
||||
assert isinstance(model_schemas[0], AIModelEntity)
|
||||
|
||||
|
||||
def test_validate_credentials_for_chat_model():
|
||||
sleep(3)
|
||||
model = BaichuanLanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model="baichuan2-turbo", credentials={"api_key": "invalid_key", "secret_key": "invalid_key"}
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model="baichuan2-turbo",
|
||||
credentials={
|
||||
"api_key": os.environ.get("BAICHUAN_API_KEY"),
|
||||
"secret_key": os.environ.get("BAICHUAN_SECRET_KEY"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
sleep(3)
|
||||
model = BaichuanLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model="baichuan2-turbo",
|
||||
credentials={
|
||||
"api_key": os.environ.get("BAICHUAN_API_KEY"),
|
||||
"secret_key": os.environ.get("BAICHUAN_SECRET_KEY"),
|
||||
},
|
||||
prompt_messages=[UserPromptMessage(content="Hello World!")],
|
||||
model_parameters={
|
||||
"temperature": 0.7,
|
||||
"top_p": 1.0,
|
||||
"top_k": 1,
|
||||
},
|
||||
stop=["you"],
|
||||
user="abc-123",
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
assert response.usage.total_tokens > 0
|
||||
|
||||
|
||||
def test_invoke_model_with_system_message():
|
||||
sleep(3)
|
||||
model = BaichuanLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model="baichuan2-turbo",
|
||||
credentials={
|
||||
"api_key": os.environ.get("BAICHUAN_API_KEY"),
|
||||
"secret_key": os.environ.get("BAICHUAN_SECRET_KEY"),
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(content="请记住你是Kasumi。"),
|
||||
UserPromptMessage(content="现在告诉我你是谁?"),
|
||||
],
|
||||
model_parameters={
|
||||
"temperature": 0.7,
|
||||
"top_p": 1.0,
|
||||
"top_k": 1,
|
||||
},
|
||||
stop=["you"],
|
||||
user="abc-123",
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
assert response.usage.total_tokens > 0
|
||||
|
||||
|
||||
def test_invoke_stream_model():
|
||||
sleep(3)
|
||||
model = BaichuanLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model="baichuan2-turbo",
|
||||
credentials={
|
||||
"api_key": os.environ.get("BAICHUAN_API_KEY"),
|
||||
"secret_key": os.environ.get("BAICHUAN_SECRET_KEY"),
|
||||
},
|
||||
prompt_messages=[UserPromptMessage(content="Hello World!")],
|
||||
model_parameters={
|
||||
"temperature": 0.7,
|
||||
"top_p": 1.0,
|
||||
"top_k": 1,
|
||||
},
|
||||
stop=["you"],
|
||||
stream=True,
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
for chunk in response:
|
||||
assert isinstance(chunk, LLMResultChunk)
|
||||
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
|
||||
|
||||
|
||||
def test_invoke_with_search():
|
||||
sleep(3)
|
||||
model = BaichuanLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model="baichuan2-turbo",
|
||||
credentials={
|
||||
"api_key": os.environ.get("BAICHUAN_API_KEY"),
|
||||
"secret_key": os.environ.get("BAICHUAN_SECRET_KEY"),
|
||||
},
|
||||
prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")],
|
||||
model_parameters={
|
||||
"temperature": 0.7,
|
||||
"top_p": 1.0,
|
||||
"top_k": 1,
|
||||
"with_search_enhance": True,
|
||||
},
|
||||
stop=["you"],
|
||||
stream=True,
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
total_message = ""
|
||||
for chunk in response:
|
||||
assert isinstance(chunk, LLMResultChunk)
|
||||
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
assert len(chunk.delta.message.content) > 0 if not chunk.delta.finish_reason else True
|
||||
total_message += chunk.delta.message.content
|
||||
|
||||
assert "不" not in total_message
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
sleep(3)
|
||||
model = BaichuanLanguageModel()
|
||||
|
||||
response = model.get_num_tokens(
|
||||
model="baichuan2-turbo",
|
||||
credentials={
|
||||
"api_key": os.environ.get("BAICHUAN_API_KEY"),
|
||||
"secret_key": os.environ.get("BAICHUAN_SECRET_KEY"),
|
||||
},
|
||||
prompt_messages=[UserPromptMessage(content="Hello World!")],
|
||||
tools=[],
|
||||
)
|
||||
|
||||
assert isinstance(response, int)
|
||||
assert response == 9
|
||||
@ -1,15 +0,0 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.baichuan.baichuan import BaichuanProvider
|
||||
|
||||
|
||||
def test_validate_provider_credentials():
|
||||
provider = BaichuanProvider()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
provider.validate_provider_credentials(credentials={"api_key": "hahahaha"})
|
||||
|
||||
provider.validate_provider_credentials(credentials={"api_key": os.environ.get("BAICHUAN_API_KEY")})
|
||||
@ -1,87 +0,0 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.baichuan.text_embedding.text_embedding import BaichuanTextEmbeddingModel
|
||||
|
||||
|
||||
def test_validate_credentials():
|
||||
model = BaichuanTextEmbeddingModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(model="baichuan-text-embedding", credentials={"api_key": "invalid_key"})
|
||||
|
||||
model.validate_credentials(
|
||||
model="baichuan-text-embedding", credentials={"api_key": os.environ.get("BAICHUAN_API_KEY")}
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
model = BaichuanTextEmbeddingModel()
|
||||
|
||||
result = model.invoke(
|
||||
model="baichuan-text-embedding",
|
||||
credentials={
|
||||
"api_key": os.environ.get("BAICHUAN_API_KEY"),
|
||||
},
|
||||
texts=["hello", "world"],
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
assert len(result.embeddings) == 2
|
||||
assert result.usage.total_tokens == 6
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = BaichuanTextEmbeddingModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model="baichuan-text-embedding",
|
||||
credentials={
|
||||
"api_key": os.environ.get("BAICHUAN_API_KEY"),
|
||||
},
|
||||
texts=["hello", "world"],
|
||||
)
|
||||
|
||||
assert num_tokens == 2
|
||||
|
||||
|
||||
def test_max_chunks():
|
||||
model = BaichuanTextEmbeddingModel()
|
||||
|
||||
result = model.invoke(
|
||||
model="baichuan-text-embedding",
|
||||
credentials={
|
||||
"api_key": os.environ.get("BAICHUAN_API_KEY"),
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world",
|
||||
"hello",
|
||||
"world",
|
||||
"hello",
|
||||
"world",
|
||||
"hello",
|
||||
"world",
|
||||
"hello",
|
||||
"world",
|
||||
"hello",
|
||||
"world",
|
||||
"hello",
|
||||
"world",
|
||||
"hello",
|
||||
"world",
|
||||
"hello",
|
||||
"world",
|
||||
"hello",
|
||||
"world",
|
||||
"hello",
|
||||
"world",
|
||||
],
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
assert len(result.embeddings) == 22
|
||||
@ -1,103 +0,0 @@
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.bedrock.llm.llm import BedrockLargeLanguageModel
|
||||
|
||||
|
||||
def test_validate_credentials():
|
||||
model = BedrockLargeLanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(model="meta.llama2-13b-chat-v1", credentials={"anthropic_api_key": "invalid_key"})
|
||||
|
||||
model.validate_credentials(
|
||||
model="meta.llama2-13b-chat-v1",
|
||||
credentials={
|
||||
"aws_region": os.getenv("AWS_REGION"),
|
||||
"aws_access_key": os.getenv("AWS_ACCESS_KEY"),
|
||||
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
model = BedrockLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model="meta.llama2-13b-chat-v1",
|
||||
credentials={
|
||||
"aws_region": os.getenv("AWS_REGION"),
|
||||
"aws_access_key": os.getenv("AWS_ACCESS_KEY"),
|
||||
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"),
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
model_parameters={"temperature": 0.0, "top_p": 1.0, "max_tokens_to_sample": 10},
|
||||
stop=["How"],
|
||||
stream=False,
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
|
||||
|
||||
def test_invoke_stream_model():
|
||||
model = BedrockLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model="meta.llama2-13b-chat-v1",
|
||||
credentials={
|
||||
"aws_region": os.getenv("AWS_REGION"),
|
||||
"aws_access_key": os.getenv("AWS_ACCESS_KEY"),
|
||||
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"),
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
model_parameters={"temperature": 0.0, "max_tokens_to_sample": 100},
|
||||
stream=True,
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
|
||||
for chunk in response:
|
||||
print(chunk)
|
||||
assert isinstance(chunk, LLMResultChunk)
|
||||
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = BedrockLargeLanguageModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model="meta.llama2-13b-chat-v1",
|
||||
credentials={
|
||||
"aws_region": os.getenv("AWS_REGION"),
|
||||
"aws_access_key": os.getenv("AWS_ACCESS_KEY"),
|
||||
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"),
|
||||
},
|
||||
messages=[
|
||||
SystemPromptMessage(
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
)
|
||||
|
||||
assert num_tokens == 18
|
||||
@ -1,21 +0,0 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.bedrock.bedrock import BedrockProvider
|
||||
|
||||
|
||||
def test_validate_provider_credentials():
|
||||
provider = BedrockProvider()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
provider.validate_provider_credentials(credentials={})
|
||||
|
||||
provider.validate_provider_credentials(
|
||||
credentials={
|
||||
"aws_region": os.getenv("AWS_REGION"),
|
||||
"aws_access_key": os.getenv("AWS_ACCESS_KEY"),
|
||||
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"),
|
||||
}
|
||||
)
|
||||
@ -1,229 +0,0 @@
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.chatglm.llm.llm import ChatGLMLargeLanguageModel
|
||||
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
|
||||
|
||||
|
||||
def test_predefined_models():
|
||||
model = ChatGLMLargeLanguageModel()
|
||||
model_schemas = model.predefined_models()
|
||||
assert len(model_schemas) >= 1
|
||||
assert isinstance(model_schemas[0], AIModelEntity)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
|
||||
def test_validate_credentials_for_chat_model(setup_openai_mock):
|
||||
model = ChatGLMLargeLanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(model="chatglm2-6b", credentials={"api_base": "invalid_key"})
|
||||
|
||||
model.validate_credentials(model="chatglm2-6b", credentials={"api_base": os.environ.get("CHATGLM_API_BASE")})
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
|
||||
def test_invoke_model(setup_openai_mock):
|
||||
model = ChatGLMLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model="chatglm2-6b",
|
||||
credentials={"api_base": os.environ.get("CHATGLM_API_BASE")},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
model_parameters={
|
||||
"temperature": 0.7,
|
||||
"top_p": 1.0,
|
||||
},
|
||||
stop=["you"],
|
||||
user="abc-123",
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
assert response.usage.total_tokens > 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
|
||||
def test_invoke_stream_model(setup_openai_mock):
|
||||
model = ChatGLMLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model="chatglm2-6b",
|
||||
credentials={"api_base": os.environ.get("CHATGLM_API_BASE")},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
model_parameters={
|
||||
"temperature": 0.7,
|
||||
"top_p": 1.0,
|
||||
},
|
||||
stop=["you"],
|
||||
stream=True,
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
for chunk in response:
|
||||
assert isinstance(chunk, LLMResultChunk)
|
||||
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
|
||||
def test_invoke_stream_model_with_functions(setup_openai_mock):
|
||||
model = ChatGLMLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model="chatglm3-6b",
|
||||
credentials={"api_base": os.environ.get("CHATGLM_API_BASE")},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content="你是一个天气机器人,你不知道今天的天气怎么样,你需要通过调用一个函数来获取天气信息。"
|
||||
),
|
||||
UserPromptMessage(content="波士顿天气如何?"),
|
||||
],
|
||||
model_parameters={
|
||||
"temperature": 0,
|
||||
"top_p": 1.0,
|
||||
},
|
||||
stop=["you"],
|
||||
user="abc-123",
|
||||
stream=True,
|
||||
tools=[
|
||||
PromptMessageTool(
|
||||
name="get_current_weather",
|
||||
description="Get the current weather in a given location",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
|
||||
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
|
||||
call: LLMResultChunk = None
|
||||
chunks = []
|
||||
|
||||
for chunk in response:
|
||||
chunks.append(chunk)
|
||||
assert isinstance(chunk, LLMResultChunk)
|
||||
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
|
||||
|
||||
if chunk.delta.message.tool_calls and len(chunk.delta.message.tool_calls) > 0:
|
||||
call = chunk
|
||||
break
|
||||
|
||||
assert call is not None
|
||||
assert call.delta.message.tool_calls[0].function.name == "get_current_weather"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
|
||||
def test_invoke_model_with_functions(setup_openai_mock):
|
||||
model = ChatGLMLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model="chatglm3-6b",
|
||||
credentials={"api_base": os.environ.get("CHATGLM_API_BASE")},
|
||||
prompt_messages=[UserPromptMessage(content="What is the weather like in San Francisco?")],
|
||||
model_parameters={
|
||||
"temperature": 0.7,
|
||||
"top_p": 1.0,
|
||||
},
|
||||
stop=["you"],
|
||||
user="abc-123",
|
||||
stream=False,
|
||||
tools=[
|
||||
PromptMessageTool(
|
||||
name="get_current_weather",
|
||||
description="Get the current weather in a given location",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
|
||||
"unit": {"type": "string", "enum": ["c", "f"]},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
assert response.usage.total_tokens > 0
|
||||
assert response.message.tool_calls[0].function.name == "get_current_weather"
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = ChatGLMLargeLanguageModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model="chatglm2-6b",
|
||||
credentials={"api_base": os.environ.get("CHATGLM_API_BASE")},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
tools=[
|
||||
PromptMessageTool(
|
||||
name="get_current_weather",
|
||||
description="Get the current weather in a given location",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
|
||||
"unit": {"type": "string", "enum": ["c", "f"]},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
assert isinstance(num_tokens, int)
|
||||
assert num_tokens == 77
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model="chatglm2-6b",
|
||||
credentials={"api_base": os.environ.get("CHATGLM_API_BASE")},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
)
|
||||
|
||||
assert isinstance(num_tokens, int)
|
||||
assert num_tokens == 21
|
||||
@ -1,17 +0,0 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.chatglm.chatglm import ChatGLMProvider
|
||||
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
|
||||
def test_validate_provider_credentials(setup_openai_mock):
|
||||
provider = ChatGLMProvider()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
provider.validate_provider_credentials(credentials={"api_base": "hahahaha"})
|
||||
|
||||
provider.validate_provider_credentials(credentials={"api_base": os.environ.get("CHATGLM_API_BASE")})
|
||||
@ -1,191 +0,0 @@
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.cohere.llm.llm import CohereLargeLanguageModel
|
||||
|
||||
|
||||
def test_validate_credentials_for_chat_model():
|
||||
model = CohereLargeLanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(model="command-light-chat", credentials={"api_key": "invalid_key"})
|
||||
|
||||
model.validate_credentials(model="command-light-chat", credentials={"api_key": os.environ.get("COHERE_API_KEY")})
|
||||
|
||||
|
||||
def test_validate_credentials_for_completion_model():
|
||||
model = CohereLargeLanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(model="command-light", credentials={"api_key": "invalid_key"})
|
||||
|
||||
model.validate_credentials(model="command-light", credentials={"api_key": os.environ.get("COHERE_API_KEY")})
|
||||
|
||||
|
||||
def test_invoke_completion_model():
|
||||
model = CohereLargeLanguageModel()
|
||||
|
||||
credentials = {"api_key": os.environ.get("COHERE_API_KEY")}
|
||||
|
||||
result = model.invoke(
|
||||
model="command-light",
|
||||
credentials=credentials,
|
||||
prompt_messages=[UserPromptMessage(content="Hello World!")],
|
||||
model_parameters={"temperature": 0.0, "max_tokens": 1},
|
||||
stream=False,
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(result, LLMResult)
|
||||
assert len(result.message.content) > 0
|
||||
assert model._num_tokens_from_string("command-light", credentials, result.message.content) == 1
|
||||
|
||||
|
||||
def test_invoke_stream_completion_model():
|
||||
model = CohereLargeLanguageModel()
|
||||
|
||||
result = model.invoke(
|
||||
model="command-light",
|
||||
credentials={"api_key": os.environ.get("COHERE_API_KEY")},
|
||||
prompt_messages=[UserPromptMessage(content="Hello World!")],
|
||||
model_parameters={"temperature": 0.0, "max_tokens": 100},
|
||||
stream=True,
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(result, Generator)
|
||||
|
||||
for chunk in result:
|
||||
assert isinstance(chunk, LLMResultChunk)
|
||||
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
|
||||
|
||||
|
||||
def test_invoke_chat_model():
|
||||
model = CohereLargeLanguageModel()
|
||||
|
||||
result = model.invoke(
|
||||
model="command-light-chat",
|
||||
credentials={"api_key": os.environ.get("COHERE_API_KEY")},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
model_parameters={
|
||||
"temperature": 0.0,
|
||||
"p": 0.99,
|
||||
"presence_penalty": 0.0,
|
||||
"frequency_penalty": 0.0,
|
||||
"max_tokens": 10,
|
||||
},
|
||||
stop=["How"],
|
||||
stream=False,
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(result, LLMResult)
|
||||
assert len(result.message.content) > 0
|
||||
|
||||
|
||||
def test_invoke_stream_chat_model():
|
||||
model = CohereLargeLanguageModel()
|
||||
|
||||
result = model.invoke(
|
||||
model="command-light-chat",
|
||||
credentials={"api_key": os.environ.get("COHERE_API_KEY")},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
model_parameters={"temperature": 0.0, "max_tokens": 100},
|
||||
stream=True,
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(result, Generator)
|
||||
|
||||
for chunk in result:
|
||||
assert isinstance(chunk, LLMResultChunk)
|
||||
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
|
||||
if chunk.delta.finish_reason is not None:
|
||||
assert chunk.delta.usage is not None
|
||||
assert chunk.delta.usage.completion_tokens > 0
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = CohereLargeLanguageModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model="command-light",
|
||||
credentials={"api_key": os.environ.get("COHERE_API_KEY")},
|
||||
prompt_messages=[UserPromptMessage(content="Hello World!")],
|
||||
)
|
||||
|
||||
assert num_tokens == 3
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model="command-light-chat",
|
||||
credentials={"api_key": os.environ.get("COHERE_API_KEY")},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
)
|
||||
|
||||
assert num_tokens == 15
|
||||
|
||||
|
||||
def test_fine_tuned_model():
|
||||
model = CohereLargeLanguageModel()
|
||||
|
||||
# test invoke
|
||||
result = model.invoke(
|
||||
model="85ec47be-6139-4f75-a4be-0f0ec1ef115c-ft",
|
||||
credentials={"api_key": os.environ.get("COHERE_API_KEY"), "mode": "completion"},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
model_parameters={"temperature": 0.0, "max_tokens": 100},
|
||||
stream=False,
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(result, LLMResult)
|
||||
|
||||
|
||||
def test_fine_tuned_chat_model():
|
||||
model = CohereLargeLanguageModel()
|
||||
|
||||
# test invoke
|
||||
result = model.invoke(
|
||||
model="94f2d55a-4c79-4c00-bde4-23962e74b170-ft",
|
||||
credentials={"api_key": os.environ.get("COHERE_API_KEY"), "mode": "chat"},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
model_parameters={"temperature": 0.0, "max_tokens": 100},
|
||||
stream=False,
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(result, LLMResult)
|
||||
@ -1,15 +0,0 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.cohere.cohere import CohereProvider
|
||||
|
||||
|
||||
def test_validate_provider_credentials():
|
||||
provider = CohereProvider()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
provider.validate_provider_credentials(credentials={})
|
||||
|
||||
provider.validate_provider_credentials(credentials={"api_key": os.environ.get("COHERE_API_KEY")})
|
||||
@ -1,40 +0,0 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.rerank_entities import RerankResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.cohere.rerank.rerank import CohereRerankModel
|
||||
|
||||
|
||||
def test_validate_credentials():
|
||||
model = CohereRerankModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(model="rerank-english-v2.0", credentials={"api_key": "invalid_key"})
|
||||
|
||||
model.validate_credentials(model="rerank-english-v2.0", credentials={"api_key": os.environ.get("COHERE_API_KEY")})
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
model = CohereRerankModel()
|
||||
|
||||
result = model.invoke(
|
||||
model="rerank-english-v2.0",
|
||||
credentials={"api_key": os.environ.get("COHERE_API_KEY")},
|
||||
query="What is the capital of the United States?",
|
||||
docs=[
|
||||
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
|
||||
"Census, Carson City had a population of 55,274.",
|
||||
"Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) "
|
||||
"is the capital of the United States. It is a federal district. The President of the USA and many major "
|
||||
"national government offices are in the territory. This makes it the political center of the United "
|
||||
"States of America.",
|
||||
],
|
||||
score_threshold=0.8,
|
||||
)
|
||||
|
||||
assert isinstance(result, RerankResult)
|
||||
assert len(result.docs) == 1
|
||||
assert result.docs[0].index == 1
|
||||
assert result.docs[0].score >= 0.8
|
||||
@ -1,45 +0,0 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.cohere.text_embedding.text_embedding import CohereTextEmbeddingModel
|
||||
|
||||
|
||||
def test_validate_credentials():
|
||||
model = CohereTextEmbeddingModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(model="embed-multilingual-v3.0", credentials={"api_key": "invalid_key"})
|
||||
|
||||
model.validate_credentials(
|
||||
model="embed-multilingual-v3.0", credentials={"api_key": os.environ.get("COHERE_API_KEY")}
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
model = CohereTextEmbeddingModel()
|
||||
|
||||
result = model.invoke(
|
||||
model="embed-multilingual-v3.0",
|
||||
credentials={"api_key": os.environ.get("COHERE_API_KEY")},
|
||||
texts=["hello", "world", " ".join(["long_text"] * 100), " ".join(["another_long_text"] * 100)],
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
assert len(result.embeddings) == 4
|
||||
assert result.usage.total_tokens == 811
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = CohereTextEmbeddingModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model="embed-multilingual-v3.0",
|
||||
credentials={"api_key": os.environ.get("COHERE_API_KEY")},
|
||||
texts=["hello", "world"],
|
||||
)
|
||||
|
||||
assert num_tokens == 3
|
||||
@ -1,186 +0,0 @@
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.fireworks.llm.llm import FireworksLargeLanguageModel
|
||||
|
||||
"""FOR MOCK FIXTURES, DO NOT REMOVE"""
|
||||
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
|
||||
|
||||
|
||||
def test_predefined_models():
|
||||
model = FireworksLargeLanguageModel()
|
||||
model_schemas = model.predefined_models()
|
||||
|
||||
assert len(model_schemas) >= 1
|
||||
assert isinstance(model_schemas[0], AIModelEntity)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
|
||||
def test_validate_credentials_for_chat_model(setup_openai_mock):
|
||||
model = FireworksLargeLanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
# model name to gpt-3.5-turbo because of mocking
|
||||
model.validate_credentials(model="gpt-3.5-turbo", credentials={"fireworks_api_key": "invalid_key"})
|
||||
|
||||
model.validate_credentials(
|
||||
model="accounts/fireworks/models/llama-v3p1-8b-instruct",
|
||||
credentials={"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY")},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
|
||||
def test_invoke_chat_model(setup_openai_mock):
|
||||
model = FireworksLargeLanguageModel()
|
||||
|
||||
result = model.invoke(
|
||||
model="accounts/fireworks/models/llama-v3p1-8b-instruct",
|
||||
credentials={"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY")},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
model_parameters={
|
||||
"temperature": 0.0,
|
||||
"top_p": 1.0,
|
||||
"presence_penalty": 0.0,
|
||||
"frequency_penalty": 0.0,
|
||||
"max_tokens": 10,
|
||||
},
|
||||
stop=["How"],
|
||||
stream=False,
|
||||
user="foo",
|
||||
)
|
||||
|
||||
assert isinstance(result, LLMResult)
|
||||
assert len(result.message.content) > 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
|
||||
def test_invoke_chat_model_with_tools(setup_openai_mock):
|
||||
model = FireworksLargeLanguageModel()
|
||||
|
||||
result = model.invoke(
|
||||
model="accounts/fireworks/models/llama-v3p1-8b-instruct",
|
||||
credentials={"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY")},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(
|
||||
content="what's the weather today in London?",
|
||||
),
|
||||
],
|
||||
model_parameters={"temperature": 0.0, "max_tokens": 100},
|
||||
tools=[
|
||||
PromptMessageTool(
|
||||
name="get_weather",
|
||||
description="Determine weather in my location",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
|
||||
"unit": {"type": "string", "enum": ["c", "f"]},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
),
|
||||
PromptMessageTool(
|
||||
name="get_stock_price",
|
||||
description="Get the current stock price",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {"symbol": {"type": "string", "description": "The stock symbol"}},
|
||||
"required": ["symbol"],
|
||||
},
|
||||
),
|
||||
],
|
||||
stream=False,
|
||||
user="foo",
|
||||
)
|
||||
|
||||
assert isinstance(result, LLMResult)
|
||||
assert isinstance(result.message, AssistantPromptMessage)
|
||||
assert len(result.message.tool_calls) > 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
|
||||
def test_invoke_stream_chat_model(setup_openai_mock):
|
||||
model = FireworksLargeLanguageModel()
|
||||
|
||||
result = model.invoke(
|
||||
model="accounts/fireworks/models/llama-v3p1-8b-instruct",
|
||||
credentials={"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY")},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
model_parameters={"temperature": 0.0, "max_tokens": 100},
|
||||
stream=True,
|
||||
user="foo",
|
||||
)
|
||||
|
||||
assert isinstance(result, Generator)
|
||||
|
||||
for chunk in result:
|
||||
assert isinstance(chunk, LLMResultChunk)
|
||||
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
|
||||
if chunk.delta.finish_reason is not None:
|
||||
assert chunk.delta.usage is not None
|
||||
assert chunk.delta.usage.completion_tokens > 0
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = FireworksLargeLanguageModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model="accounts/fireworks/models/llama-v3p1-8b-instruct",
|
||||
credentials={"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY")},
|
||||
prompt_messages=[UserPromptMessage(content="Hello World!")],
|
||||
)
|
||||
|
||||
assert num_tokens == 10
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model="accounts/fireworks/models/llama-v3p1-8b-instruct",
|
||||
credentials={"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY")},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
tools=[
|
||||
PromptMessageTool(
|
||||
name="get_weather",
|
||||
description="Determine weather in my location",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
|
||||
"unit": {"type": "string", "enum": ["c", "f"]},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
assert num_tokens == 77
|
||||
@ -1,17 +0,0 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.fireworks.fireworks import FireworksProvider
|
||||
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
|
||||
def test_validate_provider_credentials(setup_openai_mock):
|
||||
provider = FireworksProvider()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
provider.validate_provider_credentials(credentials={})
|
||||
|
||||
provider.validate_provider_credentials(credentials={"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY")})
|
||||
@ -1,54 +0,0 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.fireworks.text_embedding.text_embedding import FireworksTextEmbeddingModel
|
||||
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_openai_mock", [["text_embedding"]], indirect=True)
|
||||
def test_validate_credentials(setup_openai_mock):
|
||||
model = FireworksTextEmbeddingModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model="nomic-ai/nomic-embed-text-v1.5", credentials={"fireworks_api_key": "invalid_key"}
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model="nomic-ai/nomic-embed-text-v1.5", credentials={"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY")}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_openai_mock", [["text_embedding"]], indirect=True)
|
||||
def test_invoke_model(setup_openai_mock):
|
||||
model = FireworksTextEmbeddingModel()
|
||||
|
||||
result = model.invoke(
|
||||
model="nomic-ai/nomic-embed-text-v1.5",
|
||||
credentials={
|
||||
"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY"),
|
||||
},
|
||||
texts=["hello", "world", " ".join(["long_text"] * 100), " ".join(["another_long_text"] * 100)],
|
||||
user="foo",
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
assert len(result.embeddings) == 4
|
||||
assert result.usage.total_tokens == 2
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = FireworksTextEmbeddingModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model="nomic-ai/nomic-embed-text-v1.5",
|
||||
credentials={
|
||||
"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY"),
|
||||
},
|
||||
texts=["hello", "world"],
|
||||
)
|
||||
|
||||
assert num_tokens == 2
|
||||
@ -1,33 +0,0 @@
|
||||
import os
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.fishaudio.fishaudio import FishAudioProvider
|
||||
from tests.integration_tests.model_runtime.__mock.fishaudio import setup_fishaudio_mock
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_fishaudio_mock", [["list-models"]], indirect=True)
|
||||
def test_validate_provider_credentials(setup_fishaudio_mock):
|
||||
print("-----", httpx.get)
|
||||
provider = FishAudioProvider()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
provider.validate_provider_credentials(
|
||||
credentials={
|
||||
"api_key": "bad_api_key",
|
||||
"api_base": os.environ.get("FISH_AUDIO_API_BASE", "https://api.fish.audio"),
|
||||
"use_public_models": "false",
|
||||
"latency": "normal",
|
||||
}
|
||||
)
|
||||
|
||||
provider.validate_provider_credentials(
|
||||
credentials={
|
||||
"api_key": os.environ.get("FISH_AUDIO_API_KEY", "test"),
|
||||
"api_base": os.environ.get("FISH_AUDIO_API_BASE", "https://api.fish.audio"),
|
||||
"use_public_models": "false",
|
||||
"latency": "normal",
|
||||
}
|
||||
)
|
||||
@ -1,32 +0,0 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.model_providers.fishaudio.tts.tts import (
|
||||
FishAudioText2SpeechModel,
|
||||
)
|
||||
from tests.integration_tests.model_runtime.__mock.fishaudio import setup_fishaudio_mock
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_fishaudio_mock", [["tts"]], indirect=True)
|
||||
def test_invoke_model(setup_fishaudio_mock):
|
||||
model = FishAudioText2SpeechModel()
|
||||
|
||||
result = model.invoke(
|
||||
model="tts-default",
|
||||
tenant_id="test",
|
||||
credentials={
|
||||
"api_key": os.environ.get("FISH_AUDIO_API_KEY", "test"),
|
||||
"api_base": os.environ.get("FISH_AUDIO_API_BASE", "https://api.fish.audio"),
|
||||
"use_public_models": "false",
|
||||
"latency": "normal",
|
||||
},
|
||||
content_text="Hello, world!",
|
||||
voice="03397b4c4be74759b72533b663fbd001",
|
||||
)
|
||||
|
||||
content = b""
|
||||
for chunk in result:
|
||||
content += chunk
|
||||
|
||||
assert content != b""
|
||||
@ -1,132 +0,0 @@
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.gitee_ai.llm.llm import GiteeAILargeLanguageModel
|
||||
|
||||
|
||||
def test_predefined_models():
|
||||
model = GiteeAILargeLanguageModel()
|
||||
model_schemas = model.predefined_models()
|
||||
|
||||
assert len(model_schemas) >= 1
|
||||
assert isinstance(model_schemas[0], AIModelEntity)
|
||||
|
||||
|
||||
def test_validate_credentials_for_chat_model():
|
||||
model = GiteeAILargeLanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
# model name to gpt-3.5-turbo because of mocking
|
||||
model.validate_credentials(model="gpt-3.5-turbo", credentials={"api_key": "invalid_key"})
|
||||
|
||||
model.validate_credentials(
|
||||
model="Qwen2-7B-Instruct",
|
||||
credentials={"api_key": os.environ.get("GITEE_AI_API_KEY")},
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_chat_model():
|
||||
model = GiteeAILargeLanguageModel()
|
||||
|
||||
result = model.invoke(
|
||||
model="Qwen2-7B-Instruct",
|
||||
credentials={"api_key": os.environ.get("GITEE_AI_API_KEY")},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
model_parameters={
|
||||
"temperature": 0.0,
|
||||
"top_p": 1.0,
|
||||
"presence_penalty": 0.0,
|
||||
"frequency_penalty": 0.0,
|
||||
"max_tokens": 10,
|
||||
"stream": False,
|
||||
},
|
||||
stop=["How"],
|
||||
stream=False,
|
||||
user="foo",
|
||||
)
|
||||
|
||||
assert isinstance(result, LLMResult)
|
||||
assert len(result.message.content) > 0
|
||||
|
||||
|
||||
def test_invoke_stream_chat_model():
|
||||
model = GiteeAILargeLanguageModel()
|
||||
|
||||
result = model.invoke(
|
||||
model="Qwen2-7B-Instruct",
|
||||
credentials={"api_key": os.environ.get("GITEE_AI_API_KEY")},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
model_parameters={"temperature": 0.0, "max_tokens": 100, "stream": False},
|
||||
stream=True,
|
||||
user="foo",
|
||||
)
|
||||
|
||||
assert isinstance(result, Generator)
|
||||
|
||||
for chunk in result:
|
||||
assert isinstance(chunk, LLMResultChunk)
|
||||
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
|
||||
if chunk.delta.finish_reason is not None:
|
||||
assert chunk.delta.usage is not None
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = GiteeAILargeLanguageModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model="Qwen2-7B-Instruct",
|
||||
credentials={"api_key": os.environ.get("GITEE_AI_API_KEY")},
|
||||
prompt_messages=[UserPromptMessage(content="Hello World!")],
|
||||
)
|
||||
|
||||
assert num_tokens == 10
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model="Qwen2-7B-Instruct",
|
||||
credentials={"api_key": os.environ.get("GITEE_AI_API_KEY")},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
tools=[
|
||||
PromptMessageTool(
|
||||
name="get_weather",
|
||||
description="Determine weather in my location",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
|
||||
"unit": {"type": "string", "enum": ["c", "f"]},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
assert num_tokens == 77
|
||||
@ -1,15 +0,0 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.gitee_ai.gitee_ai import GiteeAIProvider
|
||||
|
||||
|
||||
def test_validate_provider_credentials():
|
||||
provider = GiteeAIProvider()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
provider.validate_provider_credentials(credentials={"api_key": "invalid_key"})
|
||||
|
||||
provider.validate_provider_credentials(credentials={"api_key": os.environ.get("GITEE_AI_API_KEY")})
|
||||
@ -1,47 +0,0 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.rerank_entities import RerankResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.gitee_ai.rerank.rerank import GiteeAIRerankModel
|
||||
|
||||
|
||||
def test_validate_credentials():
|
||||
model = GiteeAIRerankModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model="bge-reranker-v2-m3",
|
||||
credentials={"api_key": "invalid_key"},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model="bge-reranker-v2-m3",
|
||||
credentials={
|
||||
"api_key": os.environ.get("GITEE_AI_API_KEY"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
model = GiteeAIRerankModel()
|
||||
result = model.invoke(
|
||||
model="bge-reranker-v2-m3",
|
||||
credentials={
|
||||
"api_key": os.environ.get("GITEE_AI_API_KEY"),
|
||||
},
|
||||
query="What is the capital of the United States?",
|
||||
docs=[
|
||||
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
|
||||
"Census, Carson City had a population of 55,274.",
|
||||
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
|
||||
"are a political division controlled by the United States. Its capital is Saipan.",
|
||||
],
|
||||
top_n=1,
|
||||
score_threshold=0.01,
|
||||
)
|
||||
|
||||
assert isinstance(result, RerankResult)
|
||||
assert len(result.docs) == 1
|
||||
assert result.docs[0].score >= 0.01
|
||||
@ -1,45 +0,0 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.gitee_ai.speech2text.speech2text import GiteeAISpeech2TextModel
|
||||
|
||||
|
||||
def test_validate_credentials():
|
||||
model = GiteeAISpeech2TextModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model="whisper-base",
|
||||
credentials={"api_key": "invalid_key"},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model="whisper-base",
|
||||
credentials={"api_key": os.environ.get("GITEE_AI_API_KEY")},
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
model = GiteeAISpeech2TextModel()
|
||||
|
||||
# Get the directory of the current file
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# Get assets directory
|
||||
assets_dir = os.path.join(os.path.dirname(current_dir), "assets")
|
||||
|
||||
# Construct the path to the audio file
|
||||
audio_file_path = os.path.join(assets_dir, "audio.mp3")
|
||||
|
||||
# Open the file and get the file object
|
||||
with open(audio_file_path, "rb") as audio_file:
|
||||
file = audio_file
|
||||
|
||||
result = model.invoke(
|
||||
model="whisper-base", credentials={"api_key": os.environ.get("GITEE_AI_API_KEY")}, file=file
|
||||
)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert result == "1 2 3 4 5 6 7 8 9 10"
|
||||
@ -1,46 +0,0 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.gitee_ai.text_embedding.text_embedding import GiteeAIEmbeddingModel
|
||||
|
||||
|
||||
def test_validate_credentials():
|
||||
model = GiteeAIEmbeddingModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(model="bge-large-zh-v1.5", credentials={"api_key": "invalid_key"})
|
||||
|
||||
model.validate_credentials(model="bge-large-zh-v1.5", credentials={"api_key": os.environ.get("GITEE_AI_API_KEY")})
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
model = GiteeAIEmbeddingModel()
|
||||
|
||||
result = model.invoke(
|
||||
model="bge-large-zh-v1.5",
|
||||
credentials={
|
||||
"api_key": os.environ.get("GITEE_AI_API_KEY"),
|
||||
},
|
||||
texts=["hello", "world"],
|
||||
user="user",
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
assert len(result.embeddings) == 2
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = GiteeAIEmbeddingModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model="bge-large-zh-v1.5",
|
||||
credentials={
|
||||
"api_key": os.environ.get("GITEE_AI_API_KEY"),
|
||||
},
|
||||
texts=["hello", "world"],
|
||||
)
|
||||
|
||||
assert num_tokens == 2
|
||||
@ -1,23 +0,0 @@
|
||||
import os
|
||||
|
||||
from core.model_runtime.model_providers.gitee_ai.tts.tts import GiteeAIText2SpeechModel
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
model = GiteeAIText2SpeechModel()
|
||||
|
||||
result = model.invoke(
|
||||
model="speecht5_tts",
|
||||
tenant_id="test",
|
||||
credentials={
|
||||
"api_key": os.environ.get("GITEE_AI_API_KEY"),
|
||||
},
|
||||
content_text="Hello, world!",
|
||||
voice="",
|
||||
)
|
||||
|
||||
content = b""
|
||||
for chunk in result:
|
||||
content += chunk
|
||||
|
||||
assert content != b""
|
||||
File diff suppressed because one or more lines are too long
@ -1,17 +0,0 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.google.google import GoogleProvider
|
||||
from tests.integration_tests.model_runtime.__mock.google import setup_google_mock
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_google_mock", [["none"]], indirect=True)
|
||||
def test_validate_provider_credentials(setup_google_mock):
|
||||
provider = GoogleProvider()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
provider.validate_provider_credentials(credentials={})
|
||||
|
||||
provider.validate_provider_credentials(credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")})
|
||||
@ -1,49 +0,0 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.gpustack.text_embedding.text_embedding import (
|
||||
GPUStackTextEmbeddingModel,
|
||||
)
|
||||
|
||||
|
||||
def test_validate_credentials():
|
||||
model = GPUStackTextEmbeddingModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model="bge-m3",
|
||||
credentials={
|
||||
"endpoint_url": "invalid_url",
|
||||
"api_key": "invalid_api_key",
|
||||
},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model="bge-m3",
|
||||
credentials={
|
||||
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
|
||||
"api_key": os.environ.get("GPUSTACK_API_KEY"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
model = GPUStackTextEmbeddingModel()
|
||||
|
||||
result = model.invoke(
|
||||
model="bge-m3",
|
||||
credentials={
|
||||
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
|
||||
"api_key": os.environ.get("GPUSTACK_API_KEY"),
|
||||
"context_size": 8192,
|
||||
},
|
||||
texts=["hello", "world"],
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
assert len(result.embeddings) == 2
|
||||
assert result.usage.total_tokens == 7
|
||||
@ -1,162 +0,0 @@
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.llm_entities import (
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkDelta,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.gpustack.llm.llm import GPUStackLanguageModel
|
||||
|
||||
|
||||
def test_validate_credentials_for_chat_model():
|
||||
model = GPUStackLanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model="llama-3.2-1b-instruct",
|
||||
credentials={
|
||||
"endpoint_url": "invalid_url",
|
||||
"api_key": "invalid_api_key",
|
||||
"mode": "chat",
|
||||
},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model="llama-3.2-1b-instruct",
|
||||
credentials={
|
||||
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
|
||||
"api_key": os.environ.get("GPUSTACK_API_KEY"),
|
||||
"mode": "chat",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_completion_model():
|
||||
model = GPUStackLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model="llama-3.2-1b-instruct",
|
||||
credentials={
|
||||
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
|
||||
"api_key": os.environ.get("GPUSTACK_API_KEY"),
|
||||
"mode": "completion",
|
||||
},
|
||||
prompt_messages=[UserPromptMessage(content="ping")],
|
||||
model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10},
|
||||
stop=[],
|
||||
user="abc-123",
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
assert response.usage.total_tokens > 0
|
||||
|
||||
|
||||
def test_invoke_chat_model():
|
||||
model = GPUStackLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model="llama-3.2-1b-instruct",
|
||||
credentials={
|
||||
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
|
||||
"api_key": os.environ.get("GPUSTACK_API_KEY"),
|
||||
"mode": "chat",
|
||||
},
|
||||
prompt_messages=[UserPromptMessage(content="ping")],
|
||||
model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10},
|
||||
stop=[],
|
||||
user="abc-123",
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
assert response.usage.total_tokens > 0
|
||||
|
||||
|
||||
def test_invoke_stream_chat_model():
|
||||
model = GPUStackLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model="llama-3.2-1b-instruct",
|
||||
credentials={
|
||||
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
|
||||
"api_key": os.environ.get("GPUSTACK_API_KEY"),
|
||||
"mode": "chat",
|
||||
},
|
||||
prompt_messages=[UserPromptMessage(content="Hello World!")],
|
||||
model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10},
|
||||
stop=["you"],
|
||||
stream=True,
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
for chunk in response:
|
||||
assert isinstance(chunk, LLMResultChunk)
|
||||
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = GPUStackLanguageModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model="????",
|
||||
credentials={
|
||||
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
|
||||
"api_key": os.environ.get("GPUSTACK_API_KEY"),
|
||||
"mode": "chat",
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
tools=[
|
||||
PromptMessageTool(
|
||||
name="get_current_weather",
|
||||
description="Get the current weather in a given location",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state e.g. San Francisco, CA",
|
||||
},
|
||||
"unit": {"type": "string", "enum": ["c", "f"]},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
assert isinstance(num_tokens, int)
|
||||
assert num_tokens == 80
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model="????",
|
||||
credentials={
|
||||
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
|
||||
"api_key": os.environ.get("GPUSTACK_API_KEY"),
|
||||
"mode": "chat",
|
||||
},
|
||||
prompt_messages=[UserPromptMessage(content="Hello World!")],
|
||||
)
|
||||
|
||||
assert isinstance(num_tokens, int)
|
||||
assert num_tokens == 10
|
||||
@ -1,107 +0,0 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.gpustack.rerank.rerank import (
|
||||
GPUStackRerankModel,
|
||||
)
|
||||
|
||||
|
||||
def test_validate_credentials_for_rerank_model():
|
||||
model = GPUStackRerankModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model="bge-reranker-v2-m3",
|
||||
credentials={
|
||||
"endpoint_url": "invalid_url",
|
||||
"api_key": "invalid_api_key",
|
||||
},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model="bge-reranker-v2-m3",
|
||||
credentials={
|
||||
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
|
||||
"api_key": os.environ.get("GPUSTACK_API_KEY"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_rerank_model():
|
||||
model = GPUStackRerankModel()
|
||||
|
||||
response = model.invoke(
|
||||
model="bge-reranker-v2-m3",
|
||||
credentials={
|
||||
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
|
||||
"api_key": os.environ.get("GPUSTACK_API_KEY"),
|
||||
},
|
||||
query="Organic skincare products for sensitive skin",
|
||||
docs=[
|
||||
"Eco-friendly kitchenware for modern homes",
|
||||
"Biodegradable cleaning supplies for eco-conscious consumers",
|
||||
"Organic cotton baby clothes for sensitive skin",
|
||||
"Natural organic skincare range for sensitive skin",
|
||||
"Tech gadgets for smart homes: 2024 edition",
|
||||
"Sustainable gardening tools and compost solutions",
|
||||
"Sensitive skin-friendly facial cleansers and toners",
|
||||
"Organic food wraps and storage solutions",
|
||||
"Yoga mats made from recycled materials",
|
||||
],
|
||||
top_n=3,
|
||||
score_threshold=-0.75,
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, RerankResult)
|
||||
assert len(response.docs) == 3
|
||||
|
||||
|
||||
def test__invoke():
|
||||
model = GPUStackRerankModel()
|
||||
|
||||
# Test case 1: Empty docs
|
||||
result = model._invoke(
|
||||
model="bge-reranker-v2-m3",
|
||||
credentials={
|
||||
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
|
||||
"api_key": os.environ.get("GPUSTACK_API_KEY"),
|
||||
},
|
||||
query="Organic skincare products for sensitive skin",
|
||||
docs=[],
|
||||
top_n=3,
|
||||
score_threshold=0.75,
|
||||
user="abc-123",
|
||||
)
|
||||
assert isinstance(result, RerankResult)
|
||||
assert len(result.docs) == 0
|
||||
|
||||
# Test case 2: Expected docs
|
||||
result = model._invoke(
|
||||
model="bge-reranker-v2-m3",
|
||||
credentials={
|
||||
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
|
||||
"api_key": os.environ.get("GPUSTACK_API_KEY"),
|
||||
},
|
||||
query="Organic skincare products for sensitive skin",
|
||||
docs=[
|
||||
"Eco-friendly kitchenware for modern homes",
|
||||
"Biodegradable cleaning supplies for eco-conscious consumers",
|
||||
"Organic cotton baby clothes for sensitive skin",
|
||||
"Natural organic skincare range for sensitive skin",
|
||||
"Tech gadgets for smart homes: 2024 edition",
|
||||
"Sustainable gardening tools and compost solutions",
|
||||
"Sensitive skin-friendly facial cleansers and toners",
|
||||
"Organic food wraps and storage solutions",
|
||||
"Yoga mats made from recycled materials",
|
||||
],
|
||||
top_n=3,
|
||||
score_threshold=-0.75,
|
||||
user="abc-123",
|
||||
)
|
||||
assert isinstance(result, RerankResult)
|
||||
assert len(result.docs) == 3
|
||||
assert all(isinstance(doc, RerankDocument) for doc in result.docs)
|
||||
@ -1,55 +0,0 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.gpustack.speech2text.speech2text import GPUStackSpeech2TextModel
|
||||
|
||||
|
||||
def test_validate_credentials():
|
||||
model = GPUStackSpeech2TextModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model="faster-whisper-medium",
|
||||
credentials={
|
||||
"endpoint_url": "invalid_url",
|
||||
"api_key": "invalid_api_key",
|
||||
},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model="faster-whisper-medium",
|
||||
credentials={
|
||||
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
|
||||
"api_key": os.environ.get("GPUSTACK_API_KEY"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
model = GPUStackSpeech2TextModel()
|
||||
|
||||
# Get the directory of the current file
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# Get assets directory
|
||||
assets_dir = os.path.join(os.path.dirname(current_dir), "assets")
|
||||
|
||||
# Construct the path to the audio file
|
||||
audio_file_path = os.path.join(assets_dir, "audio.mp3")
|
||||
|
||||
file = Path(audio_file_path).read_bytes()
|
||||
|
||||
result = model.invoke(
|
||||
model="faster-whisper-medium",
|
||||
credentials={
|
||||
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
|
||||
"api_key": os.environ.get("GPUSTACK_API_KEY"),
|
||||
},
|
||||
file=file,
|
||||
)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert result == "1, 2, 3, 4, 5, 6, 7, 8, 9, 10"
|
||||
@ -1,24 +0,0 @@
|
||||
import os
|
||||
|
||||
from core.model_runtime.model_providers.gpustack.tts.tts import GPUStackText2SpeechModel
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
model = GPUStackText2SpeechModel()
|
||||
|
||||
result = model.invoke(
|
||||
model="cosyvoice-300m-sft",
|
||||
tenant_id="test",
|
||||
credentials={
|
||||
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
|
||||
"api_key": os.environ.get("GPUSTACK_API_KEY"),
|
||||
},
|
||||
content_text="Hello world",
|
||||
voice="Chinese Female",
|
||||
)
|
||||
|
||||
content = b""
|
||||
for chunk in result:
|
||||
content += chunk
|
||||
|
||||
assert content != b""
|
||||
@ -1,278 +0,0 @@
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.huggingface_hub.llm.llm import HuggingfaceHubLargeLanguageModel
|
||||
from tests.integration_tests.model_runtime.__mock.huggingface import setup_huggingface_mock
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
|
||||
def test_hosted_inference_api_validate_credentials(setup_huggingface_mock):
|
||||
model = HuggingfaceHubLargeLanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model="HuggingFaceH4/zephyr-7b-beta",
|
||||
credentials={"huggingfacehub_api_type": "hosted_inference_api", "huggingfacehub_api_token": "invalid_key"},
|
||||
)
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model="fake-model",
|
||||
credentials={"huggingfacehub_api_type": "hosted_inference_api", "huggingfacehub_api_token": "invalid_key"},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model="HuggingFaceH4/zephyr-7b-beta",
|
||||
credentials={
|
||||
"huggingfacehub_api_type": "hosted_inference_api",
|
||||
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
|
||||
def test_hosted_inference_api_invoke_model(setup_huggingface_mock):
|
||||
model = HuggingfaceHubLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model="HuggingFaceH4/zephyr-7b-beta",
|
||||
credentials={
|
||||
"huggingfacehub_api_type": "hosted_inference_api",
|
||||
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
|
||||
},
|
||||
prompt_messages=[UserPromptMessage(content="Who are you?")],
|
||||
model_parameters={
|
||||
"temperature": 1.0,
|
||||
"top_k": 2,
|
||||
"top_p": 0.5,
|
||||
},
|
||||
stop=["How"],
|
||||
stream=False,
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
|
||||
def test_hosted_inference_api_invoke_stream_model(setup_huggingface_mock):
|
||||
model = HuggingfaceHubLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model="HuggingFaceH4/zephyr-7b-beta",
|
||||
credentials={
|
||||
"huggingfacehub_api_type": "hosted_inference_api",
|
||||
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
|
||||
},
|
||||
prompt_messages=[UserPromptMessage(content="Who are you?")],
|
||||
model_parameters={
|
||||
"temperature": 1.0,
|
||||
"top_k": 2,
|
||||
"top_p": 0.5,
|
||||
},
|
||||
stop=["How"],
|
||||
stream=True,
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
|
||||
for chunk in response:
|
||||
assert isinstance(chunk, LLMResultChunk)
|
||||
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
|
||||
def test_inference_endpoints_text_generation_validate_credentials(setup_huggingface_mock):
|
||||
model = HuggingfaceHubLargeLanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model="openchat/openchat_3.5",
|
||||
credentials={
|
||||
"huggingfacehub_api_type": "inference_endpoints",
|
||||
"huggingfacehub_api_token": "invalid_key",
|
||||
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT_GEN_ENDPOINT_URL"),
|
||||
"task_type": "text-generation",
|
||||
},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model="openchat/openchat_3.5",
|
||||
credentials={
|
||||
"huggingfacehub_api_type": "inference_endpoints",
|
||||
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
|
||||
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT_GEN_ENDPOINT_URL"),
|
||||
"task_type": "text-generation",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
|
||||
def test_inference_endpoints_text_generation_invoke_model(setup_huggingface_mock):
|
||||
model = HuggingfaceHubLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model="openchat/openchat_3.5",
|
||||
credentials={
|
||||
"huggingfacehub_api_type": "inference_endpoints",
|
||||
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
|
||||
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT_GEN_ENDPOINT_URL"),
|
||||
"task_type": "text-generation",
|
||||
},
|
||||
prompt_messages=[UserPromptMessage(content="Who are you?")],
|
||||
model_parameters={
|
||||
"temperature": 1.0,
|
||||
"top_k": 2,
|
||||
"top_p": 0.5,
|
||||
},
|
||||
stop=["How"],
|
||||
stream=False,
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
|
||||
def test_inference_endpoints_text_generation_invoke_stream_model(setup_huggingface_mock):
|
||||
model = HuggingfaceHubLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model="openchat/openchat_3.5",
|
||||
credentials={
|
||||
"huggingfacehub_api_type": "inference_endpoints",
|
||||
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
|
||||
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT_GEN_ENDPOINT_URL"),
|
||||
"task_type": "text-generation",
|
||||
},
|
||||
prompt_messages=[UserPromptMessage(content="Who are you?")],
|
||||
model_parameters={
|
||||
"temperature": 1.0,
|
||||
"top_k": 2,
|
||||
"top_p": 0.5,
|
||||
},
|
||||
stop=["How"],
|
||||
stream=True,
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
|
||||
for chunk in response:
|
||||
assert isinstance(chunk, LLMResultChunk)
|
||||
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
|
||||
def test_inference_endpoints_text2text_generation_validate_credentials(setup_huggingface_mock):
|
||||
model = HuggingfaceHubLargeLanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model="google/mt5-base",
|
||||
credentials={
|
||||
"huggingfacehub_api_type": "inference_endpoints",
|
||||
"huggingfacehub_api_token": "invalid_key",
|
||||
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL"),
|
||||
"task_type": "text2text-generation",
|
||||
},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model="google/mt5-base",
|
||||
credentials={
|
||||
"huggingfacehub_api_type": "inference_endpoints",
|
||||
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
|
||||
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL"),
|
||||
"task_type": "text2text-generation",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
|
||||
def test_inference_endpoints_text2text_generation_invoke_model(setup_huggingface_mock):
|
||||
model = HuggingfaceHubLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model="google/mt5-base",
|
||||
credentials={
|
||||
"huggingfacehub_api_type": "inference_endpoints",
|
||||
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
|
||||
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL"),
|
||||
"task_type": "text2text-generation",
|
||||
},
|
||||
prompt_messages=[UserPromptMessage(content="Who are you?")],
|
||||
model_parameters={
|
||||
"temperature": 1.0,
|
||||
"top_k": 2,
|
||||
"top_p": 0.5,
|
||||
},
|
||||
stop=["How"],
|
||||
stream=False,
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
|
||||
def test_inference_endpoints_text2text_generation_invoke_stream_model(setup_huggingface_mock):
|
||||
model = HuggingfaceHubLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model="google/mt5-base",
|
||||
credentials={
|
||||
"huggingfacehub_api_type": "inference_endpoints",
|
||||
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
|
||||
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL"),
|
||||
"task_type": "text2text-generation",
|
||||
},
|
||||
prompt_messages=[UserPromptMessage(content="Who are you?")],
|
||||
model_parameters={
|
||||
"temperature": 1.0,
|
||||
"top_k": 2,
|
||||
"top_p": 0.5,
|
||||
},
|
||||
stop=["How"],
|
||||
stream=True,
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
|
||||
for chunk in response:
|
||||
assert isinstance(chunk, LLMResultChunk)
|
||||
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = HuggingfaceHubLargeLanguageModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model="google/mt5-base",
|
||||
credentials={
|
||||
"huggingfacehub_api_type": "inference_endpoints",
|
||||
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
|
||||
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL"),
|
||||
"task_type": "text2text-generation",
|
||||
},
|
||||
prompt_messages=[UserPromptMessage(content="Hello World!")],
|
||||
)
|
||||
|
||||
assert num_tokens == 7
|
||||
@ -1,112 +0,0 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.huggingface_hub.text_embedding.text_embedding import (
|
||||
HuggingfaceHubTextEmbeddingModel,
|
||||
)
|
||||
|
||||
|
||||
def test_hosted_inference_api_validate_credentials():
|
||||
model = HuggingfaceHubTextEmbeddingModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model="facebook/bart-base",
|
||||
credentials={
|
||||
"huggingfacehub_api_type": "hosted_inference_api",
|
||||
"huggingfacehub_api_token": "invalid_key",
|
||||
},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model="facebook/bart-base",
|
||||
credentials={
|
||||
"huggingfacehub_api_type": "hosted_inference_api",
|
||||
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_hosted_inference_api_invoke_model():
|
||||
model = HuggingfaceHubTextEmbeddingModel()
|
||||
|
||||
result = model.invoke(
|
||||
model="facebook/bart-base",
|
||||
credentials={
|
||||
"huggingfacehub_api_type": "hosted_inference_api",
|
||||
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
|
||||
},
|
||||
texts=["hello", "world"],
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
assert len(result.embeddings) == 2
|
||||
assert result.usage.total_tokens == 2
|
||||
|
||||
|
||||
def test_inference_endpoints_validate_credentials():
|
||||
model = HuggingfaceHubTextEmbeddingModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model="all-MiniLM-L6-v2",
|
||||
credentials={
|
||||
"huggingfacehub_api_type": "inference_endpoints",
|
||||
"huggingfacehub_api_token": "invalid_key",
|
||||
"huggingface_namespace": "Dify-AI",
|
||||
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL"),
|
||||
"task_type": "feature-extraction",
|
||||
},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model="all-MiniLM-L6-v2",
|
||||
credentials={
|
||||
"huggingfacehub_api_type": "inference_endpoints",
|
||||
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
|
||||
"huggingface_namespace": "Dify-AI",
|
||||
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL"),
|
||||
"task_type": "feature-extraction",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_inference_endpoints_invoke_model():
|
||||
model = HuggingfaceHubTextEmbeddingModel()
|
||||
|
||||
result = model.invoke(
|
||||
model="all-MiniLM-L6-v2",
|
||||
credentials={
|
||||
"huggingfacehub_api_type": "inference_endpoints",
|
||||
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
|
||||
"huggingface_namespace": "Dify-AI",
|
||||
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL"),
|
||||
"task_type": "feature-extraction",
|
||||
},
|
||||
texts=["hello", "world"],
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
assert len(result.embeddings) == 2
|
||||
assert result.usage.total_tokens == 0
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = HuggingfaceHubTextEmbeddingModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model="all-MiniLM-L6-v2",
|
||||
credentials={
|
||||
"huggingfacehub_api_type": "inference_endpoints",
|
||||
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
|
||||
"huggingface_namespace": "Dify-AI",
|
||||
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL"),
|
||||
"task_type": "feature-extraction",
|
||||
},
|
||||
texts=["hello", "world"],
|
||||
)
|
||||
|
||||
assert num_tokens == 2
|
||||
@ -1,73 +0,0 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.huggingface_tei.text_embedding.text_embedding import (
|
||||
HuggingfaceTeiTextEmbeddingModel,
|
||||
TeiHelper,
|
||||
)
|
||||
from tests.integration_tests.model_runtime.__mock.huggingface_tei import MockTEIClass
|
||||
|
||||
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_tei_mock(request, monkeypatch: pytest.MonkeyPatch):
|
||||
if MOCK:
|
||||
monkeypatch.setattr(TeiHelper, "get_tei_extra_parameter", MockTEIClass.get_tei_extra_parameter)
|
||||
monkeypatch.setattr(TeiHelper, "invoke_tokenize", MockTEIClass.invoke_tokenize)
|
||||
monkeypatch.setattr(TeiHelper, "invoke_embeddings", MockTEIClass.invoke_embeddings)
|
||||
monkeypatch.setattr(TeiHelper, "invoke_rerank", MockTEIClass.invoke_rerank)
|
||||
yield
|
||||
|
||||
if MOCK:
|
||||
monkeypatch.undo()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_tei_mock", [["none"]], indirect=True)
|
||||
def test_validate_credentials(setup_tei_mock):
|
||||
model = HuggingfaceTeiTextEmbeddingModel()
|
||||
# model name is only used in mock
|
||||
model_name = "embedding"
|
||||
|
||||
if MOCK:
|
||||
# TEI Provider will check model type by API endpoint, at real server, the model type is correct.
|
||||
# So we dont need to check model type here. Only check in mock
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model="reranker",
|
||||
credentials={
|
||||
"server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""),
|
||||
"api_key": os.environ.get("TEI_API_KEY", ""),
|
||||
},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model=model_name,
|
||||
credentials={
|
||||
"server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""),
|
||||
"api_key": os.environ.get("TEI_API_KEY", ""),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_tei_mock", [["none"]], indirect=True)
|
||||
def test_invoke_model(setup_tei_mock):
|
||||
model = HuggingfaceTeiTextEmbeddingModel()
|
||||
model_name = "embedding"
|
||||
|
||||
result = model.invoke(
|
||||
model=model_name,
|
||||
credentials={
|
||||
"server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""),
|
||||
"api_key": os.environ.get("TEI_API_KEY", ""),
|
||||
},
|
||||
texts=["hello", "world"],
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
assert len(result.embeddings) == 2
|
||||
assert result.usage.total_tokens > 0
|
||||
@ -1,80 +0,0 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.rerank_entities import RerankResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.huggingface_tei.rerank.rerank import (
|
||||
HuggingfaceTeiRerankModel,
|
||||
)
|
||||
from core.model_runtime.model_providers.huggingface_tei.text_embedding.text_embedding import TeiHelper
|
||||
from tests.integration_tests.model_runtime.__mock.huggingface_tei import MockTEIClass
|
||||
|
||||
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_tei_mock(request, monkeypatch: pytest.MonkeyPatch):
|
||||
if MOCK:
|
||||
monkeypatch.setattr(TeiHelper, "get_tei_extra_parameter", MockTEIClass.get_tei_extra_parameter)
|
||||
monkeypatch.setattr(TeiHelper, "invoke_tokenize", MockTEIClass.invoke_tokenize)
|
||||
monkeypatch.setattr(TeiHelper, "invoke_embeddings", MockTEIClass.invoke_embeddings)
|
||||
monkeypatch.setattr(TeiHelper, "invoke_rerank", MockTEIClass.invoke_rerank)
|
||||
yield
|
||||
|
||||
if MOCK:
|
||||
monkeypatch.undo()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_tei_mock", [["none"]], indirect=True)
|
||||
def test_validate_credentials(setup_tei_mock):
|
||||
model = HuggingfaceTeiRerankModel()
|
||||
# model name is only used in mock
|
||||
model_name = "reranker"
|
||||
|
||||
if MOCK:
|
||||
# TEI Provider will check model type by API endpoint, at real server, the model type is correct.
|
||||
# So we dont need to check model type here. Only check in mock
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model="embedding",
|
||||
credentials={
|
||||
"server_url": os.environ.get("TEI_RERANK_SERVER_URL"),
|
||||
"api_key": os.environ.get("TEI_API_KEY", ""),
|
||||
},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model=model_name,
|
||||
credentials={
|
||||
"server_url": os.environ.get("TEI_RERANK_SERVER_URL"),
|
||||
"api_key": os.environ.get("TEI_API_KEY", ""),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_tei_mock", [["none"]], indirect=True)
|
||||
def test_invoke_model(setup_tei_mock):
|
||||
model = HuggingfaceTeiRerankModel()
|
||||
# model name is only used in mock
|
||||
model_name = "reranker"
|
||||
|
||||
result = model.invoke(
|
||||
model=model_name,
|
||||
credentials={
|
||||
"server_url": os.environ.get("TEI_RERANK_SERVER_URL"),
|
||||
"api_key": os.environ.get("TEI_API_KEY", ""),
|
||||
},
|
||||
query="Who is Kasumi?",
|
||||
docs=[
|
||||
'Kasumi is a girl\'s name of Japanese origin meaning "mist".',
|
||||
"Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ",
|
||||
"and she leads a team named PopiParty.",
|
||||
],
|
||||
score_threshold=0.8,
|
||||
)
|
||||
|
||||
assert isinstance(result, RerankResult)
|
||||
assert len(result.docs) == 1
|
||||
assert result.docs[0].index == 0
|
||||
assert result.docs[0].score >= 0.8
|
||||
@ -1,90 +0,0 @@
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.hunyuan.llm.llm import HunyuanLargeLanguageModel
|
||||
|
||||
|
||||
def test_validate_credentials():
|
||||
model = HunyuanLargeLanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model="hunyuan-standard", credentials={"secret_id": "invalid_key", "secret_key": "invalid_key"}
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model="hunyuan-standard",
|
||||
credentials={
|
||||
"secret_id": os.environ.get("HUNYUAN_SECRET_ID"),
|
||||
"secret_key": os.environ.get("HUNYUAN_SECRET_KEY"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
model = HunyuanLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model="hunyuan-standard",
|
||||
credentials={
|
||||
"secret_id": os.environ.get("HUNYUAN_SECRET_ID"),
|
||||
"secret_key": os.environ.get("HUNYUAN_SECRET_KEY"),
|
||||
},
|
||||
prompt_messages=[UserPromptMessage(content="Hi")],
|
||||
model_parameters={"temperature": 0.5, "max_tokens": 10},
|
||||
stop=["How"],
|
||||
stream=False,
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
|
||||
|
||||
def test_invoke_stream_model():
|
||||
model = HunyuanLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model="hunyuan-standard",
|
||||
credentials={
|
||||
"secret_id": os.environ.get("HUNYUAN_SECRET_ID"),
|
||||
"secret_key": os.environ.get("HUNYUAN_SECRET_KEY"),
|
||||
},
|
||||
prompt_messages=[UserPromptMessage(content="Hi")],
|
||||
model_parameters={"temperature": 0.5, "max_tokens": 100, "seed": 1234},
|
||||
stream=True,
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
|
||||
for chunk in response:
|
||||
assert isinstance(chunk, LLMResultChunk)
|
||||
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = HunyuanLargeLanguageModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model="hunyuan-standard",
|
||||
credentials={
|
||||
"secret_id": os.environ.get("HUNYUAN_SECRET_ID"),
|
||||
"secret_key": os.environ.get("HUNYUAN_SECRET_KEY"),
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
)
|
||||
|
||||
assert num_tokens == 14
|
||||
@ -1,20 +0,0 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.hunyuan.hunyuan import HunyuanProvider
|
||||
|
||||
|
||||
def test_validate_provider_credentials():
|
||||
provider = HunyuanProvider()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
provider.validate_provider_credentials(credentials={"secret_id": "invalid_key", "secret_key": "invalid_key"})
|
||||
|
||||
provider.validate_provider_credentials(
|
||||
credentials={
|
||||
"secret_id": os.environ.get("HUNYUAN_SECRET_ID"),
|
||||
"secret_key": os.environ.get("HUNYUAN_SECRET_KEY"),
|
||||
}
|
||||
)
|
||||
@ -1,96 +0,0 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.hunyuan.text_embedding.text_embedding import HunyuanTextEmbeddingModel
|
||||
|
||||
|
||||
def test_validate_credentials():
|
||||
model = HunyuanTextEmbeddingModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model="hunyuan-embedding", credentials={"secret_id": "invalid_key", "secret_key": "invalid_key"}
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model="hunyuan-embedding",
|
||||
credentials={
|
||||
"secret_id": os.environ.get("HUNYUAN_SECRET_ID"),
|
||||
"secret_key": os.environ.get("HUNYUAN_SECRET_KEY"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
model = HunyuanTextEmbeddingModel()
|
||||
|
||||
result = model.invoke(
|
||||
model="hunyuan-embedding",
|
||||
credentials={
|
||||
"secret_id": os.environ.get("HUNYUAN_SECRET_ID"),
|
||||
"secret_key": os.environ.get("HUNYUAN_SECRET_KEY"),
|
||||
},
|
||||
texts=["hello", "world"],
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
assert len(result.embeddings) == 2
|
||||
assert result.usage.total_tokens == 6
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = HunyuanTextEmbeddingModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model="hunyuan-embedding",
|
||||
credentials={
|
||||
"secret_id": os.environ.get("HUNYUAN_SECRET_ID"),
|
||||
"secret_key": os.environ.get("HUNYUAN_SECRET_KEY"),
|
||||
},
|
||||
texts=["hello", "world"],
|
||||
)
|
||||
|
||||
assert num_tokens == 2
|
||||
|
||||
|
||||
def test_max_chunks():
|
||||
model = HunyuanTextEmbeddingModel()
|
||||
|
||||
result = model.invoke(
|
||||
model="hunyuan-embedding",
|
||||
credentials={
|
||||
"secret_id": os.environ.get("HUNYUAN_SECRET_ID"),
|
||||
"secret_key": os.environ.get("HUNYUAN_SECRET_KEY"),
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world",
|
||||
"hello",
|
||||
"world",
|
||||
"hello",
|
||||
"world",
|
||||
"hello",
|
||||
"world",
|
||||
"hello",
|
||||
"world",
|
||||
"hello",
|
||||
"world",
|
||||
"hello",
|
||||
"world",
|
||||
"hello",
|
||||
"world",
|
||||
"hello",
|
||||
"world",
|
||||
"hello",
|
||||
"world",
|
||||
"hello",
|
||||
"world",
|
||||
],
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
assert len(result.embeddings) == 22
|
||||
@ -1,15 +0,0 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.jina.jina import JinaProvider
|
||||
|
||||
|
||||
def test_validate_provider_credentials():
|
||||
provider = JinaProvider()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
provider.validate_provider_credentials(credentials={"api_key": "hahahaha"})
|
||||
|
||||
provider.validate_provider_credentials(credentials={"api_key": os.environ.get("JINA_API_KEY")})
|
||||
@ -1,49 +0,0 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.jina.text_embedding.text_embedding import JinaTextEmbeddingModel
|
||||
|
||||
|
||||
def test_validate_credentials():
|
||||
model = JinaTextEmbeddingModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(model="jina-embeddings-v2-base-en", credentials={"api_key": "invalid_key"})
|
||||
|
||||
model.validate_credentials(
|
||||
model="jina-embeddings-v2-base-en", credentials={"api_key": os.environ.get("JINA_API_KEY")}
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
model = JinaTextEmbeddingModel()
|
||||
|
||||
result = model.invoke(
|
||||
model="jina-embeddings-v2-base-en",
|
||||
credentials={
|
||||
"api_key": os.environ.get("JINA_API_KEY"),
|
||||
},
|
||||
texts=["hello", "world"],
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
assert len(result.embeddings) == 2
|
||||
assert result.usage.total_tokens == 6
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = JinaTextEmbeddingModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model="jina-embeddings-v2-base-en",
|
||||
credentials={
|
||||
"api_key": os.environ.get("JINA_API_KEY"),
|
||||
},
|
||||
texts=["hello", "world"],
|
||||
)
|
||||
|
||||
assert num_tokens == 6
|
||||
@ -1,4 +0,0 @@
|
||||
"""
|
||||
LocalAI Embedding Interface is temporarily unavailable due to
|
||||
we could not find a way to test it for now.
|
||||
"""
|
||||
@ -1,172 +0,0 @@
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.localai.llm.llm import LocalAILanguageModel
|
||||
|
||||
|
||||
def test_validate_credentials_for_chat_model():
|
||||
model = LocalAILanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model="chinese-llama-2-7b",
|
||||
credentials={
|
||||
"server_url": "hahahaha",
|
||||
"completion_type": "completion",
|
||||
},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model="chinese-llama-2-7b",
|
||||
credentials={
|
||||
"server_url": os.environ.get("LOCALAI_SERVER_URL"),
|
||||
"completion_type": "completion",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_completion_model():
|
||||
model = LocalAILanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model="chinese-llama-2-7b",
|
||||
credentials={
|
||||
"server_url": os.environ.get("LOCALAI_SERVER_URL"),
|
||||
"completion_type": "completion",
|
||||
},
|
||||
prompt_messages=[UserPromptMessage(content="ping")],
|
||||
model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10},
|
||||
stop=[],
|
||||
user="abc-123",
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
assert response.usage.total_tokens > 0
|
||||
|
||||
|
||||
def test_invoke_chat_model():
|
||||
model = LocalAILanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model="chinese-llama-2-7b",
|
||||
credentials={
|
||||
"server_url": os.environ.get("LOCALAI_SERVER_URL"),
|
||||
"completion_type": "chat_completion",
|
||||
},
|
||||
prompt_messages=[UserPromptMessage(content="ping")],
|
||||
model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10},
|
||||
stop=[],
|
||||
user="abc-123",
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
assert response.usage.total_tokens > 0
|
||||
|
||||
|
||||
def test_invoke_stream_completion_model():
|
||||
model = LocalAILanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model="chinese-llama-2-7b",
|
||||
credentials={
|
||||
"server_url": os.environ.get("LOCALAI_SERVER_URL"),
|
||||
"completion_type": "completion",
|
||||
},
|
||||
prompt_messages=[UserPromptMessage(content="Hello World!")],
|
||||
model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10},
|
||||
stop=["you"],
|
||||
stream=True,
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
for chunk in response:
|
||||
assert isinstance(chunk, LLMResultChunk)
|
||||
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
|
||||
|
||||
|
||||
def test_invoke_stream_chat_model():
|
||||
model = LocalAILanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model="chinese-llama-2-7b",
|
||||
credentials={
|
||||
"server_url": os.environ.get("LOCALAI_SERVER_URL"),
|
||||
"completion_type": "chat_completion",
|
||||
},
|
||||
prompt_messages=[UserPromptMessage(content="Hello World!")],
|
||||
model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10},
|
||||
stop=["you"],
|
||||
stream=True,
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
for chunk in response:
|
||||
assert isinstance(chunk, LLMResultChunk)
|
||||
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = LocalAILanguageModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model="????",
|
||||
credentials={
|
||||
"server_url": os.environ.get("LOCALAI_SERVER_URL"),
|
||||
"completion_type": "chat_completion",
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
tools=[
|
||||
PromptMessageTool(
|
||||
name="get_current_weather",
|
||||
description="Get the current weather in a given location",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
|
||||
"unit": {"type": "string", "enum": ["c", "f"]},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
assert isinstance(num_tokens, int)
|
||||
assert num_tokens == 77
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model="????",
|
||||
credentials={
|
||||
"server_url": os.environ.get("LOCALAI_SERVER_URL"),
|
||||
"completion_type": "chat_completion",
|
||||
},
|
||||
prompt_messages=[UserPromptMessage(content="Hello World!")],
|
||||
)
|
||||
|
||||
assert isinstance(num_tokens, int)
|
||||
assert num_tokens == 10
|
||||
@ -1,96 +0,0 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.localai.rerank.rerank import LocalaiRerankModel
|
||||
|
||||
|
||||
def test_validate_credentials_for_chat_model():
|
||||
model = LocalaiRerankModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model="bge-reranker-v2-m3",
|
||||
credentials={
|
||||
"server_url": "hahahaha",
|
||||
"completion_type": "completion",
|
||||
},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model="bge-reranker-base",
|
||||
credentials={
|
||||
"server_url": os.environ.get("LOCALAI_SERVER_URL"),
|
||||
"completion_type": "completion",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_rerank_model():
|
||||
model = LocalaiRerankModel()
|
||||
|
||||
response = model.invoke(
|
||||
model="bge-reranker-base",
|
||||
credentials={"server_url": os.environ.get("LOCALAI_SERVER_URL")},
|
||||
query="Organic skincare products for sensitive skin",
|
||||
docs=[
|
||||
"Eco-friendly kitchenware for modern homes",
|
||||
"Biodegradable cleaning supplies for eco-conscious consumers",
|
||||
"Organic cotton baby clothes for sensitive skin",
|
||||
"Natural organic skincare range for sensitive skin",
|
||||
"Tech gadgets for smart homes: 2024 edition",
|
||||
"Sustainable gardening tools and compost solutions",
|
||||
"Sensitive skin-friendly facial cleansers and toners",
|
||||
"Organic food wraps and storage solutions",
|
||||
"Yoga mats made from recycled materials",
|
||||
],
|
||||
top_n=3,
|
||||
score_threshold=0.75,
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, RerankResult)
|
||||
assert len(response.docs) == 3
|
||||
|
||||
|
||||
def test__invoke():
|
||||
model = LocalaiRerankModel()
|
||||
|
||||
# Test case 1: Empty docs
|
||||
result = model._invoke(
|
||||
model="bge-reranker-base",
|
||||
credentials={"server_url": "https://example.com", "api_key": "1234567890"},
|
||||
query="Organic skincare products for sensitive skin",
|
||||
docs=[],
|
||||
top_n=3,
|
||||
score_threshold=0.75,
|
||||
user="abc-123",
|
||||
)
|
||||
assert isinstance(result, RerankResult)
|
||||
assert len(result.docs) == 0
|
||||
|
||||
# Test case 2: Valid invocation
|
||||
result = model._invoke(
|
||||
model="bge-reranker-base",
|
||||
credentials={"server_url": "https://example.com", "api_key": "1234567890"},
|
||||
query="Organic skincare products for sensitive skin",
|
||||
docs=[
|
||||
"Eco-friendly kitchenware for modern homes",
|
||||
"Biodegradable cleaning supplies for eco-conscious consumers",
|
||||
"Organic cotton baby clothes for sensitive skin",
|
||||
"Natural organic skincare range for sensitive skin",
|
||||
"Tech gadgets for smart homes: 2024 edition",
|
||||
"Sustainable gardening tools and compost solutions",
|
||||
"Sensitive skin-friendly facial cleansers and toners",
|
||||
"Organic food wraps and storage solutions",
|
||||
"Yoga mats made from recycled materials",
|
||||
],
|
||||
top_n=3,
|
||||
score_threshold=0.75,
|
||||
user="abc-123",
|
||||
)
|
||||
assert isinstance(result, RerankResult)
|
||||
assert len(result.docs) == 3
|
||||
assert all(isinstance(doc, RerankDocument) for doc in result.docs)
|
||||
@ -1,42 +0,0 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.localai.speech2text.speech2text import LocalAISpeech2text
|
||||
|
||||
|
||||
def test_validate_credentials():
|
||||
model = LocalAISpeech2text()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(model="whisper-1", credentials={"server_url": "invalid_url"})
|
||||
|
||||
model.validate_credentials(model="whisper-1", credentials={"server_url": os.environ.get("LOCALAI_SERVER_URL")})
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
model = LocalAISpeech2text()
|
||||
|
||||
# Get the directory of the current file
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# Get assets directory
|
||||
assets_dir = os.path.join(os.path.dirname(current_dir), "assets")
|
||||
|
||||
# Construct the path to the audio file
|
||||
audio_file_path = os.path.join(assets_dir, "audio.mp3")
|
||||
|
||||
# Open the file and get the file object
|
||||
with open(audio_file_path, "rb") as audio_file:
|
||||
file = audio_file
|
||||
|
||||
result = model.invoke(
|
||||
model="whisper-1",
|
||||
credentials={"server_url": os.environ.get("LOCALAI_SERVER_URL")},
|
||||
file=file,
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert result == "1, 2, 3, 4, 5, 6, 7, 8, 9, 10"
|
||||
@ -1,58 +0,0 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.minimax.text_embedding.text_embedding import MinimaxTextEmbeddingModel
|
||||
|
||||
|
||||
def test_validate_credentials():
|
||||
model = MinimaxTextEmbeddingModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model="embo-01",
|
||||
credentials={"minimax_api_key": "invalid_key", "minimax_group_id": os.environ.get("MINIMAX_GROUP_ID")},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model="embo-01",
|
||||
credentials={
|
||||
"minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
|
||||
"minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
model = MinimaxTextEmbeddingModel()
|
||||
|
||||
result = model.invoke(
|
||||
model="embo-01",
|
||||
credentials={
|
||||
"minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
|
||||
"minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
|
||||
},
|
||||
texts=["hello", "world"],
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
assert len(result.embeddings) == 2
|
||||
assert result.usage.total_tokens == 16
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = MinimaxTextEmbeddingModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model="embo-01",
|
||||
credentials={
|
||||
"minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
|
||||
"minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
|
||||
},
|
||||
texts=["hello", "world"],
|
||||
)
|
||||
|
||||
assert num_tokens == 2
|
||||
@ -1,143 +0,0 @@
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
from time import sleep
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.minimax.llm.llm import MinimaxLargeLanguageModel
|
||||
|
||||
|
||||
def test_predefined_models():
|
||||
model = MinimaxLargeLanguageModel()
|
||||
model_schemas = model.predefined_models()
|
||||
assert len(model_schemas) >= 1
|
||||
assert isinstance(model_schemas[0], AIModelEntity)
|
||||
|
||||
|
||||
def test_validate_credentials_for_chat_model():
|
||||
sleep(3)
|
||||
model = MinimaxLargeLanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model="abab5.5-chat", credentials={"minimax_api_key": "invalid_key", "minimax_group_id": "invalid_key"}
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model="abab5.5-chat",
|
||||
credentials={
|
||||
"minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
|
||||
"minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
sleep(3)
|
||||
model = MinimaxLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model="abab5-chat",
|
||||
credentials={
|
||||
"minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
|
||||
"minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
|
||||
},
|
||||
prompt_messages=[UserPromptMessage(content="Hello World!")],
|
||||
model_parameters={
|
||||
"temperature": 0.7,
|
||||
"top_p": 1.0,
|
||||
"top_k": 1,
|
||||
},
|
||||
stop=["you"],
|
||||
user="abc-123",
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
assert response.usage.total_tokens > 0
|
||||
|
||||
|
||||
def test_invoke_stream_model():
|
||||
sleep(3)
|
||||
model = MinimaxLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model="abab5.5-chat",
|
||||
credentials={
|
||||
"minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
|
||||
"minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
|
||||
},
|
||||
prompt_messages=[UserPromptMessage(content="Hello World!")],
|
||||
model_parameters={
|
||||
"temperature": 0.7,
|
||||
"top_p": 1.0,
|
||||
"top_k": 1,
|
||||
},
|
||||
stop=["you"],
|
||||
stream=True,
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
for chunk in response:
|
||||
assert isinstance(chunk, LLMResultChunk)
|
||||
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
|
||||
|
||||
|
||||
def test_invoke_with_search():
|
||||
sleep(3)
|
||||
model = MinimaxLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model="abab5.5-chat",
|
||||
credentials={
|
||||
"minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
|
||||
"minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
|
||||
},
|
||||
prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")],
|
||||
model_parameters={
|
||||
"temperature": 0.7,
|
||||
"top_p": 1.0,
|
||||
"top_k": 1,
|
||||
"plugin_web_search": True,
|
||||
},
|
||||
stop=["you"],
|
||||
stream=True,
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
total_message = ""
|
||||
for chunk in response:
|
||||
assert isinstance(chunk, LLMResultChunk)
|
||||
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
total_message += chunk.delta.message.content
|
||||
assert len(chunk.delta.message.content) > 0 if not chunk.delta.finish_reason else True
|
||||
|
||||
assert "参考资料" in total_message
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
sleep(3)
|
||||
model = MinimaxLargeLanguageModel()
|
||||
|
||||
response = model.get_num_tokens(
|
||||
model="abab5.5-chat",
|
||||
credentials={
|
||||
"minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
|
||||
"minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
|
||||
},
|
||||
prompt_messages=[UserPromptMessage(content="Hello World!")],
|
||||
tools=[],
|
||||
)
|
||||
|
||||
assert isinstance(response, int)
|
||||
assert response == 30
|
||||
@ -1,25 +0,0 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.minimax.minimax import MinimaxProvider
|
||||
|
||||
|
||||
def test_validate_provider_credentials():
|
||||
provider = MinimaxProvider()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
provider.validate_provider_credentials(
|
||||
credentials={
|
||||
"minimax_api_key": "hahahaha",
|
||||
"minimax_group_id": "123",
|
||||
}
|
||||
)
|
||||
|
||||
provider.validate_provider_credentials(
|
||||
credentials={
|
||||
"minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
|
||||
"minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
|
||||
}
|
||||
)
|
||||
@ -1,28 +0,0 @@
|
||||
import os
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.mixedbread.mixedbread import MixedBreadProvider
|
||||
|
||||
|
||||
def test_validate_provider_credentials():
|
||||
provider = MixedBreadProvider()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
provider.validate_provider_credentials(credentials={"api_key": "hahahaha"})
|
||||
with patch("requests.post") as mock_post:
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {
|
||||
"usage": {"prompt_tokens": 3, "total_tokens": 3},
|
||||
"model": "mixedbread-ai/mxbai-embed-large-v1",
|
||||
"data": [{"embedding": [0.23333 for _ in range(1024)], "index": 0, "object": "embedding"}],
|
||||
"object": "list",
|
||||
"normalized": "true",
|
||||
"encoding_format": "float",
|
||||
"dimensions": 1024,
|
||||
}
|
||||
mock_response.status_code = 200
|
||||
mock_post.return_value = mock_response
|
||||
provider.validate_provider_credentials(credentials={"api_key": os.environ.get("MIXEDBREAD_API_KEY")})
|
||||
@ -1,100 +0,0 @@
|
||||
import os
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.rerank_entities import RerankResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.mixedbread.rerank.rerank import MixedBreadRerankModel
|
||||
|
||||
|
||||
def test_validate_credentials():
|
||||
model = MixedBreadRerankModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model="mxbai-rerank-large-v1",
|
||||
credentials={"api_key": "invalid_key"},
|
||||
)
|
||||
with patch("httpx.post") as mock_post:
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {
|
||||
"usage": {"prompt_tokens": 86, "total_tokens": 86},
|
||||
"model": "mixedbread-ai/mxbai-rerank-large-v1",
|
||||
"data": [
|
||||
{
|
||||
"index": 0,
|
||||
"score": 0.06762695,
|
||||
"input": "Carson City is the capital city of the American state of Nevada. At the 2010 United "
|
||||
"States Census, Carson City had a population of 55,274.",
|
||||
"object": "text_document",
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"score": 0.057403564,
|
||||
"input": "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific "
|
||||
"Ocean that are a political division controlled by the United States. Its capital is "
|
||||
"Saipan.",
|
||||
"object": "text_document",
|
||||
},
|
||||
],
|
||||
"object": "list",
|
||||
"top_k": 2,
|
||||
"return_input": True,
|
||||
}
|
||||
mock_response.status_code = 200
|
||||
mock_post.return_value = mock_response
|
||||
model.validate_credentials(
|
||||
model="mxbai-rerank-large-v1",
|
||||
credentials={
|
||||
"api_key": os.environ.get("MIXEDBREAD_API_KEY"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
model = MixedBreadRerankModel()
|
||||
with patch("httpx.post") as mock_post:
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {
|
||||
"usage": {"prompt_tokens": 56, "total_tokens": 56},
|
||||
"model": "mixedbread-ai/mxbai-rerank-large-v1",
|
||||
"data": [
|
||||
{
|
||||
"index": 0,
|
||||
"score": 0.6044922,
|
||||
"input": "Kasumi is a girl name of Japanese origin meaning mist.",
|
||||
"object": "text_document",
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"score": 0.0703125,
|
||||
"input": "Her music is a kawaii bass, a mix of future bass, pop, and kawaii music and she leads a "
|
||||
"team named PopiParty.",
|
||||
"object": "text_document",
|
||||
},
|
||||
],
|
||||
"object": "list",
|
||||
"top_k": 2,
|
||||
"return_input": "true",
|
||||
}
|
||||
mock_response.status_code = 200
|
||||
mock_post.return_value = mock_response
|
||||
result = model.invoke(
|
||||
model="mxbai-rerank-large-v1",
|
||||
credentials={
|
||||
"api_key": os.environ.get("MIXEDBREAD_API_KEY"),
|
||||
},
|
||||
query="Who is Kasumi?",
|
||||
docs=[
|
||||
"Kasumi is a girl name of Japanese origin meaning mist.",
|
||||
"Her music is a kawaii bass, a mix of future bass, pop, and kawaii music and she leads a team named "
|
||||
"PopiParty.",
|
||||
],
|
||||
score_threshold=0.5,
|
||||
)
|
||||
|
||||
assert isinstance(result, RerankResult)
|
||||
assert len(result.docs) == 1
|
||||
assert result.docs[0].index == 0
|
||||
assert result.docs[0].score >= 0.5
|
||||
@ -1,78 +0,0 @@
|
||||
import os
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.mixedbread.text_embedding.text_embedding import MixedBreadTextEmbeddingModel
|
||||
|
||||
|
||||
def test_validate_credentials():
|
||||
model = MixedBreadTextEmbeddingModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(model="mxbai-embed-large-v1", credentials={"api_key": "invalid_key"})
|
||||
with patch("requests.post") as mock_post:
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {
|
||||
"usage": {"prompt_tokens": 3, "total_tokens": 3},
|
||||
"model": "mixedbread-ai/mxbai-embed-large-v1",
|
||||
"data": [{"embedding": [0.23333 for _ in range(1024)], "index": 0, "object": "embedding"}],
|
||||
"object": "list",
|
||||
"normalized": "true",
|
||||
"encoding_format": "float",
|
||||
"dimensions": 1024,
|
||||
}
|
||||
mock_response.status_code = 200
|
||||
mock_post.return_value = mock_response
|
||||
model.validate_credentials(
|
||||
model="mxbai-embed-large-v1", credentials={"api_key": os.environ.get("MIXEDBREAD_API_KEY")}
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
model = MixedBreadTextEmbeddingModel()
|
||||
|
||||
with patch("requests.post") as mock_post:
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {
|
||||
"usage": {"prompt_tokens": 6, "total_tokens": 6},
|
||||
"model": "mixedbread-ai/mxbai-embed-large-v1",
|
||||
"data": [
|
||||
{"embedding": [0.23333 for _ in range(1024)], "index": 0, "object": "embedding"},
|
||||
{"embedding": [0.23333 for _ in range(1024)], "index": 1, "object": "embedding"},
|
||||
],
|
||||
"object": "list",
|
||||
"normalized": "true",
|
||||
"encoding_format": "float",
|
||||
"dimensions": 1024,
|
||||
}
|
||||
mock_response.status_code = 200
|
||||
mock_post.return_value = mock_response
|
||||
result = model.invoke(
|
||||
model="mxbai-embed-large-v1",
|
||||
credentials={
|
||||
"api_key": os.environ.get("MIXEDBREAD_API_KEY"),
|
||||
},
|
||||
texts=["hello", "world"],
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
assert len(result.embeddings) == 2
|
||||
assert result.usage.total_tokens == 6
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = MixedBreadTextEmbeddingModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model="mxbai-embed-large-v1",
|
||||
credentials={
|
||||
"api_key": os.environ.get("MIXEDBREAD_API_KEY"),
|
||||
},
|
||||
texts=["ping"],
|
||||
)
|
||||
|
||||
assert num_tokens == 1
|
||||
@ -1,62 +0,0 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.nomic.text_embedding.text_embedding import NomicTextEmbeddingModel
|
||||
from tests.integration_tests.model_runtime.__mock.nomic_embeddings import setup_nomic_mock
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_nomic_mock", [["text_embedding"]], indirect=True)
|
||||
def test_validate_credentials(setup_nomic_mock):
|
||||
model = NomicTextEmbeddingModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model="nomic-embed-text-v1.5",
|
||||
credentials={
|
||||
"nomic_api_key": "invalid_key",
|
||||
},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model="nomic-embed-text-v1.5",
|
||||
credentials={
|
||||
"nomic_api_key": os.environ.get("NOMIC_API_KEY"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_nomic_mock", [["text_embedding"]], indirect=True)
|
||||
def test_invoke_model(setup_nomic_mock):
|
||||
model = NomicTextEmbeddingModel()
|
||||
|
||||
result = model.invoke(
|
||||
model="nomic-embed-text-v1.5",
|
||||
credentials={
|
||||
"nomic_api_key": os.environ.get("NOMIC_API_KEY"),
|
||||
},
|
||||
texts=["hello", "world"],
|
||||
user="foo",
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
assert result.model == "nomic-embed-text-v1.5"
|
||||
assert len(result.embeddings) == 2
|
||||
assert result.usage.total_tokens == 2
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_nomic_mock", [["text_embedding"]], indirect=True)
|
||||
def test_get_num_tokens(setup_nomic_mock):
|
||||
model = NomicTextEmbeddingModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model="nomic-embed-text-v1.5",
|
||||
credentials={
|
||||
"nomic_api_key": os.environ.get("NOMIC_API_KEY"),
|
||||
},
|
||||
texts=["hello", "world"],
|
||||
)
|
||||
|
||||
assert num_tokens == 2
|
||||
@ -1,21 +0,0 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.nomic.nomic import NomicAtlasProvider
|
||||
from tests.integration_tests.model_runtime.__mock.nomic_embeddings import setup_nomic_mock
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_nomic_mock", [["text_embedding"]], indirect=True)
|
||||
def test_validate_provider_credentials(setup_nomic_mock):
|
||||
provider = NomicAtlasProvider()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
provider.validate_provider_credentials(credentials={})
|
||||
|
||||
provider.validate_provider_credentials(
|
||||
credentials={
|
||||
"nomic_api_key": os.environ.get("NOMIC_API_KEY"),
|
||||
},
|
||||
)
|
||||
@ -1,98 +0,0 @@
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
SystemPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.novita.llm.llm import NovitaLargeLanguageModel
|
||||
|
||||
|
||||
def test_validate_credentials():
|
||||
model = NovitaLargeLanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model="meta-llama/llama-3-8b-instruct", credentials={"api_key": "invalid_key", "mode": "chat"}
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model="meta-llama/llama-3-8b-instruct",
|
||||
credentials={"api_key": os.environ.get("NOVITA_API_KEY"), "mode": "chat"},
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
model = NovitaLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model="meta-llama/llama-3-8b-instruct",
|
||||
credentials={"api_key": os.environ.get("NOVITA_API_KEY"), "mode": "completion"},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(content="Who are you?"),
|
||||
],
|
||||
model_parameters={
|
||||
"temperature": 1.0,
|
||||
"top_p": 0.5,
|
||||
"max_tokens": 10,
|
||||
},
|
||||
stop=["How"],
|
||||
stream=False,
|
||||
user="novita",
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
|
||||
|
||||
def test_invoke_stream_model():
|
||||
model = NovitaLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model="meta-llama/llama-3-8b-instruct",
|
||||
credentials={"api_key": os.environ.get("NOVITA_API_KEY"), "mode": "chat"},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(content="Who are you?"),
|
||||
],
|
||||
model_parameters={"temperature": 1.0, "top_k": 2, "top_p": 0.5, "max_tokens": 100},
|
||||
stream=True,
|
||||
user="novita",
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
|
||||
for chunk in response:
|
||||
assert isinstance(chunk, LLMResultChunk)
|
||||
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = NovitaLargeLanguageModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model="meta-llama/llama-3-8b-instruct",
|
||||
credentials={
|
||||
"api_key": os.environ.get("NOVITA_API_KEY"),
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
)
|
||||
|
||||
assert isinstance(num_tokens, int)
|
||||
assert num_tokens == 21
|
||||
@ -1,19 +0,0 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.novita.novita import NovitaProvider
|
||||
|
||||
|
||||
def test_validate_provider_credentials():
|
||||
provider = NovitaProvider()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
provider.validate_provider_credentials(credentials={})
|
||||
|
||||
provider.validate_provider_credentials(
|
||||
credentials={
|
||||
"api_key": os.environ.get("NOVITA_API_KEY"),
|
||||
}
|
||||
)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user