Model Runtime (#1858)

Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
Co-authored-by: Garfield Dai <dai.hai@foxmail.com>
Co-authored-by: chenhe <guchenhe@gmail.com>
Co-authored-by: jyong <jyong@dify.ai>
Co-authored-by: Joel <iamjoel007@gmail.com>
Co-authored-by: Yeuoly <admin@srmxy.cn>
This commit is contained in:
takatost
2024-01-02 23:42:00 +08:00
committed by GitHub
parent e91dd28a76
commit d069c668f8
807 changed files with 171310 additions and 23806 deletions

View File

@ -1,7 +1,6 @@
import copy
from core.model_providers.models.entity.model_params import ModelMode
from core.prompt.prompt_transform import AppMode
from core.prompt.advanced_prompt_templates import CHAT_APP_COMPLETION_PROMPT_CONFIG, CHAT_APP_CHAT_PROMPT_CONFIG, COMPLETION_APP_CHAT_PROMPT_CONFIG, COMPLETION_APP_COMPLETION_PROMPT_CONFIG, \
BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG, BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG, BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG, BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG, CONTEXT, BAICHUAN_CONTEXT
@ -25,14 +24,14 @@ class AdvancedPromptTemplateService:
context_prompt = copy.deepcopy(CONTEXT)
if app_mode == AppMode.CHAT.value:
if model_mode == ModelMode.COMPLETION.value:
if model_mode == "completion":
return cls.get_completion_prompt(copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt)
elif model_mode == ModelMode.CHAT.value:
elif model_mode == "chat":
return cls.get_chat_prompt(copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt)
elif app_mode == AppMode.COMPLETION.value:
if model_mode == ModelMode.COMPLETION.value:
if model_mode == "completion":
return cls.get_completion_prompt(copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt)
elif model_mode == ModelMode.CHAT.value:
elif model_mode == "chat":
return cls.get_chat_prompt(copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt)
@classmethod
@ -54,12 +53,12 @@ class AdvancedPromptTemplateService:
baichuan_context_prompt = copy.deepcopy(BAICHUAN_CONTEXT)
if app_mode == AppMode.CHAT.value:
if model_mode == ModelMode.COMPLETION.value:
if model_mode == "completion":
return cls.get_completion_prompt(copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt)
elif model_mode == ModelMode.CHAT.value:
elif model_mode == "chat":
return cls.get_chat_prompt(copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt)
elif app_mode == AppMode.COMPLETION.value:
if model_mode == ModelMode.COMPLETION.value:
if model_mode == "completion":
return cls.get_completion_prompt(copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt)
elif model_mode == ModelMode.CHAT.value:
elif model_mode == "chat":
return cls.get_chat_prompt(copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt)

View File

@ -2,11 +2,12 @@ import re
import uuid
from core.external_data_tool.factory import ExternalDataToolFactory
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from core.model_runtime.model_providers import model_provider_factory
from core.moderation.factory import ModerationFactory
from core.prompt.prompt_transform import AppMode
from core.agent.agent_executor import PlanningStrategy
from core.model_providers.model_provider_factory import ModelProviderFactory
from core.model_providers.models.entity.model_params import ModelType, ModelMode
from core.provider_manager import ProviderManager
from models.account import Account
from services.dataset_service import DatasetService
@ -34,26 +35,6 @@ class AppModelConfigService:
if not isinstance(cp, dict):
raise ValueError("model.completion_params must be of object type")
# max_tokens
if 'max_tokens' not in cp:
cp["max_tokens"] = 512
# temperature
if 'temperature' not in cp:
cp["temperature"] = 1
# top_p
if 'top_p' not in cp:
cp["top_p"] = 1
# presence_penalty
if 'presence_penalty' not in cp:
cp["presence_penalty"] = 0
# presence_penalty
if 'frequency_penalty' not in cp:
cp["frequency_penalty"] = 0
# stop
if 'stop' not in cp:
cp["stop"] = []
@ -63,20 +44,10 @@ class AppModelConfigService:
if len(cp["stop"]) > 4:
raise ValueError("stop sequences must be less than 4")
# Filter out extra parameters
filtered_cp = {
"max_tokens": cp["max_tokens"],
"temperature": cp["temperature"],
"top_p": cp["top_p"],
"presence_penalty": cp["presence_penalty"],
"frequency_penalty": cp["frequency_penalty"],
"stop": cp["stop"]
}
return filtered_cp
return cp
@classmethod
def validate_configuration(cls, tenant_id: str, account: Account, config: dict, mode: str) -> dict:
def validate_configuration(cls, tenant_id: str, account: Account, config: dict, app_mode: str) -> dict:
# opening_statement
if 'opening_statement' not in config or not config["opening_statement"]:
config["opening_statement"] = ""
@ -140,21 +111,6 @@ class AppModelConfigService:
if not isinstance(config["retriever_resource"]["enabled"], bool):
raise ValueError("enabled in retriever_resource must be of boolean type")
# annotation reply
if 'annotation_reply' not in config or not config["annotation_reply"]:
config["annotation_reply"] = {
"enabled": False
}
if not isinstance(config["annotation_reply"], dict):
raise ValueError("annotation_reply must be of dict type")
if "enabled" not in config["annotation_reply"] or not config["annotation_reply"]["enabled"]:
config["annotation_reply"]["enabled"] = False
if not isinstance(config["annotation_reply"]["enabled"], bool):
raise ValueError("enabled in annotation_reply must be of boolean type")
# more_like_this
if 'more_like_this' not in config or not config["more_like_this"]:
config["more_like_this"] = {
@ -178,7 +134,8 @@ class AppModelConfigService:
raise ValueError("model must be of object type")
# model.provider
model_provider_names = ModelProviderFactory.get_provider_names()
provider_entities = model_provider_factory.get_providers()
model_provider_names = [provider.provider for provider in provider_entities]
if 'provider' not in config["model"] or config["model"]["provider"] not in model_provider_names:
raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")
@ -186,18 +143,29 @@ class AppModelConfigService:
if 'name' not in config["model"]:
raise ValueError("model.name is required")
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, config["model"]["provider"])
if not model_provider:
provider_manager = ProviderManager()
models = provider_manager.get_configurations(tenant_id).get_models(
provider=config["model"]["provider"],
model_type=ModelType.LLM
)
if not models:
raise ValueError("model.name must be in the specified model list")
model_list = model_provider.get_supported_model_list(ModelType.TEXT_GENERATION)
model_ids = [m['id'] for m in model_list]
model_ids = [m.model for m in models]
if config["model"]["name"] not in model_ids:
raise ValueError("model.name must be in the specified model list")
model_mode = None
for model in models:
if model.model == config["model"]["name"]:
model_mode = model.model_properties.get(ModelPropertyKey.MODE)
break
# model.mode
if 'mode' not in config['model'] or not config['model']["mode"]:
config['model']["mode"] = ""
if model_mode:
config['model']["mode"] = model_mode
else:
config['model']["mode"] = "completion"
# model.completion_params
if 'completion_params' not in config["model"]:
@ -319,10 +287,10 @@ class AppModelConfigService:
raise ValueError("Dataset ID does not exist, please check your permission.")
# dataset_query_variable
cls.is_dataset_query_variable_valid(config, mode)
cls.is_dataset_query_variable_valid(config, app_mode)
# advanced prompt validation
cls.is_advanced_prompt_valid(config, mode)
cls.is_advanced_prompt_valid(config, app_mode)
# external data tools validation
cls.is_external_data_tools_valid(tenant_id, config)
@ -340,7 +308,6 @@ class AppModelConfigService:
"suggested_questions_after_answer": config["suggested_questions_after_answer"],
"speech_to_text": config["speech_to_text"],
"retriever_resource": config["retriever_resource"],
"annotation_reply": config["annotation_reply"],
"more_like_this": config["more_like_this"],
"sensitive_word_avoidance": config["sensitive_word_avoidance"],
"external_data_tools": config["external_data_tools"],
@ -507,7 +474,7 @@ class AppModelConfigService:
if config['model']["mode"] not in ['chat', 'completion']:
raise ValueError("model.mode must be in ['chat', 'completion'] when prompt_type is advanced")
if app_mode == AppMode.CHAT.value and config['model']["mode"] == ModelMode.COMPLETION.value:
if app_mode == AppMode.CHAT.value and config['model']["mode"] == "completion":
user_prefix = config['completion_prompt_config']['conversation_histories_role']['user_prefix']
assistant_prefix = config['completion_prompt_config']['conversation_histories_role']['assistant_prefix']
@ -517,7 +484,7 @@ class AppModelConfigService:
if not assistant_prefix:
config['completion_prompt_config']['conversation_histories_role']['assistant_prefix'] = 'Assistant'
if config['model']["mode"] == ModelMode.CHAT.value:
if config['model']["mode"] == "chat":
prompt_list = config['chat_prompt_config']['prompt']
if len(prompt_list) > 10:

View File

@ -1,6 +1,8 @@
import io
from werkzeug.datastructures import FileStorage
from core.model_providers.model_factory import ModelFactory
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, UnsupportedAudioTypeServiceError, ProviderNotSupportSpeechToTextServiceError
FILE_SIZE = 15
@ -25,11 +27,13 @@ class AudioService:
message = f"Audio size larger than {FILE_SIZE} mb"
raise AudioTooLargeServiceError(message)
model = ModelFactory.get_speech2text_model(
tenant_id=tenant_id
model_manager = ModelManager()
model_instance = model_manager.get_default_model_instance(
tenant_id=tenant_id,
model_type=ModelType.SPEECH2TEXT
)
buffer = io.BytesIO(file_content)
buffer.name = 'temp.mp3'
return model.run(buffer)
return {"text": model_instance.invoke_speech2text(buffer)}

View File

@ -1,29 +1,16 @@
import json
import logging
import threading
import time
import uuid
from typing import Generator, Union, Any, Optional, List
from typing import Generator, Union, Any
from flask import current_app, Flask
from redis.client import PubSub
from sqlalchemy import and_
from core.completion import Completion
from core.conversation_message_task import PubHandler, ConversationTaskStoppedException, \
ConversationTaskInterruptException
from core.application_manager import ApplicationManager
from core.entities.application_entities import InvokeFrom
from core.file.message_file_parser import MessageFileParser
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
LLMRateLimitError, \
LLMAuthorizationError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.model_providers.models.entity.message import PromptMessageFile
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.model import Conversation, AppModelConfig, App, Account, EndUser, Message
from services.app_model_config_service import AppModelConfigService
from services.errors.app import MoreLikeThisDisabledError
from services.errors.app_model_config import AppModelConfigBrokenError
from services.errors.completion import CompletionStoppedError
from services.errors.conversation import ConversationNotExistsError, ConversationCompletedError
from services.errors.message import MessageNotExistsError
@ -32,7 +19,7 @@ class CompletionService:
@classmethod
def completion(cls, app_model: App, user: Union[Account, EndUser], args: Any,
from_source: str, streaming: bool = True,
invoke_from: InvokeFrom, streaming: bool = True,
is_model_config_override: bool = False) -> Union[dict, Generator]:
# is streaming mode
inputs = args['inputs']
@ -56,7 +43,7 @@ class CompletionService:
Conversation.status == 'normal'
]
if from_source == 'console':
if isinstance(user, Account):
conversation_filter.append(Conversation.from_account_id == user.id)
else:
conversation_filter.append(Conversation.from_end_user_id == user.id if user else None)
@ -124,7 +111,7 @@ class CompletionService:
tenant_id=app_model.tenant_id,
account=user,
config=args['model_config'],
mode=app_model.mode
app_mode=app_model.mode
)
app_model_config = AppModelConfig(
@ -145,134 +132,29 @@ class CompletionService:
user
)
generate_task_id = str(uuid.uuid4())
pubsub = redis_client.pubsub()
pubsub.subscribe(PubHandler.generate_channel_name(user, generate_task_id))
user = cls.get_real_user_instead_of_proxy_obj(user)
generate_worker_thread = threading.Thread(target=cls.generate_worker, kwargs={
'flask_app': current_app._get_current_object(),
'generate_task_id': generate_task_id,
'detached_app_model': app_model,
'app_model_config': app_model_config.copy(),
'query': query,
'inputs': inputs,
'files': file_objs,
'detached_user': user,
'detached_conversation': conversation,
'streaming': streaming,
'is_model_config_override': is_model_config_override,
'retriever_from': args['retriever_from'] if 'retriever_from' in args else 'dev',
'auto_generate_name': auto_generate_name,
'from_source': from_source
})
generate_worker_thread.start()
# wait for 10 minutes to close the thread
cls.countdown_and_close(current_app._get_current_object(), generate_worker_thread, pubsub, user,
generate_task_id)
return cls.compact_response(pubsub, streaming)
@classmethod
def get_real_user_instead_of_proxy_obj(cls, user: Union[Account, EndUser]):
if isinstance(user, Account):
user = db.session.query(Account).filter(Account.id == user.id).first()
elif isinstance(user, EndUser):
user = db.session.query(EndUser).filter(EndUser.id == user.id).first()
else:
raise Exception("Unknown user type")
return user
@classmethod
def generate_worker(cls, flask_app: Flask, generate_task_id: str, detached_app_model: App,
app_model_config: AppModelConfig,
query: str, inputs: dict, files: List[PromptMessageFile],
detached_user: Union[Account, EndUser],
detached_conversation: Optional[Conversation], streaming: bool, is_model_config_override: bool,
retriever_from: str = 'dev', auto_generate_name: bool = True, from_source: str = 'console'):
with flask_app.app_context():
# fixed the state of the model object when it detached from the original session
user = db.session.merge(detached_user)
app_model = db.session.merge(detached_app_model)
if detached_conversation:
conversation = db.session.merge(detached_conversation)
else:
conversation = None
try:
# run
Completion.generate(
task_id=generate_task_id,
app=app_model,
app_model_config=app_model_config,
query=query,
inputs=inputs,
user=user,
files=files,
conversation=conversation,
streaming=streaming,
is_override=is_model_config_override,
retriever_from=retriever_from,
auto_generate_name=auto_generate_name,
from_source=from_source
)
except (ConversationTaskInterruptException, ConversationTaskStoppedException):
pass
except (ValueError, LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError,
ModelCurrentlyNotSupportError) as e:
PubHandler.pub_error(user, generate_task_id, e)
except LLMAuthorizationError:
PubHandler.pub_error(user, generate_task_id, LLMAuthorizationError('Incorrect API key provided'))
except Exception as e:
logging.exception("Unknown Error in completion")
PubHandler.pub_error(user, generate_task_id, e)
finally:
db.session.remove()
@classmethod
def countdown_and_close(cls, flask_app: Flask, worker_thread, pubsub, detached_user,
generate_task_id) -> threading.Thread:
# wait for 10 minutes to close the thread
timeout = 600
def close_pubsub():
with flask_app.app_context():
try:
user = db.session.merge(detached_user)
sleep_iterations = 0
while sleep_iterations < timeout and worker_thread.is_alive():
if sleep_iterations > 0 and sleep_iterations % 10 == 0:
PubHandler.ping(user, generate_task_id)
time.sleep(1)
sleep_iterations += 1
if worker_thread.is_alive():
PubHandler.stop(user, generate_task_id)
try:
pubsub.close()
except Exception:
pass
finally:
db.session.remove()
countdown_thread = threading.Thread(target=close_pubsub)
countdown_thread.start()
return countdown_thread
application_manager = ApplicationManager()
return application_manager.generate(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
app_model_config_id=app_model_config.id,
app_model_config_dict=app_model_config.to_dict(),
app_model_config_override=is_model_config_override,
user=user,
invoke_from=invoke_from,
inputs=inputs,
query=query,
files=file_objs,
conversation=conversation,
stream=streaming,
extras={
"auto_generate_conversation_name": auto_generate_name
}
)
@classmethod
def generate_more_like_this(cls, app_model: App, user: Union[Account, EndUser],
message_id: str, streaming: bool = True,
retriever_from: str = 'dev') -> Union[dict, Generator]:
message_id: str, invoke_from: InvokeFrom, streaming: bool = True) \
-> Union[dict, Generator]:
if not user:
raise ValueError('user cannot be None')
@ -306,36 +188,24 @@ class CompletionService:
message.files, app_model_config
)
generate_task_id = str(uuid.uuid4())
pubsub = redis_client.pubsub()
pubsub.subscribe(PubHandler.generate_channel_name(user, generate_task_id))
user = cls.get_real_user_instead_of_proxy_obj(user)
generate_worker_thread = threading.Thread(target=cls.generate_worker, kwargs={
'flask_app': current_app._get_current_object(),
'generate_task_id': generate_task_id,
'detached_app_model': app_model,
'app_model_config': app_model_config.copy(),
'query': message.query,
'inputs': message.inputs,
'files': file_objs,
'detached_user': user,
'detached_conversation': None,
'streaming': streaming,
'is_model_config_override': True,
'retriever_from': retriever_from,
'auto_generate_name': False
})
generate_worker_thread.start()
# wait for 10 minutes to close the thread
cls.countdown_and_close(current_app._get_current_object(), generate_worker_thread, pubsub, user,
generate_task_id)
return cls.compact_response(pubsub, streaming)
application_manager = ApplicationManager()
return application_manager.generate(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
app_model_config_id=app_model_config.id,
app_model_config_dict=app_model_config.to_dict(),
app_model_config_override=True,
user=user,
invoke_from=invoke_from,
inputs=message.inputs,
query=message.query,
files=file_objs,
conversation=None,
stream=streaming,
extras={
"auto_generate_conversation_name": False
}
)
@classmethod
def get_cleaned_inputs(cls, user_inputs: dict, app_model_config: AppModelConfig):
@ -375,247 +245,3 @@ class CompletionService:
return filtered_inputs
@classmethod
def compact_response(cls, pubsub: PubSub, streaming: bool = False) -> Union[dict, Generator]:
generate_channel = list(pubsub.channels.keys())[0].decode('utf-8')
if not streaming:
try:
message_result = {}
for message in pubsub.listen():
if message["type"] == "message":
result = message["data"].decode('utf-8')
result = json.loads(result)
if result.get('error'):
cls.handle_error(result)
if result['event'] == 'annotation' and 'data' in result:
message_result['annotation'] = result.get('data')
return cls.get_blocking_annotation_message_response_data(message_result)
if result['event'] == 'message' and 'data' in result:
message_result['message'] = result.get('data')
if result['event'] == 'message_end' and 'data' in result:
message_result['message_end'] = result.get('data')
return cls.get_blocking_message_response_data(message_result)
except ValueError as e:
if e.args[0] != "I/O operation on closed file.": # ignore this error
raise CompletionStoppedError()
else:
logging.exception(e)
raise
finally:
db.session.remove()
try:
pubsub.unsubscribe(generate_channel)
except ConnectionError:
pass
else:
def generate() -> Generator:
try:
for message in pubsub.listen():
if message["type"] == "message":
result = message["data"].decode('utf-8')
result = json.loads(result)
if result.get('error'):
cls.handle_error(result)
event = result.get('event')
if event == "end":
logging.debug("{} finished".format(generate_channel))
break
if event == 'message':
yield "data: " + json.dumps(cls.get_message_response_data(result.get('data'))) + "\n\n"
elif event == 'message_replace':
yield "data: " + json.dumps(
cls.get_message_replace_response_data(result.get('data'))) + "\n\n"
elif event == 'chain':
yield "data: " + json.dumps(cls.get_chain_response_data(result.get('data'))) + "\n\n"
elif event == 'agent_thought':
yield "data: " + json.dumps(
cls.get_agent_thought_response_data(result.get('data'))) + "\n\n"
elif event == 'annotation':
yield "data: " + json.dumps(
cls.get_annotation_response_data(result.get('data'))) + "\n\n"
elif event == 'message_end':
yield "data: " + json.dumps(
cls.get_message_end_data(result.get('data'))) + "\n\n"
elif event == 'ping':
yield "event: ping\n\n"
else:
yield "data: " + json.dumps(result) + "\n\n"
except ValueError as e:
if e.args[0] != "I/O operation on closed file.": # ignore this error
logging.exception(e)
raise
finally:
db.session.remove()
try:
pubsub.unsubscribe(generate_channel)
except ConnectionError:
pass
return generate()
@classmethod
def get_message_response_data(cls, data: dict):
response_data = {
'event': 'message',
'task_id': data.get('task_id'),
'id': data.get('message_id'),
'answer': data.get('text'),
'created_at': int(time.time())
}
if data.get('mode') == 'chat':
response_data['conversation_id'] = data.get('conversation_id')
return response_data
@classmethod
def get_message_replace_response_data(cls, data: dict):
response_data = {
'event': 'message_replace',
'task_id': data.get('task_id'),
'id': data.get('message_id'),
'answer': data.get('text'),
'created_at': int(time.time())
}
if data.get('mode') == 'chat':
response_data['conversation_id'] = data.get('conversation_id')
return response_data
@classmethod
def get_blocking_message_response_data(cls, data: dict):
message = data.get('message')
response_data = {
'event': 'message',
'task_id': message.get('task_id'),
'id': message.get('message_id'),
'answer': message.get('text'),
'metadata': {},
'created_at': int(time.time())
}
if message.get('mode') == 'chat':
response_data['conversation_id'] = message.get('conversation_id')
if 'message_end' in data:
message_end = data.get('message_end')
if 'retriever_resources' in message_end:
response_data['metadata']['retriever_resources'] = message_end.get('retriever_resources')
return response_data
@classmethod
def get_blocking_annotation_message_response_data(cls, data: dict):
message = data.get('annotation')
response_data = {
'event': 'annotation',
'task_id': message.get('task_id'),
'id': message.get('message_id'),
'answer': message.get('text'),
'metadata': {},
'created_at': int(time.time()),
'annotation_id': message.get('annotation_id'),
'annotation_author_name': message.get('annotation_author_name')
}
if message.get('mode') == 'chat':
response_data['conversation_id'] = message.get('conversation_id')
return response_data
@classmethod
def get_message_end_data(cls, data: dict):
response_data = {
'event': 'message_end',
'task_id': data.get('task_id'),
'id': data.get('message_id')
}
if 'retriever_resources' in data:
response_data['retriever_resources'] = data.get('retriever_resources')
if data.get('mode') == 'chat':
response_data['conversation_id'] = data.get('conversation_id')
return response_data
@classmethod
def get_chain_response_data(cls, data: dict):
response_data = {
'event': 'chain',
'id': data.get('chain_id'),
'task_id': data.get('task_id'),
'message_id': data.get('message_id'),
'type': data.get('type'),
'input': data.get('input'),
'output': data.get('output'),
'created_at': int(time.time())
}
if data.get('mode') == 'chat':
response_data['conversation_id'] = data.get('conversation_id')
return response_data
@classmethod
def get_agent_thought_response_data(cls, data: dict):
response_data = {
'event': 'agent_thought',
'id': data.get('id'),
'chain_id': data.get('chain_id'),
'task_id': data.get('task_id'),
'message_id': data.get('message_id'),
'position': data.get('position'),
'thought': data.get('thought'),
'tool': data.get('tool'),
'tool_input': data.get('tool_input'),
'created_at': int(time.time())
}
if data.get('mode') == 'chat':
response_data['conversation_id'] = data.get('conversation_id')
return response_data
@classmethod
def get_annotation_response_data(cls, data: dict):
response_data = {
'event': 'annotation',
'task_id': data.get('task_id'),
'id': data.get('message_id'),
'answer': data.get('text'),
'created_at': int(time.time()),
'annotation_id': data.get('annotation_id'),
'annotation_author_name': data.get('annotation_author_name'),
}
if data.get('mode') == 'chat':
response_data['conversation_id'] = data.get('conversation_id')
return response_data
@classmethod
def handle_error(cls, result: dict):
logging.debug("error: %s", result)
error = result.get('error')
description = result.get('description')
# handle errors
llm_errors = {
'ValueError': LLMBadRequestError,
'LLMBadRequestError': LLMBadRequestError,
'LLMAPIConnectionError': LLMAPIConnectionError,
'LLMAPIUnavailableError': LLMAPIUnavailableError,
'LLMRateLimitError': LLMRateLimitError,
'ProviderTokenNotInitError': ProviderTokenNotInitError,
'QuotaExceededError': QuotaExceededError,
'ModelCurrentlyNotSupportError': ModelCurrentlyNotSupportError
}
if error in llm_errors:
raise llm_errors[error](description)
elif error == 'LLMAuthorizationError':
raise LLMAuthorizationError('Incorrect API key provided')
else:
raise Exception(description)

View File

@ -4,14 +4,16 @@ import datetime
import time
import random
import uuid
from typing import Optional, List
from typing import Optional, List, cast
from flask import current_app
from sqlalchemy import func
from core.index.index import IndexBuilder
from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_providers.model_factory import ModelFactory
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from extensions.ext_redis import redis_client
from flask_login import current_user
@ -92,16 +94,18 @@ class DatasetService:
f'Dataset with name {name} already exists.')
embedding_model = None
if indexing_technique == 'high_quality':
embedding_model = ModelFactory.get_embedding_model(
tenant_id=tenant_id
model_manager = ModelManager()
embedding_model = model_manager.get_default_model_instance(
tenant_id=tenant_id,
model_type=ModelType.TEXT_EMBEDDING
)
dataset = Dataset(name=name, indexing_technique=indexing_technique)
# dataset = Dataset(name=name, provider=provider, config=config)
dataset.created_by = account.id
dataset.updated_by = account.id
dataset.tenant_id = tenant_id
dataset.embedding_model_provider = embedding_model.model_provider.provider_name if embedding_model else None
dataset.embedding_model = embedding_model.name if embedding_model else None
dataset.embedding_model_provider = embedding_model.provider if embedding_model else None
dataset.embedding_model = embedding_model.model if embedding_model else None
db.session.add(dataset)
db.session.commit()
return dataset
@ -120,10 +124,12 @@ class DatasetService:
def check_dataset_model_setting(dataset):
if dataset.indexing_technique == 'high_quality':
try:
ModelFactory.get_embedding_model(
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model
)
except LLMBadRequestError:
raise ValueError(
@ -150,14 +156,16 @@ class DatasetService:
action = 'add'
# get embedding model setting
try:
embedding_model = ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id
model_manager = ModelManager()
embedding_model = model_manager.get_default_model_instance(
tenant_id=current_user.current_tenant_id,
model_type=ModelType.TEXT_EMBEDDING
)
filtered_data['embedding_model'] = embedding_model.name
filtered_data['embedding_model_provider'] = embedding_model.model_provider.provider_name
filtered_data['embedding_model'] = embedding_model.model
filtered_data['embedding_model_provider'] = embedding_model.provider
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_model.model_provider.provider_name,
embedding_model.name
embedding_model.provider,
embedding_model.model
)
filtered_data['collection_binding_id'] = dataset_collection_binding.id
except LLMBadRequestError:
@ -458,14 +466,16 @@ class DocumentService:
dataset.indexing_technique = document_data["indexing_technique"]
if document_data["indexing_technique"] == 'high_quality':
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id
model_manager = ModelManager()
embedding_model = model_manager.get_default_model_instance(
tenant_id=current_user.current_tenant_id,
model_type=ModelType.TEXT_EMBEDDING
)
dataset.embedding_model = embedding_model.name
dataset.embedding_model_provider = embedding_model.model_provider.provider_name
dataset.embedding_model = embedding_model.model
dataset.embedding_model_provider = embedding_model.provider
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_model.model_provider.provider_name,
embedding_model.name
embedding_model.provider,
embedding_model.model
)
dataset.collection_binding_id = dataset_collection_binding.id
if not dataset.retrieval_model:
@ -737,12 +747,14 @@ class DocumentService:
dataset_collection_binding_id = None
retrieval_model = None
if document_data['indexing_technique'] == 'high_quality':
embedding_model = ModelFactory.get_embedding_model(
tenant_id=tenant_id
model_manager = ModelManager()
embedding_model = model_manager.get_default_model_instance(
tenant_id=current_user.current_tenant_id,
model_type=ModelType.TEXT_EMBEDDING
)
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_model.model_provider.provider_name,
embedding_model.name
embedding_model.provider,
embedding_model.model
)
dataset_collection_binding_id = dataset_collection_binding.id
if 'retrieval_model' in document_data and document_data['retrieval_model']:
@ -766,8 +778,8 @@ class DocumentService:
data_source_type=document_data["data_source"]["type"],
indexing_technique=document_data["indexing_technique"],
created_by=account.id,
embedding_model=embedding_model.name if embedding_model else None,
embedding_model_provider=embedding_model.model_provider.provider_name if embedding_model else None,
embedding_model=embedding_model.model if embedding_model else None,
embedding_model_provider=embedding_model.provider if embedding_model else None,
collection_binding_id=dataset_collection_binding_id,
retrieval_model=retrieval_model
)
@ -989,13 +1001,20 @@ class SegmentService:
segment_hash = helper.generate_text_hash(content)
tokens = 0
if dataset.indexing_technique == 'high_quality':
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model
)
# calc embedding use tokens
tokens = embedding_model.get_num_tokens(content)
model_type_instance = cast(TextEmbeddingModel, embedding_model.model_type_instance)
tokens = model_type_instance.get_num_tokens(
model=embedding_model.model,
credentials=embedding_model.credentials,
texts=[content]
)
max_position = db.session.query(func.max(DocumentSegment.position)).filter(
DocumentSegment.document_id == document.id
).scalar()
@ -1037,10 +1056,12 @@ class SegmentService:
def multi_create_segment(cls, segments: list, document: Document, dataset: Dataset):
embedding_model = None
if dataset.indexing_technique == 'high_quality':
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model
)
max_position = db.session.query(func.max(DocumentSegment.position)).filter(
DocumentSegment.document_id == document.id
@ -1054,7 +1075,12 @@ class SegmentService:
tokens = 0
if dataset.indexing_technique == 'high_quality' and embedding_model:
# calc embedding use tokens
tokens = embedding_model.get_num_tokens(content)
model_type_instance = cast(TextEmbeddingModel, embedding_model.model_type_instance)
tokens = model_type_instance.get_num_tokens(
model=embedding_model.model,
credentials=embedding_model.credentials,
texts=[content]
)
segment_document = DocumentSegment(
tenant_id=current_user.current_tenant_id,
dataset_id=document.dataset_id,
@ -1121,14 +1147,21 @@ class SegmentService:
segment_hash = helper.generate_text_hash(content)
tokens = 0
if dataset.indexing_technique == 'high_quality':
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model
)
# calc embedding use tokens
tokens = embedding_model.get_num_tokens(content)
model_type_instance = cast(TextEmbeddingModel, embedding_model.model_type_instance)
tokens = model_type_instance.get_num_tokens(
model=embedding_model.model,
credentials=embedding_model.credentials,
texts=[content]
)
segment.content = content
segment.index_node_hash = segment_hash
segment.word_count = len(content)

View File

View File

@ -0,0 +1,152 @@
from enum import Enum
from typing import Optional
from flask import current_app
from pydantic import BaseModel
from core.entities.model_entities import ModelStatus, ModelWithProviderEntity
from core.entities.provider_entities import QuotaConfiguration
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import ModelType, ProviderModel
from core.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderCredentialSchema, \
ModelCredentialSchema, ProviderHelpEntity, SimpleProviderEntity
from models.provider import ProviderType, ProviderQuotaType
class CustomConfigurationStatus(Enum):
"""
Enum class for custom configuration status.
"""
ACTIVE = 'active'
NO_CONFIGURE = 'no-configure'
class CustomConfigurationResponse(BaseModel):
"""
Model class for provider custom configuration response.
"""
status: CustomConfigurationStatus
class SystemConfigurationResponse(BaseModel):
"""
Model class for provider system configuration response.
"""
enabled: bool
current_quota_type: Optional[ProviderQuotaType] = None
quota_configurations: list[QuotaConfiguration] = []
class ProviderResponse(BaseModel):
"""
Model class for provider response.
"""
provider: str
label: I18nObject
description: Optional[I18nObject] = None
icon_small: Optional[I18nObject] = None
icon_large: Optional[I18nObject] = None
background: Optional[str] = None
help: Optional[ProviderHelpEntity] = None
supported_model_types: list[ModelType]
configurate_methods: list[ConfigurateMethod]
provider_credential_schema: Optional[ProviderCredentialSchema] = None
model_credential_schema: Optional[ModelCredentialSchema] = None
preferred_provider_type: ProviderType
custom_configuration: CustomConfigurationResponse
system_configuration: SystemConfigurationResponse
def __init__(self, **data) -> None:
super().__init__(**data)
url_prefix = (current_app.config.get("CONSOLE_API_URL")
+ f"/console/api/workspaces/current/model-providers/{self.provider}")
if self.icon_small is not None:
self.icon_small = I18nObject(
en_US=f"{url_prefix}/icon_small/en_US",
zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
)
if self.icon_large is not None:
self.icon_large = I18nObject(
en_US=f"{url_prefix}/icon_large/en_US",
zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
)
class ModelResponse(ProviderModel):
"""
Model class for model response.
"""
status: ModelStatus
class ProviderWithModelsResponse(BaseModel):
"""
Model class for provider with models response.
"""
provider: str
label: I18nObject
icon_small: Optional[I18nObject] = None
icon_large: Optional[I18nObject] = None
status: CustomConfigurationStatus
models: list[ModelResponse]
def __init__(self, **data) -> None:
super().__init__(**data)
url_prefix = (current_app.config.get("CONSOLE_API_URL")
+ f"/console/api/workspaces/current/model-providers/{self.provider}")
if self.icon_small is not None:
self.icon_small = I18nObject(
en_US=f"{url_prefix}/icon_small/en_US",
zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
)
if self.icon_large is not None:
self.icon_large = I18nObject(
en_US=f"{url_prefix}/icon_large/en_US",
zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
)
class SimpleProviderEntityResponse(SimpleProviderEntity):
"""
Simple provider entity response.
"""
def __init__(self, **data) -> None:
super().__init__(**data)
url_prefix = (current_app.config.get("CONSOLE_API_URL")
+ f"/console/api/workspaces/current/model-providers/{self.provider}")
if self.icon_small is not None:
self.icon_small = I18nObject(
en_US=f"{url_prefix}/icon_small/en_US",
zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
)
if self.icon_large is not None:
self.icon_large = I18nObject(
en_US=f"{url_prefix}/icon_large/en_US",
zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
)
class DefaultModelResponse(BaseModel):
"""
Default model entity.
"""
model: str
model_type: ModelType
provider: SimpleProviderEntityResponse
class ModelWithProviderEntityResponse(ModelWithProviderEntity):
"""
Model with provider entity.
"""
provider: SimpleProviderEntityResponse
def __init__(self, model: ModelWithProviderEntity) -> None:
super().__init__(**model.dict())

View File

@ -1,4 +1,3 @@
import json
import logging
import threading
import time
@ -11,7 +10,9 @@ from langchain.schema import Document
from sklearn.manifold import TSNE
from core.embedding.cached_embedding import CacheEmbedding
from core.model_providers.model_factory import ModelFactory
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rerank.rerank import RerankRunner
from extensions.ext_database import db
from models.account import Account
from models.dataset import Dataset, DocumentSegment, DatasetQuery
@ -47,11 +48,14 @@ class HitTestingService:
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
# get embedding model
embedding_model = ModelFactory.get_embedding_model(
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
model_type=ModelType.TEXT_EMBEDDING,
provider=dataset.embedding_model_provider,
model=dataset.embedding_model
)
embeddings = CacheEmbedding(embedding_model)
all_documents = []
@ -93,14 +97,22 @@ class HitTestingService:
thread.join()
if retrieval_model['search_method'] == 'hybrid_search':
hybrid_rerank = ModelFactory.get_reranking_model(
model_manager = ModelManager()
rerank_model_instance = model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
model_provider_name=retrieval_model['reranking_model']['reranking_provider_name'],
model_name=retrieval_model['reranking_model']['reranking_model_name']
provider=retrieval_model['reranking_model']['reranking_provider_name'],
model_type=ModelType.RERANK,
model=retrieval_model['reranking_model']['reranking_model_name']
)
rerank_runner = RerankRunner(rerank_model_instance)
all_documents = rerank_runner.run(
query=query,
documents=all_documents,
score_threshold=retrieval_model['score_threshold'] if retrieval_model['score_threshold_enabled'] else None,
top_n=retrieval_model['top_k'],
user=f"account-{account.id}"
)
all_documents = hybrid_rerank.rerank(query, all_documents,
retrieval_model['score_threshold'] if retrieval_model['score_threshold_enabled'] else None,
retrieval_model['top_k'])
end = time.perf_counter()
logging.debug(f"Hit testing retrieve in {end - start:0.4f} seconds")

View File

@ -1,8 +1,10 @@
import json
from typing import Optional, Union, List
from core.completion import Completion
from core.generator.llm_generator import LLMGenerator
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from extensions.ext_database import db
from models.account import Account
@ -216,21 +218,27 @@ class MessageService:
raise SuggestedQuestionsAfterAnswerDisabledError()
# get memory of conversation (read-only)
memory = Completion.get_memory_from_conversation(
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=app_model.tenant_id,
app_model_config=app_model_config,
conversation=conversation,
max_token_limit=3000,
message_limit=3,
return_messages=False,
memory_key="histories"
provider=app_model_config.model_dict['provider'],
model_type=ModelType.LLM,
model=app_model_config.model_dict['name']
)
external_context = memory.load_memory_variables({})
memory = TokenBufferMemory(
conversation=conversation,
model_instance=model_instance
)
histories = memory.get_history_prompt_text(
max_token_limit=3000,
message_limit=3,
)
questions = LLMGenerator.generate_suggested_questions_after_answer(
tenant_id=app_model.tenant_id,
**external_context
histories=histories
)
return questions

View File

@ -0,0 +1,530 @@
import logging
import mimetypes
import os
from typing import Optional, cast, Tuple
import requests
from flask import current_app
from core.entities.model_entities import ModelWithProviderEntity, ModelStatus, DefaultModelEntity
from core.model_runtime.entities.model_entities import ModelType, ParameterRule
from core.model_runtime.model_providers import model_provider_factory
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.provider_manager import ProviderManager
from models.provider import ProviderType
from services.entities.model_provider_entities import ProviderResponse, CustomConfigurationResponse, \
SystemConfigurationResponse, CustomConfigurationStatus, ProviderWithModelsResponse, ModelResponse, \
DefaultModelResponse, ModelWithProviderEntityResponse
logger = logging.getLogger(__name__)
class ModelProviderService:
"""
Model Provider Service
"""
def __init__(self) -> None:
self.provider_manager = ProviderManager()
def get_provider_list(self, tenant_id: str, model_type: Optional[str] = None) -> list[ProviderResponse]:
"""
get provider list.
:param tenant_id: workspace id
:param model_type: model type
:return:
"""
# Get all provider configurations of the current workspace
provider_configurations = self.provider_manager.get_configurations(tenant_id)
provider_responses = []
for provider_configuration in provider_configurations.values():
if model_type:
model_type_entity = ModelType.value_of(model_type)
if model_type_entity not in provider_configuration.provider.supported_model_types:
continue
provider_response = ProviderResponse(
**provider_configuration.provider.dict(),
preferred_provider_type=provider_configuration.preferred_provider_type,
custom_configuration=CustomConfigurationResponse(
status=CustomConfigurationStatus.ACTIVE
if provider_configuration.is_custom_configuration_available()
else CustomConfigurationStatus.NO_CONFIGURE
),
system_configuration=SystemConfigurationResponse(
**provider_configuration.system_configuration.dict()
)
)
provider_responses.append(provider_response)
return provider_responses
def get_models_by_provider(self, tenant_id: str, provider: str) -> list[ModelWithProviderEntityResponse]:
"""
get provider models.
For the model provider page,
only supports passing in a single provider to query the list of supported models.
:param tenant_id:
:param provider:
:return:
"""
# Get all provider configurations of the current workspace
provider_configurations = self.provider_manager.get_configurations(tenant_id)
# Get provider available models
return [ModelWithProviderEntityResponse(model) for model in provider_configurations.get_models(
provider=provider
)]
def get_provider_credentials(self, tenant_id: str, provider: str) -> dict:
"""
get provider credentials.
:param tenant_id:
:param provider:
:return:
"""
# Get all provider configurations of the current workspace
provider_configurations = self.provider_manager.get_configurations(tenant_id)
# Get provider configuration
provider_configuration = provider_configurations.get(provider)
if not provider_configuration:
raise ValueError(f"Provider {provider} does not exist.")
# Get provider custom credentials from workspace
return provider_configuration.get_custom_credentials(obfuscated=True)
def provider_credentials_validate(self, tenant_id: str, provider: str, credentials: dict) -> None:
"""
validate provider credentials.
:param tenant_id:
:param provider:
:param credentials:
"""
# Get all provider configurations of the current workspace
provider_configurations = self.provider_manager.get_configurations(tenant_id)
# Get provider configuration
provider_configuration = provider_configurations.get(provider)
if not provider_configuration:
raise ValueError(f"Provider {provider} does not exist.")
provider_configuration.custom_credentials_validate(credentials)
def save_provider_credentials(self, tenant_id: str, provider: str, credentials: dict) -> None:
"""
save custom provider config.
:param tenant_id: workspace id
:param provider: provider name
:param credentials: provider credentials
:return:
"""
# Get all provider configurations of the current workspace
provider_configurations = self.provider_manager.get_configurations(tenant_id)
# Get provider configuration
provider_configuration = provider_configurations.get(provider)
if not provider_configuration:
raise ValueError(f"Provider {provider} does not exist.")
# Add or update custom provider credentials.
provider_configuration.add_or_update_custom_credentials(credentials)
def remove_provider_credentials(self, tenant_id: str, provider: str) -> None:
"""
remove custom provider config.
:param tenant_id: workspace id
:param provider: provider name
:return:
"""
# Get all provider configurations of the current workspace
provider_configurations = self.provider_manager.get_configurations(tenant_id)
# Get provider configuration
provider_configuration = provider_configurations.get(provider)
if not provider_configuration:
raise ValueError(f"Provider {provider} does not exist.")
# Remove custom provider credentials.
provider_configuration.delete_custom_credentials()
def get_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str) -> dict:
"""
get model credentials.
:param tenant_id: workspace id
:param provider: provider name
:param model_type: model type
:param model: model name
:return:
"""
# Get all provider configurations of the current workspace
provider_configurations = self.provider_manager.get_configurations(tenant_id)
# Get provider configuration
provider_configuration = provider_configurations.get(provider)
if not provider_configuration:
raise ValueError(f"Provider {provider} does not exist.")
# Get model custom credentials from ProviderModel if exists
return provider_configuration.get_custom_model_credentials(
model_type=ModelType.value_of(model_type),
model=model,
obfuscated=True
)
def model_credentials_validate(self, tenant_id: str, provider: str, model_type: str, model: str,
credentials: dict) -> None:
"""
validate model credentials.
:param tenant_id: workspace id
:param provider: provider name
:param model_type: model type
:param model: model name
:param credentials: model credentials
:return:
"""
# Get all provider configurations of the current workspace
provider_configurations = self.provider_manager.get_configurations(tenant_id)
# Get provider configuration
provider_configuration = provider_configurations.get(provider)
if not provider_configuration:
raise ValueError(f"Provider {provider} does not exist.")
# Validate model credentials
provider_configuration.custom_model_credentials_validate(
model_type=ModelType.value_of(model_type),
model=model,
credentials=credentials
)
def save_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str,
credentials: dict) -> None:
"""
save model credentials.
:param tenant_id: workspace id
:param provider: provider name
:param model_type: model type
:param model: model name
:param credentials: model credentials
:return:
"""
# Get all provider configurations of the current workspace
provider_configurations = self.provider_manager.get_configurations(tenant_id)
# Get provider configuration
provider_configuration = provider_configurations.get(provider)
if not provider_configuration:
raise ValueError(f"Provider {provider} does not exist.")
# Add or update custom model credentials
provider_configuration.add_or_update_custom_model_credentials(
model_type=ModelType.value_of(model_type),
model=model,
credentials=credentials
)
def remove_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str) -> None:
"""
remove model credentials.
:param tenant_id: workspace id
:param provider: provider name
:param model_type: model type
:param model: model name
:return:
"""
# Get all provider configurations of the current workspace
provider_configurations = self.provider_manager.get_configurations(tenant_id)
# Get provider configuration
provider_configuration = provider_configurations.get(provider)
if not provider_configuration:
raise ValueError(f"Provider {provider} does not exist.")
# Remove custom model credentials
provider_configuration.delete_custom_model_credentials(
model_type=ModelType.value_of(model_type),
model=model
)
def get_models_by_model_type(self, tenant_id: str, model_type: str) -> list[ProviderWithModelsResponse]:
"""
get models by model type.
:param tenant_id: workspace id
:param model_type: model type
:return:
"""
# Get all provider configurations of the current workspace
provider_configurations = self.provider_manager.get_configurations(tenant_id)
# Get provider available models
models = provider_configurations.get_models(
model_type=ModelType.value_of(model_type)
)
# Group models by provider
provider_models = {}
for model in models:
if model.provider.provider not in provider_models:
provider_models[model.provider.provider] = []
if model.deprecated:
continue
provider_models[model.provider.provider].append(model)
# convert to ProviderWithModelsResponse list
providers_with_models: list[ProviderWithModelsResponse] = []
for provider, models in provider_models.items():
if not models:
continue
first_model = models[0]
has_active_models = any([model.status == ModelStatus.ACTIVE for model in models])
providers_with_models.append(
ProviderWithModelsResponse(
provider=provider,
label=first_model.provider.label,
icon_small=first_model.provider.icon_small,
icon_large=first_model.provider.icon_large,
status=CustomConfigurationStatus.ACTIVE
if has_active_models else CustomConfigurationStatus.NO_CONFIGURE,
models=[ModelResponse(
model=model.model,
label=model.label,
model_type=model.model_type,
features=model.features,
fetch_from=model.fetch_from,
model_properties=model.model_properties,
status=model.status
) for model in models]
)
)
return providers_with_models
def get_model_parameter_rules(self, tenant_id: str, provider: str, model: str) -> list[ParameterRule]:
"""
get model parameter rules.
Only supports LLM.
:param tenant_id: workspace id
:param provider: provider name
:param model: model name
:return:
"""
# Get all provider configurations of the current workspace
provider_configurations = self.provider_manager.get_configurations(tenant_id)
# Get provider configuration
provider_configuration = provider_configurations.get(provider)
if not provider_configuration:
raise ValueError(f"Provider {provider} does not exist.")
# Get model instance of LLM
model_type_instance = provider_configuration.get_model_type_instance(ModelType.LLM)
model_type_instance = cast(LargeLanguageModel, model_type_instance)
# fetch credentials
credentials = provider_configuration.get_current_credentials(
model_type=ModelType.LLM,
model=model
)
if not credentials:
return []
# Call get_parameter_rules method of model instance to get model parameter rules
return model_type_instance.get_parameter_rules(
model=model,
credentials=credentials
)
def get_default_model_of_model_type(self, tenant_id: str, model_type: str) -> Optional[DefaultModelResponse]:
"""
get default model of model type.
:param tenant_id: workspace id
:param model_type: model type
:return:
"""
model_type_enum = ModelType.value_of(model_type)
result = self.provider_manager.get_default_model(
tenant_id=tenant_id,
model_type=model_type_enum
)
return DefaultModelResponse(
**result.dict()
) if result else None
def update_default_model_of_model_type(self, tenant_id: str, model_type: str, provider: str, model: str) -> None:
"""
update default model of model type.
:param tenant_id: workspace id
:param model_type: model type
:param provider: provider name
:param model: model name
:return:
"""
model_type_enum = ModelType.value_of(model_type)
self.provider_manager.update_default_model_record(
tenant_id=tenant_id,
model_type=model_type_enum,
provider=provider,
model=model
)
def get_model_provider_icon(self, provider: str, icon_type: str, lang: str) -> Tuple[Optional[bytes], Optional[str]]:
"""
get model provider icon.
:param provider: provider name
:param icon_type: icon type (icon_small or icon_large)
:param lang: language (zh_Hans or en_US)
:return:
"""
provider_instance = model_provider_factory.get_provider_instance(provider)
provider_schema = provider_instance.get_provider_schema()
if icon_type.lower() == 'icon_small':
if not provider_schema.icon_small:
raise ValueError(f"Provider {provider} does not have small icon.")
if lang.lower() == 'zh_hans':
file_name = provider_schema.icon_small.zh_Hans
else:
file_name = provider_schema.icon_small.en_US
else:
if not provider_schema.icon_large:
raise ValueError(f"Provider {provider} does not have large icon.")
if lang.lower() == 'zh_hans':
file_name = provider_schema.icon_large.zh_Hans
else:
file_name = provider_schema.icon_large.en_US
root_path = current_app.root_path
provider_instance_path = os.path.dirname(os.path.join(root_path, provider_instance.__class__.__module__.replace('.', '/')))
file_path = os.path.join(provider_instance_path, "_assets")
file_path = os.path.join(file_path, file_name)
if not os.path.exists(file_path):
return None, None
mimetype, _ = mimetypes.guess_type(file_path)
mimetype = mimetype or 'application/octet-stream'
# read binary from file
with open(file_path, 'rb') as f:
byte_data = f.read()
return byte_data, mimetype
def switch_preferred_provider(self, tenant_id: str, provider: str, preferred_provider_type: str) -> None:
"""
switch preferred provider.
:param tenant_id: workspace id
:param provider: provider name
:param preferred_provider_type: preferred provider type
:return:
"""
# Get all provider configurations of the current workspace
provider_configurations = self.provider_manager.get_configurations(tenant_id)
# Convert preferred_provider_type to ProviderType
preferred_provider_type_enum = ProviderType.value_of(preferred_provider_type)
# Get provider configuration
provider_configuration = provider_configurations.get(provider)
if not provider_configuration:
raise ValueError(f"Provider {provider} does not exist.")
# Switch preferred provider type
provider_configuration.switch_preferred_provider_type(preferred_provider_type_enum)
def free_quota_submit(self, tenant_id: str, provider: str):
api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY")
api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL")
api_url = api_base_url + '/api/v1/providers/apply'
headers = {
'Content-Type': 'application/json',
'Authorization': f"Bearer {api_key}"
}
response = requests.post(api_url, headers=headers, json={'workspace_id': tenant_id, 'provider_name': provider})
if not response.ok:
logger.error(f"Request FREE QUOTA APPLY SERVER Error: {response.status_code} ")
raise ValueError(f"Error: {response.status_code} ")
if response.json()["code"] != 'success':
raise ValueError(
f"error: {response.json()['message']}"
)
rst = response.json()
if rst['type'] == 'redirect':
return {
'type': rst['type'],
'redirect_url': rst['redirect_url']
}
else:
return {
'type': rst['type'],
'result': 'success'
}
def free_quota_qualification_verify(self, tenant_id: str, provider: str, token: Optional[str]):
api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY")
api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL")
api_url = api_base_url + '/api/v1/providers/qualification-verify'
headers = {
'Content-Type': 'application/json',
'Authorization': f"Bearer {api_key}"
}
json_data = {'workspace_id': tenant_id, 'provider_name': provider}
if token:
json_data['token'] = token
response = requests.post(api_url, headers=headers,
json=json_data)
if not response.ok:
logger.error(f"Request FREE QUOTA APPLY SERVER Error: {response.status_code} ")
raise ValueError(f"Error: {response.status_code} ")
rst = response.json()
if rst["code"] != 'success':
raise ValueError(
f"error: {rst['message']}"
)
data = rst['data']
if data['qualified'] is True:
return {
'result': 'success',
'provider_name': provider,
'flag': True
}
else:
return {
'result': 'success',
'provider_name': provider,
'flag': False,
'reason': data['reason']
}

View File

@ -1,596 +0,0 @@
import datetime
import json
import logging
import os
from collections import defaultdict
from typing import Optional
import requests
from core.model_providers.model_factory import ModelFactory
from extensions.ext_database import db
from core.model_providers.model_provider_factory import ModelProviderFactory
from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules
from models.provider import Provider, ProviderModel, TenantPreferredModelProvider, ProviderType, ProviderQuotaType, \
TenantDefaultModel
class ProviderService:
def get_provider_list(self, tenant_id: str, model_type: Optional[str] = None) -> list:
"""
get provider list of tenant.
:param tenant_id: workspace id
:param model_type: filter by model type
:return:
"""
# get rules for all providers
model_provider_rules = ModelProviderFactory.get_provider_rules()
model_provider_names = [model_provider_name for model_provider_name, _ in model_provider_rules.items()]
for model_provider_name, model_provider_rule in model_provider_rules.items():
if ProviderType.SYSTEM.value in model_provider_rule['support_provider_types'] \
and 'system_config' in model_provider_rule and model_provider_rule['system_config'] \
and 'supported_quota_types' in model_provider_rule['system_config'] \
and 'trial' in model_provider_rule['system_config']['supported_quota_types']:
ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
configurable_model_provider_names = [
model_provider_name
for model_provider_name, model_provider_rules in model_provider_rules.items()
if 'custom' in model_provider_rules['support_provider_types']
and model_provider_rules['model_flexibility'] == 'configurable'
]
# get all providers for the tenant
providers = db.session.query(Provider) \
.filter(
Provider.tenant_id == tenant_id,
Provider.provider_name.in_(model_provider_names),
Provider.is_valid == True
).order_by(Provider.created_at.desc()).all()
provider_name_to_provider_dict = defaultdict(list)
for provider in providers:
provider_name_to_provider_dict[provider.provider_name].append(provider)
# get all configurable provider models for the tenant
provider_models = db.session.query(ProviderModel) \
.filter(
ProviderModel.tenant_id == tenant_id,
ProviderModel.provider_name.in_(configurable_model_provider_names),
ProviderModel.is_valid == True
).order_by(ProviderModel.created_at.desc()).all()
provider_name_to_provider_model_dict = defaultdict(list)
for provider_model in provider_models:
provider_name_to_provider_model_dict[provider_model.provider_name].append(provider_model)
# get all preferred provider type for the tenant
preferred_provider_types = db.session.query(TenantPreferredModelProvider) \
.filter(
TenantPreferredModelProvider.tenant_id == tenant_id,
TenantPreferredModelProvider.provider_name.in_(model_provider_names)
).all()
provider_name_to_preferred_provider_type_dict = {preferred_provider_type.provider_name: preferred_provider_type
for preferred_provider_type in preferred_provider_types}
providers_list = {}
for model_provider_name, model_provider_rule in model_provider_rules.items():
if model_type and model_type not in model_provider_rule.get('supported_model_types', []):
continue
# get preferred provider type
preferred_model_provider = provider_name_to_preferred_provider_type_dict.get(model_provider_name)
preferred_provider_type = ModelProviderFactory.get_preferred_type_by_preferred_model_provider(
tenant_id,
model_provider_name,
preferred_model_provider
)
provider_config_dict = {
"preferred_provider_type": preferred_provider_type,
"model_flexibility": model_provider_rule['model_flexibility'],
"supported_model_types": model_provider_rule.get("supported_model_types", []),
}
provider_parameter_dict = {}
if ProviderType.SYSTEM.value in model_provider_rule['support_provider_types']:
for quota_type_enum in ProviderQuotaType:
quota_type = quota_type_enum.value
if quota_type in model_provider_rule['system_config']['supported_quota_types']:
key = ProviderType.SYSTEM.value + ':' + quota_type
provider_parameter_dict[key] = {
"provider_name": model_provider_name,
"provider_type": ProviderType.SYSTEM.value,
"config": None,
"is_valid": False, # need update
"quota_type": quota_type,
"quota_unit": model_provider_rule['system_config']['quota_unit'], # need update
"quota_limit": 0 if quota_type != ProviderQuotaType.TRIAL.value else
model_provider_rule['system_config']['quota_limit'], # need update
"quota_used": 0, # need update
"last_used": None # need update
}
if ProviderType.CUSTOM.value in model_provider_rule['support_provider_types']:
provider_parameter_dict[ProviderType.CUSTOM.value] = {
"provider_name": model_provider_name,
"provider_type": ProviderType.CUSTOM.value,
"config": None, # need update
"models": [], # need update
"is_valid": False,
"last_used": None # need update
}
model_provider_class = ModelProviderFactory.get_model_provider_class(model_provider_name)
current_providers = provider_name_to_provider_dict[model_provider_name]
for provider in current_providers:
if provider.provider_type == ProviderType.SYSTEM.value:
quota_type = provider.quota_type
key = f'{ProviderType.SYSTEM.value}:{quota_type}'
if key in provider_parameter_dict:
provider_parameter_dict[key]['is_valid'] = provider.is_valid
provider_parameter_dict[key]['quota_used'] = provider.quota_used
provider_parameter_dict[key]['quota_limit'] = provider.quota_limit
provider_parameter_dict[key]['last_used'] = int(provider.last_used.timestamp()) \
if provider.last_used else None
elif provider.provider_type == ProviderType.CUSTOM.value \
and ProviderType.CUSTOM.value in provider_parameter_dict:
# if custom
key = ProviderType.CUSTOM.value
provider_parameter_dict[key]['last_used'] = int(provider.last_used.timestamp()) \
if provider.last_used else None
provider_parameter_dict[key]['is_valid'] = provider.is_valid
if model_provider_rule['model_flexibility'] == 'fixed':
provider_parameter_dict[key]['config'] = model_provider_class(provider=provider) \
.get_provider_credentials(obfuscated=True)
else:
models = []
provider_models = provider_name_to_provider_model_dict[model_provider_name]
for provider_model in provider_models:
models.append({
"model_name": provider_model.model_name,
"model_type": provider_model.model_type,
"config": model_provider_class(provider=provider) \
.get_model_credentials(provider_model.model_name,
ModelType.value_of(provider_model.model_type),
obfuscated=True),
"is_valid": provider_model.is_valid
})
provider_parameter_dict[key]['models'] = models
provider_config_dict['providers'] = list(provider_parameter_dict.values())
providers_list[model_provider_name] = provider_config_dict
return providers_list
def custom_provider_config_validate(self, provider_name: str, config: dict) -> None:
"""
validate custom provider config.
:param provider_name:
:param config:
:return:
:raises CredentialsValidateFailedError: When the config credential verification fails.
"""
# get model provider rules
model_provider_rules = ModelProviderFactory.get_provider_rule(provider_name)
if model_provider_rules['model_flexibility'] != 'fixed':
raise ValueError('Only support fixed model provider')
# only support provider type CUSTOM
if ProviderType.CUSTOM.value not in model_provider_rules['support_provider_types']:
raise ValueError('Only support provider type CUSTOM')
# validate provider config
model_provider_class = ModelProviderFactory.get_model_provider_class(provider_name)
model_provider_class.is_provider_credentials_valid_or_raise(config)
def save_custom_provider_config(self, tenant_id: str, provider_name: str, config: dict) -> None:
"""
save custom provider config.
:param tenant_id:
:param provider_name:
:param config:
:return:
"""
# validate custom provider config
self.custom_provider_config_validate(provider_name, config)
# get provider
provider = db.session.query(Provider) \
.filter(
Provider.tenant_id == tenant_id,
Provider.provider_name == provider_name,
Provider.provider_type == ProviderType.CUSTOM.value
).first()
model_provider_class = ModelProviderFactory.get_model_provider_class(provider_name)
encrypted_config = model_provider_class.encrypt_provider_credentials(tenant_id, config)
# save provider
if provider:
provider.encrypted_config = json.dumps(encrypted_config)
provider.is_valid = True
provider.updated_at = datetime.datetime.utcnow()
db.session.commit()
else:
provider = Provider(
tenant_id=tenant_id,
provider_name=provider_name,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps(encrypted_config),
is_valid=True
)
db.session.add(provider)
db.session.commit()
def delete_custom_provider(self, tenant_id: str, provider_name: str) -> None:
"""
delete custom provider.
:param tenant_id:
:param provider_name:
:return:
"""
# get provider
provider = db.session.query(Provider) \
.filter(
Provider.tenant_id == tenant_id,
Provider.provider_name == provider_name,
Provider.provider_type == ProviderType.CUSTOM.value
).first()
if provider:
try:
self.switch_preferred_provider(tenant_id, provider_name, ProviderType.SYSTEM.value)
except ValueError:
pass
db.session.delete(provider)
db.session.commit()
def custom_provider_model_config_validate(self,
provider_name: str,
model_name: str,
model_type: str,
config: dict) -> None:
"""
validate custom provider model config.
:param provider_name:
:param model_name:
:param model_type:
:param config:
:return:
:raises CredentialsValidateFailedError: When the config credential verification fails.
"""
# get model provider rules
model_provider_rules = ModelProviderFactory.get_provider_rule(provider_name)
if model_provider_rules['model_flexibility'] != 'configurable':
raise ValueError('Only support configurable model provider')
# only support provider type CUSTOM
if ProviderType.CUSTOM.value not in model_provider_rules['support_provider_types']:
raise ValueError('Only support provider type CUSTOM')
# validate provider model config
model_type = ModelType.value_of(model_type)
model_provider_class = ModelProviderFactory.get_model_provider_class(provider_name)
model_provider_class.is_model_credentials_valid_or_raise(model_name, model_type, config)
def add_or_save_custom_provider_model_config(self,
tenant_id: str,
provider_name: str,
model_name: str,
model_type: str,
config: dict) -> None:
"""
Add or save custom provider model config.
:param tenant_id:
:param provider_name:
:param model_name:
:param model_type:
:param config:
:return:
"""
# validate custom provider model config
self.custom_provider_model_config_validate(provider_name, model_name, model_type, config)
# get provider
provider = db.session.query(Provider) \
.filter(
Provider.tenant_id == tenant_id,
Provider.provider_name == provider_name,
Provider.provider_type == ProviderType.CUSTOM.value
).first()
if not provider:
provider = Provider(
tenant_id=tenant_id,
provider_name=provider_name,
provider_type=ProviderType.CUSTOM.value,
is_valid=True
)
db.session.add(provider)
db.session.commit()
elif not provider.is_valid:
provider.is_valid = True
provider.encrypted_config = None
db.session.commit()
model_provider_class = ModelProviderFactory.get_model_provider_class(provider_name)
encrypted_config = model_provider_class.encrypt_model_credentials(
tenant_id,
model_name,
ModelType.value_of(model_type),
config
)
# get provider model
provider_model = db.session.query(ProviderModel) \
.filter(
ProviderModel.tenant_id == tenant_id,
ProviderModel.provider_name == provider_name,
ProviderModel.model_name == model_name,
ProviderModel.model_type == model_type
).first()
if provider_model:
provider_model.encrypted_config = json.dumps(encrypted_config)
provider_model.is_valid = True
db.session.commit()
else:
provider_model = ProviderModel(
tenant_id=tenant_id,
provider_name=provider_name,
model_name=model_name,
model_type=model_type,
encrypted_config=json.dumps(encrypted_config),
is_valid=True
)
db.session.add(provider_model)
db.session.commit()
def delete_custom_provider_model(self,
tenant_id: str,
provider_name: str,
model_name: str,
model_type: str) -> None:
"""
delete custom provider model.
:param tenant_id:
:param provider_name:
:param model_name:
:param model_type:
:return:
"""
# get provider model
provider_model = db.session.query(ProviderModel) \
.filter(
ProviderModel.tenant_id == tenant_id,
ProviderModel.provider_name == provider_name,
ProviderModel.model_name == model_name,
ProviderModel.model_type == model_type
).first()
if provider_model:
db.session.delete(provider_model)
db.session.commit()
def switch_preferred_provider(self, tenant_id: str, provider_name: str, preferred_provider_type: str) -> None:
"""
switch preferred provider.
:param tenant_id:
:param provider_name:
:param preferred_provider_type:
:return:
"""
provider_type = ProviderType.value_of(preferred_provider_type)
if not provider_type:
raise ValueError(f'Invalid preferred provider type: {preferred_provider_type}')
model_provider_rules = ModelProviderFactory.get_provider_rule(provider_name)
if preferred_provider_type not in model_provider_rules['support_provider_types']:
raise ValueError(f'Not support provider type: {preferred_provider_type}')
model_provider = ModelProviderFactory.get_model_provider_class(provider_name)
if not model_provider.is_provider_type_system_supported():
return
# get preferred provider
preferred_model_provider = db.session.query(TenantPreferredModelProvider) \
.filter(
TenantPreferredModelProvider.tenant_id == tenant_id,
TenantPreferredModelProvider.provider_name == provider_name
).first()
if preferred_model_provider:
preferred_model_provider.preferred_provider_type = preferred_provider_type
else:
preferred_model_provider = TenantPreferredModelProvider(
tenant_id=tenant_id,
provider_name=provider_name,
preferred_provider_type=preferred_provider_type
)
db.session.add(preferred_model_provider)
db.session.commit()
def get_default_model_of_model_type(self, tenant_id: str, model_type: str) -> Optional[TenantDefaultModel]:
"""
get default model of model type.
:param tenant_id:
:param model_type:
:return:
"""
return ModelFactory.get_default_model(tenant_id, ModelType.value_of(model_type))
def update_default_model_of_model_type(self,
tenant_id: str,
model_type: str,
provider_name: str,
model_name: str) -> TenantDefaultModel:
"""
update default model of model type.
:param tenant_id:
:param model_type:
:param provider_name:
:param model_name:
:return:
"""
return ModelFactory.update_default_model(tenant_id, ModelType.value_of(model_type), provider_name, model_name)
def get_valid_model_list(self, tenant_id: str, model_type: str) -> list:
"""
get valid model list.
:param tenant_id:
:param model_type:
:return:
"""
valid_model_list = []
# get model provider rules
model_provider_rules = ModelProviderFactory.get_provider_rules()
for model_provider_name, model_provider_rule in model_provider_rules.items():
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
if not model_provider:
continue
model_list = model_provider.get_supported_model_list(ModelType.value_of(model_type))
provider = model_provider.provider
for model in model_list:
valid_model_dict = {
"model_name": model['id'],
"model_display_name": model['name'],
"model_type": model_type,
"model_provider": {
"provider_name": provider.provider_name,
"provider_type": provider.provider_type
},
'features': []
}
if 'mode' in model:
valid_model_dict['model_mode'] = model['mode']
if 'features' in model:
valid_model_dict['features'] = model['features']
if provider.provider_type == ProviderType.SYSTEM.value:
valid_model_dict['model_provider']['quota_type'] = provider.quota_type
valid_model_dict['model_provider']['quota_unit'] = model_provider_rule['system_config']['quota_unit']
valid_model_dict['model_provider']['quota_limit'] = provider.quota_limit
valid_model_dict['model_provider']['quota_used'] = provider.quota_used
valid_model_list.append(valid_model_dict)
return valid_model_list
def get_model_parameter_rules(self, tenant_id: str, model_provider_name: str, model_name: str, model_type: str) \
-> ModelKwargsRules:
"""
get model parameter rules.
It depends on preferred provider in use.
:param tenant_id:
:param model_provider_name:
:param model_name:
:param model_type:
:return:
"""
# get model provider
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
if not model_provider:
# get empty model provider
return ModelKwargsRules()
# get model parameter rules
return model_provider.get_model_parameter_rules(model_name, ModelType.value_of(model_type))
def free_quota_submit(self, tenant_id: str, provider_name: str):
api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY")
api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL")
api_url = api_base_url + '/api/v1/providers/apply'
headers = {
'Content-Type': 'application/json',
'Authorization': f"Bearer {api_key}"
}
response = requests.post(api_url, headers=headers, json={'workspace_id': tenant_id, 'provider_name': provider_name})
if not response.ok:
logging.error(f"Request FREE QUOTA APPLY SERVER Error: {response.status_code} ")
raise ValueError(f"Error: {response.status_code} ")
if response.json()["code"] != 'success':
raise ValueError(
f"error: {response.json()['message']}"
)
rst = response.json()
if rst['type'] == 'redirect':
return {
'type': rst['type'],
'redirect_url': rst['redirect_url']
}
else:
return {
'type': rst['type'],
'result': 'success'
}
def free_quota_qualification_verify(self, tenant_id: str, provider_name: str, token: Optional[str]):
api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY")
api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL")
api_url = api_base_url + '/api/v1/providers/qualification-verify'
headers = {
'Content-Type': 'application/json',
'Authorization': f"Bearer {api_key}"
}
json_data = {'workspace_id': tenant_id, 'provider_name': provider_name}
if token:
json_data['token'] = token
response = requests.post(api_url, headers=headers,
json=json_data)
if not response.ok:
logging.error(f"Request FREE QUOTA APPLY SERVER Error: {response.status_code} ")
raise ValueError(f"Error: {response.status_code} ")
rst = response.json()
if rst["code"] != 'success':
raise ValueError(
f"error: {rst['message']}"
)
data = rst['data']
if data['qualified'] is True:
return {
'result': 'success',
'provider_name': provider_name,
'flag': True
}
else:
return {
'result': 'success',
'provider_name': provider_name,
'flag': False,
'reason': data['reason']
}

View File

@ -1,9 +1,11 @@
from typing import Optional
from flask import current_app, Flask
from langchain.embeddings.base import Embeddings
from core.index.vector_index.vector_index import VectorIndex
from core.model_providers.model_factory import ModelFactory
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.rerank.rerank import RerankRunner
from extensions.ext_database import db
from models.dataset import Dataset
@ -50,12 +52,24 @@ class RetrievalService:
if documents:
if reranking_model and search_method == 'semantic_search':
rerank = ModelFactory.get_reranking_model(
tenant_id=dataset.tenant_id,
model_provider_name=reranking_model['reranking_provider_name'],
model_name=reranking_model['reranking_model_name']
)
all_documents.extend(rerank.rerank(query, documents, score_threshold, len(documents)))
try:
model_manager = ModelManager()
rerank_model_instance = model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
provider=reranking_model['reranking_provider_name'],
model_type=ModelType.RERANK,
model=reranking_model['reranking_model_name']
)
except InvokeAuthorizationError:
return
rerank_runner = RerankRunner(rerank_model_instance)
all_documents.extend(rerank_runner.run(
query=query,
documents=documents,
score_threshold=score_threshold,
top_n=len(documents)
))
else:
all_documents.extend(documents)
@ -81,15 +95,23 @@ class RetrievalService:
)
if documents:
if reranking_model and search_method == 'full_text_search':
rerank = ModelFactory.get_reranking_model(
tenant_id=dataset.tenant_id,
model_provider_name=reranking_model['reranking_provider_name'],
model_name=reranking_model['reranking_model_name']
)
all_documents.extend(rerank.rerank(query, documents, score_threshold, len(documents)))
try:
model_manager = ModelManager()
rerank_model_instance = model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
provider=reranking_model['reranking_provider_name'],
model_type=ModelType.RERANK,
model=reranking_model['reranking_model_name']
)
except InvokeAuthorizationError:
return
rerank_runner = RerankRunner(rerank_model_instance)
all_documents.extend(rerank_runner.run(
query=query,
documents=documents,
score_threshold=score_threshold,
top_n=len(documents)
))
else:
all_documents.extend(documents)

View File

@ -19,7 +19,6 @@ class WorkspaceService:
'plan': tenant.plan,
'status': tenant.status,
'created_at': tenant.created_at,
'providers': [],
'in_trail': True,
'trial_end_reason': None,
'role': 'normal',
@ -37,12 +36,4 @@ class WorkspaceService:
if can_replace_logo and TenantService.has_roles(tenant, [TenantAccountJoinRole.OWNER, TenantAccountJoinRole.ADMIN]):
tenant_info['custom_config'] = tenant.custom_config_dict
# Get providers
providers = db.session.query(Provider).filter(
Provider.tenant_id == tenant.id
).all()
# Add providers to the tenant info
tenant_info['providers'] = providers
return tenant_info