Compare commits

...

25 Commits

Author SHA1 Message Date
5809edd74b feat: bump version to 0.3.23 (#1198) 2023-09-20 00:14:36 +08:00
05bfa11915 build: update devDependencies (#1125) 2023-09-19 13:31:48 +08:00
435f804c6f fix: gpt-3.5-turbo-instruct context size to 8192 (#1196) 2023-09-19 02:10:22 +08:00
ae3f1ac0a9 feat: support gpt-3.5-turbo-instruct model (#1195) 2023-09-19 02:05:04 +08:00
269a465fc4 Feat/improve vector database logic (#1193)
Co-authored-by: jyong <jyong@dify.ai>
2023-09-18 18:15:41 +08:00
60e0bbd713 Feat/provider add zhipuai (#1192)
Co-authored-by: Joel <iamjoel007@gmail.com>
2023-09-18 18:02:05 +08:00
827c97f0d3 feat: add zhipuai (#1188) 2023-09-18 17:32:31 +08:00
c8bd76cd66 fix: inference embedding validate (#1187) 2023-09-16 03:09:36 +08:00
ec5f585df4 1111 wrong embedding model displayed in datasets (#1186) 2023-09-15 07:54:45 -05:00
1de48f33ca feat(web): service request return generics type (#1157) 2023-09-15 07:54:20 -05:00
6b41a9593e fix: text error (#1184) 2023-09-15 14:15:28 +08:00
82267083e8 fix: model param description error (#1183) 2023-09-15 11:36:01 +08:00
c385961d33 chore: Optimization model parameter description (#1181) 2023-09-15 11:14:14 +08:00
20bab6edec Restore the application template (#1174)
Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
2023-09-14 08:28:32 -05:00
67bed54f32 Mermaid front end rendering (#1166)
Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
2023-09-14 14:09:23 +08:00
leo
562a571281 fix: Improved fallback solution for avatar image loading failure (#1172) 2023-09-14 13:31:35 +08:00
fc68c81791 fix: correct invite url (#1173) 2023-09-14 12:07:34 +08:00
5d9070bc60 Feat/add blocking mode resource return (#1171)
Co-authored-by: jyong <jyong@dify.ai>
2023-09-13 18:53:35 +08:00
b11fb0dfd1 fix LocalAI is missing in lang/en (#1169) 2023-09-13 10:08:33 +08:00
d1c5c5f160 add video to cn readme (#1165) 2023-09-12 08:30:12 -05:00
0b1d1440aa Update README.md (#1164) 2023-09-12 07:48:35 -05:00
0c420d64b3 chore: hover conversation show option button (#1160) 2023-09-12 16:35:13 +08:00
f9082104ed feat: add hosted moderation (#1158) 2023-09-12 10:26:12 +08:00
983834cd52 feat: spark check (#1134) 2023-09-11 17:31:03 +08:00
96d10c8b39 feat: spark free quota verify (#1152) 2023-09-11 17:30:54 +08:00
118 changed files with 3549 additions and 1173 deletions

View File

@ -16,6 +16,10 @@ Out-of-the-box web sites supporting form mode and chat conversation mode
A single API encompassing plugin capabilities, context enhancement, and more, saving you backend coding effort
Visual data analysis, log review, and annotation for applications
https://github.com/langgenius/dify/assets/100913391/f6e658d5-31b3-4c16-a0af-9e191da4d0f6
## Highlighted Features
**1. LLMs support:** Choose capabilities based on different models when building your Dify AI apps. Dify is compatible with Langchain, meaning it will support various LLMs. Currently supported:

View File

@ -17,7 +17,7 @@
- 一套 API 即可包含插件、上下文增强等能力,替你省下了后端代码的编写工作
- 可视化的对应用进行数据分析,查阅日志或进行标注
https://github.com/langgenius/dify/assets/100913391/f6e658d5-31b3-4c16-a0af-9e191da4d0f6
## 核心能力
1. **模型支持:** 你可以在 Dify 上选择基于不同模型的能力来开发你的 AI 应用。Dify 兼容 Langchain这意味着我们将逐步支持多种 LLMs ,目前支持的模型供应商:

View File

@ -4,6 +4,7 @@ import math
import random
import string
import time
import uuid
import click
from tqdm import tqdm
@ -23,7 +24,7 @@ from libs.helper import email as email_validate
from extensions.ext_database import db
from libs.rsa import generate_key_pair
from models.account import InvitationCode, Tenant, TenantAccountJoin
from models.dataset import Dataset, DatasetQuery, Document
from models.dataset import Dataset, DatasetQuery, Document, DatasetCollectionBinding
from models.model import Account, AppModelConfig, App
import secrets
import base64
@ -239,7 +240,13 @@ def clean_unused_dataset_indexes():
kw_index = IndexBuilder.get_index(dataset, 'economy')
# delete from vector index
if vector_index:
vector_index.delete()
if dataset.collection_binding_id:
vector_index.delete_by_group_id(dataset.id)
else:
if dataset.collection_binding_id:
vector_index.delete_by_group_id(dataset.id)
else:
vector_index.delete()
kw_index.delete()
# update document
update_params = {
@ -346,7 +353,8 @@ def create_qdrant_indexes():
is_valid=True,
)
model_provider = OpenAIProvider(provider=provider)
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002", model_provider=model_provider)
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002",
model_provider=model_provider)
embeddings = CacheEmbedding(embedding_model)
from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
@ -364,7 +372,8 @@ def create_qdrant_indexes():
index.create_qdrant_dataset(dataset)
index_struct = {
"type": 'qdrant',
"vector_store": {"class_prefix": dataset.index_struct_dict['vector_store']['class_prefix']}
"vector_store": {
"class_prefix": dataset.index_struct_dict['vector_store']['class_prefix']}
}
dataset.index_struct = json.dumps(index_struct)
db.session.commit()
@ -373,7 +382,8 @@ def create_qdrant_indexes():
click.echo('passed.')
except Exception as e:
click.echo(
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)), fg='red'))
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
fg='red'))
continue
click.echo(click.style('Congratulations! Create {} dataset indexes.'.format(create_count), fg='green'))
@ -414,7 +424,8 @@ def update_qdrant_indexes():
is_valid=True,
)
model_provider = OpenAIProvider(provider=provider)
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002", model_provider=model_provider)
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002",
model_provider=model_provider)
embeddings = CacheEmbedding(embedding_model)
from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
@ -435,11 +446,104 @@ def update_qdrant_indexes():
click.echo('passed.')
except Exception as e:
click.echo(
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)), fg='red'))
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
fg='red'))
continue
click.echo(click.style('Congratulations! Update {} dataset indexes.'.format(create_count), fg='green'))
@click.command('normalization-collections', help='restore all collections in one')
def normalization_collections():
click.echo(click.style('Start normalization collections.', fg='green'))
normalization_count = 0
page = 1
while True:
try:
datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \
.order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50)
except NotFound:
break
page += 1
for dataset in datasets:
if not dataset.collection_binding_id:
try:
click.echo('restore dataset index: {}'.format(dataset.id))
try:
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
except Exception:
provider = Provider(
id='provider_id',
tenant_id=dataset.tenant_id,
provider_name='openai',
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({'openai_api_key': 'TEST'}),
is_valid=True,
)
model_provider = OpenAIProvider(provider=provider)
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002",
model_provider=model_provider)
embeddings = CacheEmbedding(embedding_model)
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
filter(DatasetCollectionBinding.provider_name == embedding_model.model_provider.provider_name,
DatasetCollectionBinding.model_name == embedding_model.name). \
order_by(DatasetCollectionBinding.created_at). \
first()
if not dataset_collection_binding:
dataset_collection_binding = DatasetCollectionBinding(
provider_name=embedding_model.model_provider.provider_name,
model_name=embedding_model.name,
collection_name="Vector_index_" + str(uuid.uuid4()).replace("-", "_") + '_Node'
)
db.session.add(dataset_collection_binding)
db.session.commit()
from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
index = QdrantVectorIndex(
dataset=dataset,
config=QdrantConfig(
endpoint=current_app.config.get('QDRANT_URL'),
api_key=current_app.config.get('QDRANT_API_KEY'),
root_path=current_app.root_path
),
embeddings=embeddings
)
if index:
index.restore_dataset_in_one(dataset, dataset_collection_binding)
else:
click.echo('passed.')
original_index = QdrantVectorIndex(
dataset=dataset,
config=QdrantConfig(
endpoint=current_app.config.get('QDRANT_URL'),
api_key=current_app.config.get('QDRANT_API_KEY'),
root_path=current_app.root_path
),
embeddings=embeddings
)
if original_index:
original_index.delete_original_collection(dataset, dataset_collection_binding)
normalization_count += 1
else:
click.echo('passed.')
except Exception as e:
click.echo(
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
fg='red'))
continue
click.echo(click.style('Congratulations! restore {} dataset indexes.'.format(normalization_count), fg='green'))
@click.command('update_app_model_configs', help='Migrate data to support paragraph variable.')
@click.option("--batch-size", default=500, help="Number of records to migrate in each batch.")
def update_app_model_configs(batch_size):
@ -473,7 +577,7 @@ def update_app_model_configs(batch_size):
.join(App, App.app_model_config_id == AppModelConfig.id) \
.filter(App.mode == 'completion') \
.count()
if total_records == 0:
click.secho("No data to migrate.", fg='green')
return
@ -485,14 +589,14 @@ def update_app_model_configs(batch_size):
offset = i * batch_size
limit = min(batch_size, total_records - offset)
click.secho(f"Fetching batch {i+1}/{num_batches} from source database...", fg='green')
click.secho(f"Fetching batch {i + 1}/{num_batches} from source database...", fg='green')
data_batch = db.session.query(AppModelConfig) \
.join(App, App.app_model_config_id == AppModelConfig.id) \
.filter(App.mode == 'completion') \
.order_by(App.created_at) \
.offset(offset).limit(limit).all()
if not data_batch:
click.secho("No more data to migrate.", fg='green')
break
@ -512,7 +616,7 @@ def update_app_model_configs(batch_size):
app_data = db.session.query(App) \
.filter(App.id == data.app_id) \
.one()
account_data = db.session.query(Account) \
.join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id) \
.filter(TenantAccountJoin.role == 'owner') \
@ -534,13 +638,15 @@ def update_app_model_configs(batch_size):
db.session.commit()
except Exception as e:
click.secho(f"Error while migrating data: {e}, app_id: {data.app_id}, app_model_config_id: {data.id}", fg='red')
click.secho(f"Error while migrating data: {e}, app_id: {data.app_id}, app_model_config_id: {data.id}",
fg='red')
continue
click.secho(f"Successfully migrated batch {i+1}/{num_batches}.", fg='green')
click.secho(f"Successfully migrated batch {i + 1}/{num_batches}.", fg='green')
pbar.update(len(data_batch))
def register_commands(app):
app.cli.add_command(reset_password)
app.cli.add_command(reset_email)
@ -551,4 +657,5 @@ def register_commands(app):
app.cli.add_command(clean_unused_dataset_indexes)
app.cli.add_command(create_qdrant_indexes)
app.cli.add_command(update_qdrant_indexes)
app.cli.add_command(update_app_model_configs)
app.cli.add_command(update_app_model_configs)
app.cli.add_command(normalization_collections)

View File

@ -61,6 +61,8 @@ DEFAULTS = {
'HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA': 1000000,
'HOSTED_ANTHROPIC_PAID_MIN_QUANTITY': 20,
'HOSTED_ANTHROPIC_PAID_MAX_QUANTITY': 100,
'HOSTED_MODERATION_ENABLED': 'False',
'HOSTED_MODERATION_PROVIDERS': '',
'TENANT_DOCUMENT_COUNT': 100,
'CLEAN_DAY_SETTING': 30,
'UPLOAD_FILE_SIZE_LIMIT': 15,
@ -100,7 +102,7 @@ class Config:
self.CONSOLE_URL = get_env('CONSOLE_URL')
self.API_URL = get_env('API_URL')
self.APP_URL = get_env('APP_URL')
self.CURRENT_VERSION = "0.3.22"
self.CURRENT_VERSION = "0.3.23"
self.COMMIT_SHA = get_env('COMMIT_SHA')
self.EDITION = "SELF_HOSTED"
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
@ -230,6 +232,9 @@ class Config:
self.HOSTED_ANTHROPIC_PAID_MIN_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MIN_QUANTITY'))
self.HOSTED_ANTHROPIC_PAID_MAX_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MAX_QUANTITY'))
self.HOSTED_MODERATION_ENABLED = get_bool_env('HOSTED_MODERATION_ENABLED')
self.HOSTED_MODERATION_PROVIDERS = get_env('HOSTED_MODERATION_PROVIDERS')
self.STRIPE_API_KEY = get_env('STRIPE_API_KEY')
self.STRIPE_WEBHOOK_SECRET = get_env('STRIPE_WEBHOOK_SECRET')

View File

@ -16,7 +16,7 @@ model_templates = {
},
'model_config': {
'provider': 'openai',
'model_id': 'text-davinci-003',
'model_id': 'gpt-3.5-turbo-instruct',
'configs': {
'prompt_template': '',
'prompt_variables': [],
@ -30,7 +30,7 @@ model_templates = {
},
'model': json.dumps({
"provider": "openai",
"name": "text-davinci-003",
"name": "gpt-3.5-turbo-instruct",
"completion_params": {
"max_tokens": 512,
"temperature": 1,
@ -104,7 +104,7 @@ demo_model_templates = {
'mode': 'completion',
'model_config': AppModelConfig(
provider='openai',
model_id='text-davinci-003',
model_id='gpt-3.5-turbo-instruct',
configs={
'prompt_template': "Please translate the following text into {{target_language}}:\n",
'prompt_variables': [
@ -140,7 +140,7 @@ demo_model_templates = {
pre_prompt="Please translate the following text into {{target_language}}:\n",
model=json.dumps({
"provider": "openai",
"name": "text-davinci-003",
"name": "gpt-3.5-turbo-instruct",
"completion_params": {
"max_tokens": 1000,
"temperature": 0,
@ -222,7 +222,7 @@ demo_model_templates = {
'mode': 'completion',
'model_config': AppModelConfig(
provider='openai',
model_id='text-davinci-003',
model_id='gpt-3.5-turbo-instruct',
configs={
'prompt_template': "请将以下文本翻译为{{target_language}}:\n",
'prompt_variables': [
@ -258,7 +258,7 @@ demo_model_templates = {
pre_prompt="请将以下文本翻译为{{target_language}}:\n",
model=json.dumps({
"provider": "openai",
"name": "text-davinci-003",
"name": "gpt-3.5-turbo-instruct",
"completion_params": {
"max_tokens": 1000,
"temperature": 0,

View File

@ -33,7 +33,6 @@ class UniversalChatApi(UniversalChatResource):
args = parser.parse_args()
app_model_config = app_model.app_model_config
app_model_config
# update app model config
args['model_config'] = app_model_config.to_dict()

View File

@ -246,7 +246,8 @@ class ModelProviderModelParameterRuleApi(Resource):
'enabled': v.enabled,
'min': v.min,
'max': v.max,
'default': v.default
'default': v.default,
'precision': v.precision
}
for k, v in vars(parameter_rules).items()
}
@ -285,6 +286,25 @@ class ModelProviderFreeQuotaSubmitApi(Resource):
return result
class ModelProviderFreeQuotaQualificationVerifyApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider_name: str):
parser = reqparse.RequestParser()
parser.add_argument('token', type=str, required=False, nullable=True, location='args')
args = parser.parse_args()
provider_service = ProviderService()
result = provider_service.free_quota_qualification_verify(
tenant_id=current_user.current_tenant_id,
provider_name=provider_name,
token=args['token']
)
return result
api.add_resource(ModelProviderListApi, '/workspaces/current/model-providers')
api.add_resource(ModelProviderValidateApi, '/workspaces/current/model-providers/<string:provider_name>/validate')
api.add_resource(ModelProviderUpdateApi, '/workspaces/current/model-providers/<string:provider_name>')
@ -300,3 +320,5 @@ api.add_resource(ModelProviderPaymentCheckoutUrlApi,
'/workspaces/current/model-providers/<string:provider_name>/checkout-url')
api.add_resource(ModelProviderFreeQuotaSubmitApi,
'/workspaces/current/model-providers/<string:provider_name>/free-quota-submit')
api.add_resource(ModelProviderFreeQuotaQualificationVerifyApi,
'/workspaces/current/model-providers/<string:provider_name>/free-quota-qualification-verify')

View File

@ -16,6 +16,8 @@ from core.agent.agent.structed_multi_dataset_router_agent import StructuredMulti
from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent
from langchain.agents import AgentExecutor as LCAgentExecutor
from core.helper import moderation
from core.model_providers.error import LLMError
from core.model_providers.models.llm.base import BaseLLM
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
@ -116,6 +118,18 @@ class AgentExecutor:
return self.agent.should_use_agent(query)
def run(self, query: str) -> AgentExecuteResult:
moderation_result = moderation.check_moderation(
self.configuration.model_instance.model_provider,
query
)
if not moderation_result:
return AgentExecuteResult(
output="I apologize for any confusion, but I'm an AI assistant to be helpful, harmless, and honest.",
strategy=self.configuration.strategy,
configuration=self.configuration
)
agent_executor = LCAgentExecutor.from_agent_and_tools(
agent=self.agent,
tools=self.configuration.tools,
@ -128,7 +142,9 @@ class AgentExecutor:
try:
output = agent_executor.run(query)
except Exception:
except LLMError as ex:
raise ex
except Exception as ex:
logging.exception("agent_executor run failed")
output = None

View File

@ -6,7 +6,7 @@ from typing import Any, Dict, List, Union, Optional
from langchain.agents import openai_functions_agent, openai_functions_multi_agent
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration
from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration, BaseMessage
from core.callback_handler.entity.agent_loop import AgentLoop
from core.conversation_message_task import ConversationMessageTask
@ -18,9 +18,9 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
"""Callback Handler that prints to std out."""
raise_error: bool = True
def __init__(self, model_instant: BaseLLM, conversation_message_task: ConversationMessageTask) -> None:
def __init__(self, model_instance: BaseLLM, conversation_message_task: ConversationMessageTask) -> None:
"""Initialize callback handler."""
self.model_instant = model_instant
self.model_instance = model_instance
self.conversation_message_task = conversation_message_task
self._agent_loops = []
self._current_loop = None
@ -46,6 +46,21 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
"""Whether to ignore chain callbacks."""
return True
def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
**kwargs: Any
) -> Any:
if not self._current_loop:
# Agent start with a LLM query
self._current_loop = AgentLoop(
position=len(self._agent_loops) + 1,
prompt="\n".join([message.content for message in messages[0]]),
status='llm_started',
started_at=time.perf_counter()
)
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
@ -70,7 +85,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
if response.llm_output:
self._current_loop.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
else:
self._current_loop.prompt_tokens = self.model_instant.get_num_tokens(
self._current_loop.prompt_tokens = self.model_instance.get_num_tokens(
[PromptMessage(content=self._current_loop.prompt)]
)
completion_generation = response.generations[0][0]
@ -87,7 +102,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
if response.llm_output:
self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens']
else:
self._current_loop.completion_tokens = self.model_instant.get_num_tokens(
self._current_loop.completion_tokens = self.model_instance.get_num_tokens(
[PromptMessage(content=self._current_loop.completion)]
)
@ -162,7 +177,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at
self.conversation_message_task.on_agent_end(
self._message_agent_thought, self.model_instant, self._current_loop
self._message_agent_thought, self.model_instance, self._current_loop
)
self._agent_loops.append(self._current_loop)
@ -193,7 +208,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
)
self.conversation_message_task.on_agent_end(
self._message_agent_thought, self.model_instant, self._current_loop
self._message_agent_thought, self.model_instance, self._current_loop
)
self._agent_loops.append(self._current_loop)

View File

@ -6,4 +6,3 @@ class LLMMessage(BaseModel):
prompt_tokens: int = 0
completion: str = ''
completion_tokens: int = 0
latency: float = 0.0

View File

@ -1,5 +1,4 @@
import logging
import time
from typing import Any, Dict, List, Union
from langchain.callbacks.base import BaseCallbackHandler
@ -32,7 +31,6 @@ class LLMCallbackHandler(BaseCallbackHandler):
messages: List[List[BaseMessage]],
**kwargs: Any
) -> Any:
self.start_at = time.perf_counter()
real_prompts = []
for message in messages[0]:
if message.type == 'human':
@ -53,8 +51,6 @@ class LLMCallbackHandler(BaseCallbackHandler):
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
self.start_at = time.perf_counter()
self.llm_message.prompt = [{
"role": 'user',
"text": prompts[0]
@ -63,14 +59,22 @@ class LLMCallbackHandler(BaseCallbackHandler):
self.llm_message.prompt_tokens = self.model_instance.get_num_tokens([PromptMessage(content=prompts[0])])
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
end_at = time.perf_counter()
self.llm_message.latency = end_at - self.start_at
if not self.conversation_message_task.streaming:
self.conversation_message_task.append_message_text(response.generations[0][0].text)
self.llm_message.completion = response.generations[0][0].text
self.llm_message.completion_tokens = self.model_instance.get_num_tokens([PromptMessage(content=self.llm_message.completion)])
if response.llm_output and 'token_usage' in response.llm_output:
if 'prompt_tokens' in response.llm_output['token_usage']:
self.llm_message.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
if 'completion_tokens' in response.llm_output['token_usage']:
self.llm_message.completion_tokens = response.llm_output['token_usage']['completion_tokens']
else:
self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
[PromptMessage(content=self.llm_message.completion)])
else:
self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
[PromptMessage(content=self.llm_message.completion)])
self.conversation_message_task.save_message(self.llm_message)
@ -89,8 +93,6 @@ class LLMCallbackHandler(BaseCallbackHandler):
"""Do nothing."""
if isinstance(error, ConversationTaskStoppedException):
if self.conversation_message_task.streaming:
end_at = time.perf_counter()
self.llm_message.latency = end_at - self.start_at
self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
[PromptMessage(content=self.llm_message.completion)]
)

View File

@ -1,15 +1,33 @@
import enum
import logging
from typing import List, Dict, Optional, Any
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from pydantic import BaseModel
from core.model_providers.error import LLMBadRequestError
from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.moderation import openai_moderation
class SensitiveWordAvoidanceRule(BaseModel):
class Type(enum.Enum):
MODERATION = "moderation"
KEYWORDS = "keywords"
type: Type
canned_response: str = 'Your content violates our usage policy. Please revise and try again.'
extra_params: dict = {}
class SensitiveWordAvoidanceChain(Chain):
input_key: str = "input" #: :meta private:
output_key: str = "output" #: :meta private:
sensitive_words: List[str] = []
canned_response: str = None
model_instance: BaseLLM
sensitive_word_avoidance_rule: SensitiveWordAvoidanceRule
@property
def _chain_type(self) -> str:
@ -31,11 +49,24 @@ class SensitiveWordAvoidanceChain(Chain):
"""
return [self.output_key]
def _check_sensitive_word(self, text: str) -> str:
for word in self.sensitive_words:
def _check_sensitive_word(self, text: str) -> bool:
for word in self.sensitive_word_avoidance_rule.extra_params.get('sensitive_words', []):
if word in text:
return self.canned_response
return text
return False
return True
def _check_moderation(self, text: str) -> bool:
moderation_model_instance = ModelFactory.get_moderation_model(
tenant_id=self.model_instance.model_provider.provider.tenant_id,
model_provider_name='openai',
model_name=openai_moderation.DEFAULT_MODEL
)
try:
return moderation_model_instance.run(text=text)
except Exception as ex:
logging.exception(ex)
raise LLMBadRequestError('Rate limit exceeded, please try again later.')
def _call(
self,
@ -43,5 +74,19 @@ class SensitiveWordAvoidanceChain(Chain):
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
text = inputs[self.input_key]
output = self._check_sensitive_word(text)
return {self.output_key: output}
if self.sensitive_word_avoidance_rule.type == SensitiveWordAvoidanceRule.Type.KEYWORDS:
result = self._check_sensitive_word(text)
else:
result = self._check_moderation(text)
if not result:
raise SensitiveWordAvoidanceError(self.sensitive_word_avoidance_rule.canned_response)
return {self.output_key: text}
class SensitiveWordAvoidanceError(Exception):
def __init__(self, message):
super().__init__(message)
self.message = message

View File

@ -1,24 +1,22 @@
import json
import logging
import re
from typing import Optional, List, Union, Tuple
from typing import Optional, List, Union
from langchain.schema import BaseMessage
from requests.exceptions import ChunkedEncodingError
from core.agent.agent_executor import AgentExecuteResult, PlanningStrategy
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
from core.callback_handler.llm_callback_handler import LLMCallbackHandler
from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceError
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
from core.model_providers.error import LLMBadRequestError
from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
ReadOnlyConversationTokenDBBufferSharedMemory
from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.entity.message import PromptMessage, to_prompt_messages
from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.llm.base import BaseLLM
from core.orchestrator_rule_parser import OrchestratorRuleParser
from core.prompt.prompt_builder import PromptBuilder
from core.prompt.prompt_template import JinjaPromptTemplate
from core.prompt.prompts import MORE_LIKE_THIS_GENERATE_PROMPT
from models.dataset import DocumentSegment, Dataset, Document
from models.model import App, AppModelConfig, Account, Conversation, Message, EndUser
@ -79,28 +77,53 @@ class Completion:
app_model_config=app_model_config
)
# parse sensitive_word_avoidance_chain
chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain([chain_callback])
if sensitive_word_avoidance_chain:
query = sensitive_word_avoidance_chain.run(query)
# get agent executor
agent_executor = orchestrator_rule_parser.to_agent_executor(
conversation_message_task=conversation_message_task,
memory=memory,
rest_tokens=rest_tokens_for_context_and_memory,
chain_callback=chain_callback
)
# run agent executor
agent_execute_result = None
if agent_executor:
should_use_agent = agent_executor.should_use_agent(query)
if should_use_agent:
agent_execute_result = agent_executor.run(query)
# run the final llm
try:
# parse sensitive_word_avoidance_chain
chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain(
final_model_instance, [chain_callback])
if sensitive_word_avoidance_chain:
try:
query = sensitive_word_avoidance_chain.run(query)
except SensitiveWordAvoidanceError as ex:
cls.run_final_llm(
model_instance=final_model_instance,
mode=app.mode,
app_model_config=app_model_config,
query=query,
inputs=inputs,
agent_execute_result=None,
conversation_message_task=conversation_message_task,
memory=memory,
fake_response=ex.message
)
return
# get agent executor
agent_executor = orchestrator_rule_parser.to_agent_executor(
conversation_message_task=conversation_message_task,
memory=memory,
rest_tokens=rest_tokens_for_context_and_memory,
chain_callback=chain_callback,
retriever_from=retriever_from
)
# run agent executor
agent_execute_result = None
if agent_executor:
should_use_agent = agent_executor.should_use_agent(query)
if should_use_agent:
agent_execute_result = agent_executor.run(query)
# When no extra pre prompt is specified,
# the output of the agent can be used directly as the main output content without calling LLM again
fake_response = None
if not app_model_config.pre_prompt and agent_execute_result and agent_execute_result.output \
and agent_execute_result.strategy not in [PlanningStrategy.ROUTER,
PlanningStrategy.REACT_ROUTER]:
fake_response = agent_execute_result.output
# run the final llm
cls.run_final_llm(
model_instance=final_model_instance,
mode=app.mode,
@ -109,7 +132,8 @@ class Completion:
inputs=inputs,
agent_execute_result=agent_execute_result,
conversation_message_task=conversation_message_task,
memory=memory
memory=memory,
fake_response=fake_response
)
except ConversationTaskStoppedException:
return
@ -124,14 +148,8 @@ class Completion:
inputs: dict,
agent_execute_result: Optional[AgentExecuteResult],
conversation_message_task: ConversationMessageTask,
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]):
# When no extra pre prompt is specified,
# the output of the agent can be used directly as the main output content without calling LLM again
fake_response = None
if not app_model_config.pre_prompt and agent_execute_result and agent_execute_result.output \
and agent_execute_result.strategy not in [PlanningStrategy.ROUTER, PlanningStrategy.REACT_ROUTER]:
fake_response = agent_execute_result.output
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory],
fake_response: Optional[str]):
# get llm prompt
prompt_messages, stop_words = model_instance.get_prompt(
mode=mode,

View File

@ -1,5 +1,5 @@
import decimal
import json
import time
from typing import Optional, Union, List
from core.callback_handler.entity.agent_loop import AgentLoop
@ -23,6 +23,8 @@ class ConversationMessageTask:
def __init__(self, task_id: str, app: App, app_model_config: AppModelConfig, user: Account,
inputs: dict, query: str, streaming: bool, model_instance: BaseLLM,
conversation: Optional[Conversation] = None, is_override: bool = False):
self.start_at = time.perf_counter()
self.task_id = task_id
self.app = app
@ -61,6 +63,7 @@ class ConversationMessageTask:
)
def init(self):
override_model_configs = None
if self.is_override:
override_model_configs = self.app_model_config.to_dict()
@ -165,7 +168,7 @@ class ConversationMessageTask:
self.message.answer_tokens = answer_tokens
self.message.answer_unit_price = answer_unit_price
self.message.answer_price_unit = answer_price_unit
self.message.provider_response_latency = llm_message.latency
self.message.provider_response_latency = time.perf_counter() - self.start_at
self.message.total_price = total_price
db.session.commit()
@ -220,18 +223,18 @@ class ConversationMessageTask:
return message_agent_thought
def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instant: BaseLLM,
def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instance: BaseLLM,
agent_loop: AgentLoop):
agent_message_unit_price = agent_model_instant.get_tokens_unit_price(MessageType.HUMAN)
agent_message_price_unit = agent_model_instant.get_price_unit(MessageType.HUMAN)
agent_answer_unit_price = agent_model_instant.get_tokens_unit_price(MessageType.ASSISTANT)
agent_answer_price_unit = agent_model_instant.get_price_unit(MessageType.ASSISTANT)
agent_message_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.HUMAN)
agent_message_price_unit = agent_model_instance.get_price_unit(MessageType.HUMAN)
agent_answer_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.ASSISTANT)
agent_answer_price_unit = agent_model_instance.get_price_unit(MessageType.ASSISTANT)
loop_message_tokens = agent_loop.prompt_tokens
loop_answer_tokens = agent_loop.completion_tokens
loop_message_total_price = agent_model_instant.calc_tokens_price(loop_message_tokens, MessageType.HUMAN)
loop_answer_total_price = agent_model_instant.calc_tokens_price(loop_answer_tokens, MessageType.ASSISTANT)
loop_message_total_price = agent_model_instance.calc_tokens_price(loop_message_tokens, MessageType.HUMAN)
loop_answer_total_price = agent_model_instance.calc_tokens_price(loop_answer_tokens, MessageType.ASSISTANT)
loop_total_price = loop_message_total_price + loop_answer_total_price
message_agent_thought.observation = agent_loop.tool_output
@ -245,7 +248,7 @@ class ConversationMessageTask:
message_agent_thought.latency = agent_loop.latency
message_agent_thought.tokens = agent_loop.prompt_tokens + agent_loop.completion_tokens
message_agent_thought.total_price = loop_total_price
message_agent_thought.currency = agent_model_instant.get_currency()
message_agent_thought.currency = agent_model_instance.get_currency()
db.session.flush()
def on_dataset_query_end(self, dataset_query_obj: DatasetQueryObj):

View File

@ -0,0 +1,34 @@
import logging
import openai
from core.model_providers.error import LLMBadRequestError
from core.model_providers.providers.base import BaseModelProvider
from core.model_providers.providers.hosted import hosted_config, hosted_model_providers
from models.provider import ProviderType
def check_moderation(model_provider: BaseModelProvider, text: str) -> bool:
if hosted_config.moderation.enabled is True and hosted_model_providers.openai:
if model_provider.provider.provider_type == ProviderType.SYSTEM.value \
and model_provider.provider_name in hosted_config.moderation.providers:
# 2000 text per chunk
length = 2000
text_chunks = [text[i:i + length] for i in range(0, len(text), length)]
max_text_chunks = 32
chunks = [text_chunks[i:i + max_text_chunks] for i in range(0, len(text_chunks), max_text_chunks)]
for text_chunk in chunks:
try:
moderation_result = openai.Moderation.create(input=text_chunk,
api_key=hosted_model_providers.openai.api_key)
except Exception as ex:
logging.exception(ex)
raise LLMBadRequestError('Rate limit exceeded, please try again later.')
for result in moderation_result.results:
if result['flagged'] is True:
return False
return True

View File

@ -16,6 +16,10 @@ class BaseIndex(ABC):
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
raise NotImplementedError
@abstractmethod
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
raise NotImplementedError
@abstractmethod
def add_texts(self, texts: list[Document], **kwargs):
raise NotImplementedError
@ -28,6 +32,10 @@ class BaseIndex(ABC):
def delete_by_ids(self, ids: list[str]) -> None:
raise NotImplementedError
@abstractmethod
def delete_by_group_id(self, group_id: str) -> None:
raise NotImplementedError
@abstractmethod
def delete_by_document_id(self, document_id: str):
raise NotImplementedError

View File

@ -46,6 +46,32 @@ class KeywordTableIndex(BaseIndex):
return self
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
keyword_table_handler = JiebaKeywordTableHandler()
keyword_table = {}
for text in texts:
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
dataset_keyword_table = DatasetKeywordTable(
dataset_id=self.dataset.id,
keyword_table=json.dumps({
'__type__': 'keyword_table',
'__data__': {
"index_id": self.dataset.id,
"summary": None,
"table": {}
}
}, cls=SetEncoder)
)
db.session.add(dataset_keyword_table)
db.session.commit()
self._save_dataset_keyword_table(keyword_table)
return self
def add_texts(self, texts: list[Document], **kwargs):
keyword_table_handler = JiebaKeywordTableHandler()
@ -120,6 +146,12 @@ class KeywordTableIndex(BaseIndex):
db.session.delete(dataset_keyword_table)
db.session.commit()
def delete_by_group_id(self, group_id: str) -> None:
dataset_keyword_table = self.dataset.dataset_keyword_table
if dataset_keyword_table:
db.session.delete(dataset_keyword_table)
db.session.commit()
def _save_dataset_keyword_table(self, keyword_table):
keyword_table_dict = {
'__type__': 'keyword_table',

View File

@ -10,7 +10,7 @@ from weaviate import UnexpectedStatusCodeException
from core.index.base import BaseIndex
from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment
from models.dataset import Dataset, DocumentSegment, DatasetCollectionBinding
from models.dataset import Document as DatasetDocument
@ -110,6 +110,12 @@ class BaseVectorIndex(BaseIndex):
for node_id in ids:
vector_store.del_text(node_id)
def delete_by_group_id(self, group_id: str) -> None:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
vector_store.delete()
def delete(self) -> None:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
@ -243,3 +249,53 @@ class BaseVectorIndex(BaseIndex):
raise e
logging.info(f"Dataset {dataset.id} recreate successfully.")
def restore_dataset_in_one(self, dataset: Dataset, dataset_collection_binding: DatasetCollectionBinding):
logging.info(f"restore dataset in_one,_dataset {dataset.id}")
dataset_documents = db.session.query(DatasetDocument).filter(
DatasetDocument.dataset_id == dataset.id,
DatasetDocument.indexing_status == 'completed',
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
).all()
documents = []
for dataset_document in dataset_documents:
segments = db.session.query(DocumentSegment).filter(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.status == 'completed',
DocumentSegment.enabled == True
).all()
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
}
)
documents.append(document)
if documents:
try:
self.create_with_collection_name(documents, dataset_collection_binding.collection_name)
except Exception as e:
raise e
logging.info(f"Dataset {dataset.id} recreate successfully.")
def delete_original_collection(self, dataset: Dataset, dataset_collection_binding: DatasetCollectionBinding):
logging.info(f"delete original collection: {dataset.id}")
self.delete()
dataset.collection_binding_id = dataset_collection_binding.id
db.session.add(dataset)
db.session.commit()
logging.info(f"Dataset {dataset.id} recreate successfully.")

View File

@ -69,6 +69,19 @@ class MilvusVectorIndex(BaseVectorIndex):
return self
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
uuids = self._get_uuids(texts)
self._vector_store = WeaviateVectorStore.from_documents(
texts,
self._embeddings,
client=self._client,
index_name=collection_name,
uuids=uuids,
by_text=False
)
return self
def _get_vector_store(self) -> VectorStore:
"""Only for created index."""
if self._vector_store:

View File

@ -28,6 +28,7 @@ from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
from langchain.vectorstores import VectorStore
from langchain.vectorstores.utils import maximal_marginal_relevance
from qdrant_client.http.models import PayloadSchemaType
if TYPE_CHECKING:
from qdrant_client import grpc # noqa
@ -84,6 +85,7 @@ class Qdrant(VectorStore):
CONTENT_KEY = "page_content"
METADATA_KEY = "metadata"
GROUP_KEY = "group_id"
VECTOR_NAME = None
def __init__(
@ -93,9 +95,12 @@ class Qdrant(VectorStore):
embeddings: Optional[Embeddings] = None,
content_payload_key: str = CONTENT_KEY,
metadata_payload_key: str = METADATA_KEY,
group_payload_key: str = GROUP_KEY,
group_id: str = None,
distance_strategy: str = "COSINE",
vector_name: Optional[str] = VECTOR_NAME,
embedding_function: Optional[Callable] = None, # deprecated
is_new_collection: bool = False
):
"""Initialize with necessary components."""
try:
@ -129,7 +134,10 @@ class Qdrant(VectorStore):
self.collection_name = collection_name
self.content_payload_key = content_payload_key or self.CONTENT_KEY
self.metadata_payload_key = metadata_payload_key or self.METADATA_KEY
self.group_payload_key = group_payload_key or self.GROUP_KEY
self.vector_name = vector_name or self.VECTOR_NAME
self.group_id = group_id
self.is_new_collection= is_new_collection
if embedding_function is not None:
warnings.warn(
@ -170,6 +178,8 @@ class Qdrant(VectorStore):
batch_size:
How many vectors upload per-request.
Default: 64
group_id:
collection group
Returns:
List of ids from adding the texts into the vectorstore.
@ -182,7 +192,11 @@ class Qdrant(VectorStore):
collection_name=self.collection_name, points=points, **kwargs
)
added_ids.extend(batch_ids)
# if is new collection, create payload index on group_id
if self.is_new_collection:
self.client.create_payload_index(self.collection_name, self.group_payload_key,
field_schema=PayloadSchemaType.KEYWORD,
field_type=PayloadSchemaType.KEYWORD)
return added_ids
@sync_call_fallback
@ -970,6 +984,8 @@ class Qdrant(VectorStore):
distance_func: str = "Cosine",
content_payload_key: str = CONTENT_KEY,
metadata_payload_key: str = METADATA_KEY,
group_payload_key: str = GROUP_KEY,
group_id: str = None,
vector_name: Optional[str] = VECTOR_NAME,
batch_size: int = 64,
shard_number: Optional[int] = None,
@ -1034,6 +1050,11 @@ class Qdrant(VectorStore):
metadata_payload_key:
A payload key used to store the metadata of the document.
Default: "metadata"
group_payload_key:
A payload key used to store the content of the document.
Default: "group_id"
group_id:
collection group id
vector_name:
Name of the vector to be used internally in Qdrant.
Default: None
@ -1107,6 +1128,8 @@ class Qdrant(VectorStore):
distance_func,
content_payload_key,
metadata_payload_key,
group_payload_key,
group_id,
vector_name,
shard_number,
replication_factor,
@ -1321,6 +1344,8 @@ class Qdrant(VectorStore):
distance_func: str = "Cosine",
content_payload_key: str = CONTENT_KEY,
metadata_payload_key: str = METADATA_KEY,
group_payload_key: str = GROUP_KEY,
group_id: str = None,
vector_name: Optional[str] = VECTOR_NAME,
shard_number: Optional[int] = None,
replication_factor: Optional[int] = None,
@ -1350,6 +1375,7 @@ class Qdrant(VectorStore):
vector_size = len(partial_embeddings[0])
collection_name = collection_name or uuid.uuid4().hex
distance_func = distance_func.upper()
is_new_collection = False
client = qdrant_client.QdrantClient(
location=location,
url=url,
@ -1454,6 +1480,7 @@ class Qdrant(VectorStore):
init_from=init_from,
timeout=timeout, # type: ignore[arg-type]
)
is_new_collection = True
qdrant = cls(
client=client,
collection_name=collection_name,
@ -1462,6 +1489,9 @@ class Qdrant(VectorStore):
metadata_payload_key=metadata_payload_key,
distance_strategy=distance_func,
vector_name=vector_name,
group_id=group_id,
group_payload_key=group_payload_key,
is_new_collection=is_new_collection
)
return qdrant
@ -1516,6 +1546,8 @@ class Qdrant(VectorStore):
metadatas: Optional[List[dict]],
content_payload_key: str,
metadata_payload_key: str,
group_id: str,
group_payload_key: str
) -> List[dict]:
payloads = []
for i, text in enumerate(texts):
@ -1529,6 +1561,7 @@ class Qdrant(VectorStore):
{
content_payload_key: text,
metadata_payload_key: metadata,
group_payload_key: group_id
}
)
@ -1578,7 +1611,7 @@ class Qdrant(VectorStore):
else:
out.append(
rest.FieldCondition(
key=f"{self.metadata_payload_key}.{key}",
key=key,
match=rest.MatchValue(value=value),
)
)
@ -1654,6 +1687,7 @@ class Qdrant(VectorStore):
metadatas: Optional[List[dict]] = None,
ids: Optional[Sequence[str]] = None,
batch_size: int = 64,
group_id: Optional[str] = None,
) -> Generator[Tuple[List[str], List[rest.PointStruct]], None, None]:
from qdrant_client.http import models as rest
@ -1684,6 +1718,8 @@ class Qdrant(VectorStore):
batch_metadatas,
self.content_payload_key,
self.metadata_payload_key,
self.group_id,
self.group_payload_key
),
)
]

View File

@ -6,18 +6,20 @@ from langchain.embeddings.base import Embeddings
from langchain.schema import Document, BaseRetriever
from langchain.vectorstores import VectorStore
from pydantic import BaseModel
from qdrant_client.http.models import HnswConfigDiff
from core.index.base import BaseIndex
from core.index.vector_index.base import BaseVectorIndex
from core.vector_store.qdrant_vector_store import QdrantVectorStore
from models.dataset import Dataset
from extensions.ext_database import db
from models.dataset import Dataset, DatasetCollectionBinding
class QdrantConfig(BaseModel):
endpoint: str
api_key: Optional[str]
root_path: Optional[str]
def to_qdrant_params(self):
if self.endpoint and self.endpoint.startswith('path:'):
path = self.endpoint.replace('path:', '')
@ -43,16 +45,21 @@ class QdrantVectorIndex(BaseVectorIndex):
return 'qdrant'
def get_index_name(self, dataset: Dataset) -> str:
if self.dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
if not class_prefix.endswith('_Node'):
# original class_prefix
class_prefix += '_Node'
if dataset.collection_binding_id:
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \
one_or_none()
if dataset_collection_binding:
return dataset_collection_binding.collection_name
else:
raise ValueError('Dataset Collection Bindings is not exist!')
else:
if self.dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
return class_prefix
return class_prefix
dataset_id = dataset.id
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
dataset_id = dataset.id
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
def to_index_struct(self) -> dict:
return {
@ -68,6 +75,27 @@ class QdrantVectorIndex(BaseVectorIndex):
collection_name=self.get_index_name(self.dataset),
ids=uuids,
content_payload_key='page_content',
group_id=self.dataset.id,
group_payload_key='group_id',
hnsw_config=HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000,
max_indexing_threads=0, on_disk=False),
**self._client_config.to_qdrant_params()
)
return self
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
uuids = self._get_uuids(texts)
self._vector_store = QdrantVectorStore.from_documents(
texts,
self._embeddings,
collection_name=collection_name,
ids=uuids,
content_payload_key='page_content',
group_id=self.dataset.id,
group_payload_key='group_id',
hnsw_config=HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000,
max_indexing_threads=0, on_disk=False),
**self._client_config.to_qdrant_params()
)
@ -78,8 +106,6 @@ class QdrantVectorIndex(BaseVectorIndex):
if self._vector_store:
return self._vector_store
attributes = ['doc_id', 'dataset_id', 'document_id']
if self._is_origin():
attributes = ['doc_id']
client = qdrant_client.QdrantClient(
**self._client_config.to_qdrant_params()
)
@ -88,16 +114,15 @@ class QdrantVectorIndex(BaseVectorIndex):
client=client,
collection_name=self.get_index_name(self.dataset),
embeddings=self._embeddings,
content_payload_key='page_content'
content_payload_key='page_content',
group_id=self.dataset.id,
group_payload_key='group_id'
)
def _get_vector_store_class(self) -> type:
return QdrantVectorStore
def delete_by_document_id(self, document_id: str):
if self._is_origin():
self.recreate_dataset(self.dataset)
return
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
@ -114,9 +139,6 @@ class QdrantVectorIndex(BaseVectorIndex):
))
def delete_by_ids(self, ids: list[str]) -> None:
if self._is_origin():
self.recreate_dataset(self.dataset)
return
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
@ -132,6 +154,22 @@ class QdrantVectorIndex(BaseVectorIndex):
],
))
def delete_by_group_id(self, group_id: str) -> None:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
from qdrant_client.http import models
vector_store.del_texts(models.Filter(
must=[
models.FieldCondition(
key="group_id",
match=models.MatchValue(value=group_id),
),
],
))
def _is_origin(self):
if self.dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']

View File

@ -91,6 +91,20 @@ class WeaviateVectorIndex(BaseVectorIndex):
return self
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
uuids = self._get_uuids(texts)
self._vector_store = WeaviateVectorStore.from_documents(
texts,
self._embeddings,
client=self._client,
index_name=self.get_index_name(self.dataset),
uuids=uuids,
by_text=False
)
return self
def _get_vector_store(self) -> VectorStore:
"""Only for created index."""
if self._vector_store:

View File

@ -8,6 +8,7 @@ from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.embedding.base import BaseEmbedding
from core.model_providers.models.entity.model_params import ModelKwargs, ModelType
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.moderation.base import BaseModeration
from core.model_providers.models.speech2text.base import BaseSpeech2Text
from extensions.ext_database import db
from models.provider import TenantDefaultModel
@ -180,7 +181,7 @@ class ModelFactory:
def get_moderation_model(cls,
tenant_id: str,
model_provider_name: str,
model_name: str) -> Optional[BaseProviderModel]:
model_name: str) -> Optional[BaseModeration]:
"""
get moderation model.

View File

@ -45,6 +45,9 @@ class ModelProviderFactory:
elif provider_name == 'wenxin':
from core.model_providers.providers.wenxin_provider import WenxinProvider
return WenxinProvider
elif provider_name == 'zhipuai':
from core.model_providers.providers.zhipuai_provider import ZhipuAIProvider
return ZhipuAIProvider
elif provider_name == 'chatglm':
from core.model_providers.providers.chatglm_provider import ChatGLMProvider
return ChatGLMProvider

View File

@ -0,0 +1,22 @@
from core.model_providers.error import LLMBadRequestError
from core.model_providers.providers.base import BaseModelProvider
from core.model_providers.models.embedding.base import BaseEmbedding
from core.third_party.langchain.embeddings.zhipuai_embedding import ZhipuAIEmbeddings
class ZhipuAIEmbedding(BaseEmbedding):
def __init__(self, model_provider: BaseModelProvider, name: str):
credentials = model_provider.get_model_credentials(
model_name=name,
model_type=self.type
)
client = ZhipuAIEmbeddings(
model=name,
**credentials,
)
super().__init__(model_provider, client, name)
def handle_exceptions(self, ex: Exception) -> Exception:
return LLMBadRequestError(f"ZhipuAI embedding: {str(ex)}")

View File

@ -49,6 +49,7 @@ class KwargRule(Generic[T], BaseModel):
max: Optional[T] = None
default: Optional[T] = None
alias: Optional[str] = None
precision: Optional[int] = None
class ModelKwargsRules(BaseModel):

View File

@ -10,6 +10,7 @@ from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, BaseMessage, ChatGeneration
from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, DifyStdOutCallbackHandler
from core.helper import moderation
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult, to_prompt_messages
from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
@ -116,6 +117,15 @@ class BaseLLM(BaseProviderModel):
:param callbacks:
:return:
"""
moderation_result = moderation.check_moderation(
self.model_provider,
"\n".join([message.content for message in messages])
)
if not moderation_result:
kwargs['fake_response'] = "I apologize for any confusion, " \
"but I'm an AI assistant to be helpful, harmless, and honest."
if self.deduct_quota:
self.model_provider.check_quota_over_limit()

View File

@ -17,6 +17,7 @@ from core.model_providers.models.entity.model_params import ModelMode, ModelKwar
from models.provider import ProviderType, ProviderQuotaType
COMPLETION_MODELS = [
'gpt-3.5-turbo-instruct', # 4,096 tokens
'text-davinci-003', # 4,097 tokens
]
@ -31,6 +32,7 @@ MODEL_MAX_TOKENS = {
'gpt-4': 8192,
'gpt-4-32k': 32768,
'gpt-3.5-turbo': 4096,
'gpt-3.5-turbo-instruct': 8192,
'gpt-3.5-turbo-16k': 16384,
'text-davinci-003': 4097,
}

View File

@ -0,0 +1,61 @@
from typing import List, Optional, Any
from langchain.callbacks.manager import Callbacks
from langchain.schema import LLMResult
from core.model_providers.error import LLMBadRequestError
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
from core.third_party.langchain.llms.zhipuai_llm import ZhipuAIChatLLM
class ZhipuAIModel(BaseLLM):
model_mode: ModelMode = ModelMode.CHAT
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
return ZhipuAIChatLLM(
streaming=self.streaming,
callbacks=self.callbacks,
**self.credentials,
**provider_model_kwargs
)
def _run(self, messages: List[PromptMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs) -> LLMResult:
"""
run predict by prompt messages and stop words.
:param messages:
:param stop:
:param callbacks:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
get num tokens of prompt messages.
:param messages:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens_from_messages(prompts), 0)
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
for k, v in provider_model_kwargs.items():
if hasattr(self.client, k):
setattr(self.client, k, v)
def handle_exceptions(self, ex: Exception) -> Exception:
return LLMBadRequestError(f"ZhipuAI: {str(ex)}")
@property
def support_streaming(self):
return True

View File

@ -0,0 +1,29 @@
from abc import abstractmethod
from typing import Any
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.providers.base import BaseModelProvider
class BaseModeration(BaseProviderModel):
name: str
type: ModelType = ModelType.MODERATION
def __init__(self, model_provider: BaseModelProvider, client: Any, name: str):
super().__init__(model_provider, client)
self.name = name
def run(self, text: str) -> bool:
try:
return self._run(text)
except Exception as ex:
raise self.handle_exceptions(ex)
@abstractmethod
def _run(self, text: str) -> bool:
raise NotImplementedError
@abstractmethod
def handle_exceptions(self, ex: Exception) -> Exception:
raise NotImplementedError

View File

@ -4,29 +4,39 @@ import openai
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
LLMRateLimitError, LLMAuthorizationError
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.models.moderation.base import BaseModeration
from core.model_providers.providers.base import BaseModelProvider
DEFAULT_AUDIO_MODEL = 'whisper-1'
DEFAULT_MODEL = 'whisper-1'
class OpenAIModeration(BaseProviderModel):
type: ModelType = ModelType.MODERATION
class OpenAIModeration(BaseModeration):
def __init__(self, model_provider: BaseModelProvider, name: str):
super().__init__(model_provider, openai.Moderation)
super().__init__(model_provider, openai.Moderation, name)
def run(self, text):
def _run(self, text: str) -> bool:
credentials = self.model_provider.get_model_credentials(
model_name=DEFAULT_AUDIO_MODEL,
model_name=self.name,
model_type=self.type
)
try:
return self._client.create(input=text, api_key=credentials['openai_api_key'])
except Exception as ex:
raise self.handle_exceptions(ex)
# 2000 text per chunk
length = 2000
text_chunks = [text[i:i + length] for i in range(0, len(text), length)]
max_text_chunks = 32
chunks = [text_chunks[i:i + max_text_chunks] for i in range(0, len(text_chunks), max_text_chunks)]
for text_chunk in chunks:
moderation_result = self._client.create(input=text_chunk,
api_key=credentials['openai_api_key'])
for result in moderation_result.results:
if result['flagged'] is True:
return False
return True
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, openai.error.InvalidRequestError):

View File

@ -69,11 +69,11 @@ class AnthropicProvider(BaseModelProvider):
:return:
"""
return ModelKwargsRules(
temperature=KwargRule[float](min=0, max=1, default=1),
top_p=KwargRule[float](min=0, max=1, default=0.7),
temperature=KwargRule[float](min=0, max=1, default=1, precision=2),
top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](alias="max_tokens_to_sample", min=10, max=100000, default=256),
max_tokens=KwargRule[int](alias="max_tokens_to_sample", min=10, max=100000, default=256, precision=0),
)
@classmethod

View File

@ -164,14 +164,14 @@ class AzureOpenAIProvider(BaseModelProvider):
model_credentials = self.get_model_credentials(model_name, model_type)
return ModelKwargsRules(
temperature=KwargRule[float](min=0, max=2, default=1),
top_p=KwargRule[float](min=0, max=1, default=1),
presence_penalty=KwargRule[float](min=-2, max=2, default=0),
frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
temperature=KwargRule[float](min=0, max=2, default=1, precision=2),
top_p=KwargRule[float](min=0, max=1, default=1, precision=2),
presence_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2),
frequency_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2),
max_tokens=KwargRule[int](min=10, max=base_model_max_tokens.get(
model_credentials['base_model_name'],
4097
), default=16),
), default=16, precision=0),
)
@classmethod

View File

@ -64,11 +64,11 @@ class ChatGLMProvider(BaseModelProvider):
}
return ModelKwargsRules(
temperature=KwargRule[float](min=0, max=2, default=1),
top_p=KwargRule[float](min=0, max=1, default=0.7),
temperature=KwargRule[float](min=0, max=2, default=1, precision=2),
top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](alias='max_token', min=10, max=model_max_tokens.get(model_name), default=2048),
max_tokens=KwargRule[int](alias='max_token', min=10, max=model_max_tokens.get(model_name), default=2048, precision=0),
)
@classmethod

View File

@ -45,6 +45,18 @@ class HostedModelProviders(BaseModel):
hosted_model_providers = HostedModelProviders()
class HostedModerationConfig(BaseModel):
enabled: bool = False
providers: list[str] = []
class HostedConfig(BaseModel):
moderation = HostedModerationConfig()
hosted_config = HostedConfig()
def init_app(app: Flask):
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
langchain.verbose = True
@ -78,3 +90,9 @@ def init_app(app: Flask):
paid_min_quantity=app.config.get("HOSTED_ANTHROPIC_PAID_MIN_QUANTITY"),
paid_max_quantity=app.config.get("HOSTED_ANTHROPIC_PAID_MAX_QUANTITY"),
)
if app.config.get("HOSTED_MODERATION_ENABLED") and app.config.get("HOSTED_MODERATION_PROVIDERS"):
hosted_config.moderation = HostedModerationConfig(
enabled=app.config.get("HOSTED_MODERATION_ENABLED"),
providers=app.config.get("HOSTED_MODERATION_PROVIDERS").split(',')
)

View File

@ -47,11 +47,11 @@ class HuggingfaceHubProvider(BaseModelProvider):
:return:
"""
return ModelKwargsRules(
temperature=KwargRule[float](min=0, max=2, default=1),
top_p=KwargRule[float](min=0.01, max=0.99, default=0.7),
temperature=KwargRule[float](min=0, max=2, default=1, precision=2),
top_p=KwargRule[float](min=0.01, max=0.99, default=0.7, precision=2),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=200),
max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=200, precision=0),
)
@classmethod

View File

@ -52,9 +52,9 @@ class LocalAIProvider(BaseModelProvider):
:return:
"""
return ModelKwargsRules(
temperature=KwargRule[float](min=0, max=2, default=0.7),
top_p=KwargRule[float](min=0, max=1, default=1),
max_tokens=KwargRule[int](min=10, max=4097, default=16),
temperature=KwargRule[float](min=0, max=2, default=0.7, precision=2),
top_p=KwargRule[float](min=0, max=1, default=1, precision=2),
max_tokens=KwargRule[int](min=10, max=4097, default=16, precision=0),
)
@classmethod

View File

@ -74,11 +74,11 @@ class MinimaxProvider(BaseModelProvider):
}
return ModelKwargsRules(
temperature=KwargRule[float](min=0.01, max=1, default=0.9),
top_p=KwargRule[float](min=0, max=1, default=0.95),
temperature=KwargRule[float](min=0.01, max=1, default=0.9, precision=2),
top_p=KwargRule[float](min=0, max=1, default=0.95, precision=2),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name, 6144), default=1024),
max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name, 6144), default=1024, precision=0),
)
@classmethod

View File

@ -40,6 +40,10 @@ class OpenAIProvider(BaseModelProvider):
ModelFeature.AGENT_THOUGHT.value
]
},
{
'id': 'gpt-3.5-turbo-instruct',
'name': 'GPT-3.5-Turbo-Instruct',
},
{
'id': 'gpt-3.5-turbo-16k',
'name': 'gpt-3.5-turbo-16k',
@ -128,16 +132,17 @@ class OpenAIProvider(BaseModelProvider):
'gpt-4': 8192,
'gpt-4-32k': 32768,
'gpt-3.5-turbo': 4096,
'gpt-3.5-turbo-instruct': 8192,
'gpt-3.5-turbo-16k': 16384,
'text-davinci-003': 4097,
}
return ModelKwargsRules(
temperature=KwargRule[float](min=0, max=2, default=1),
top_p=KwargRule[float](min=0, max=1, default=1),
presence_penalty=KwargRule[float](min=-2, max=2, default=0),
frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name, 4097), default=16),
temperature=KwargRule[float](min=0, max=2, default=1, precision=2),
top_p=KwargRule[float](min=0, max=1, default=1, precision=2),
presence_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2),
frequency_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2),
max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name, 4097), default=16, precision=0),
)
@classmethod

View File

@ -45,11 +45,11 @@ class OpenLLMProvider(BaseModelProvider):
:return:
"""
return ModelKwargsRules(
temperature=KwargRule[float](min=0.01, max=2, default=1),
top_p=KwargRule[float](min=0, max=1, default=0.7),
presence_penalty=KwargRule[float](min=-2, max=2, default=0),
frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=128),
temperature=KwargRule[float](min=0.01, max=2, default=1, precision=2),
top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2),
presence_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2),
frequency_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2),
max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=128, precision=0),
)
@classmethod

View File

@ -72,6 +72,7 @@ class ReplicateProvider(BaseModelProvider):
min=float(value.get('minimum')) if value.get('minimum') is not None else None,
max=float(value.get('maximum')) if value.get('maximum') is not None else None,
default=float(value.get('default')) if value.get('default') is not None else None,
precision = 2
)
if key == 'temperature':
model_kwargs_rules.temperature = kwarg_rule
@ -84,6 +85,7 @@ class ReplicateProvider(BaseModelProvider):
min=int(value.get('minimum')) if value.get('minimum') is not None else 1,
max=int(value.get('maximum')) if value.get('maximum') is not None else 8000,
default=int(value.get('default')) if value.get('default') is not None else 500,
precision = 0
)
return model_kwargs_rules

View File

@ -62,11 +62,11 @@ class SparkProvider(BaseModelProvider):
:return:
"""
return ModelKwargsRules(
temperature=KwargRule[float](min=0, max=1, default=0.5),
temperature=KwargRule[float](min=0, max=1, default=0.5, precision=2),
top_p=KwargRule[float](enabled=False),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](min=10, max=4096, default=2048),
max_tokens=KwargRule[int](min=10, max=4096, default=2048, precision=0),
)
@classmethod

View File

@ -64,10 +64,10 @@ class TongyiProvider(BaseModelProvider):
return ModelKwargsRules(
temperature=KwargRule[float](enabled=False),
top_p=KwargRule[float](min=0, max=1, default=0.8),
top_p=KwargRule[float](min=0, max=1, default=0.8, precision=2),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name), default=1024),
max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name), default=1024, precision=0),
)
@classmethod

View File

@ -63,8 +63,8 @@ class WenxinProvider(BaseModelProvider):
"""
if model_name in ['ernie-bot', 'ernie-bot-turbo']:
return ModelKwargsRules(
temperature=KwargRule[float](min=0.01, max=1, default=0.95),
top_p=KwargRule[float](min=0.01, max=1, default=0.8),
temperature=KwargRule[float](min=0.01, max=1, default=0.95, precision=2),
top_p=KwargRule[float](min=0.01, max=1, default=0.8, precision=2),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](enabled=False),

View File

@ -2,6 +2,7 @@ import json
from typing import Type
import requests
from langchain.embeddings import XinferenceEmbeddings
from core.helper import encrypter
from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding
@ -52,27 +53,27 @@ class XinferenceProvider(BaseModelProvider):
credentials = self.get_model_credentials(model_name, model_type)
if credentials['model_format'] == "ggmlv3" and credentials["model_handle_type"] == "chatglm":
return ModelKwargsRules(
temperature=KwargRule[float](min=0.01, max=2, default=1),
top_p=KwargRule[float](min=0, max=1, default=0.7),
temperature=KwargRule[float](min=0.01, max=2, default=1, precision=2),
top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](min=10, max=4000, default=256),
max_tokens=KwargRule[int](min=10, max=4000, default=256, precision=0),
)
elif credentials['model_format'] == "ggmlv3":
return ModelKwargsRules(
temperature=KwargRule[float](min=0.01, max=2, default=1),
top_p=KwargRule[float](min=0, max=1, default=0.7),
presence_penalty=KwargRule[float](min=-2, max=2, default=0),
frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
max_tokens=KwargRule[int](min=10, max=4000, default=256),
temperature=KwargRule[float](min=0.01, max=2, default=1, precision=2),
top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2),
presence_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2),
frequency_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2),
max_tokens=KwargRule[int](min=10, max=4000, default=256, precision=0),
)
else:
return ModelKwargsRules(
temperature=KwargRule[float](min=0.01, max=2, default=1),
top_p=KwargRule[float](min=0, max=1, default=0.7),
temperature=KwargRule[float](min=0.01, max=2, default=1, precision=2),
top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](min=10, max=4000, default=256),
max_tokens=KwargRule[int](min=10, max=4000, default=256, precision=0),
)
@ -97,11 +98,18 @@ class XinferenceProvider(BaseModelProvider):
'model_uid': credentials['model_uid'],
}
llm = XinferenceLLM(
**credential_kwargs
)
if model_type == ModelType.TEXT_GENERATION:
llm = XinferenceLLM(
**credential_kwargs
)
llm("ping")
llm("ping")
elif model_type == ModelType.EMBEDDINGS:
embedding = XinferenceEmbeddings(
**credential_kwargs
)
embedding.embed_query("ping")
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@ -117,8 +125,9 @@ class XinferenceProvider(BaseModelProvider):
:param credentials:
:return:
"""
extra_credentials = cls._get_extra_credentials(credentials)
credentials.update(extra_credentials)
if model_type == ModelType.TEXT_GENERATION:
extra_credentials = cls._get_extra_credentials(credentials)
credentials.update(extra_credentials)
credentials['server_url'] = encrypter.encrypt_token(tenant_id, credentials['server_url'])

View File

@ -0,0 +1,176 @@
import json
from json import JSONDecodeError
from typing import Type
from langchain.schema import HumanMessage
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.embedding.zhipuai_embedding import ZhipuAIEmbedding
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
from core.model_providers.models.llm.zhipuai_model import ZhipuAIModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.third_party.langchain.llms.zhipuai_llm import ZhipuAIChatLLM
from models.provider import ProviderType, ProviderQuotaType
class ZhipuAIProvider(BaseModelProvider):
@property
def provider_name(self):
"""
Returns the name of a provider.
"""
return 'zhipuai'
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
if model_type == ModelType.TEXT_GENERATION:
return [
{
'id': 'chatglm_pro',
'name': 'chatglm_pro',
},
{
'id': 'chatglm_std',
'name': 'chatglm_std',
},
{
'id': 'chatglm_lite',
'name': 'chatglm_lite',
},
{
'id': 'chatglm_lite_32k',
'name': 'chatglm_lite_32k',
}
]
elif model_type == ModelType.EMBEDDINGS:
return [
{
'id': 'text_embedding',
'name': 'text_embedding',
}
]
else:
return []
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
:param model_type:
:return:
"""
if model_type == ModelType.TEXT_GENERATION:
model_class = ZhipuAIModel
elif model_type == ModelType.EMBEDDINGS:
model_class = ZhipuAIEmbedding
else:
raise NotImplementedError
return model_class
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
"""
get model parameter rules.
:param model_name:
:param model_type:
:return:
"""
return ModelKwargsRules(
temperature=KwargRule[float](min=0.01, max=1, default=0.95, precision=2),
top_p=KwargRule[float](min=0.1, max=0.9, default=0.8, precision=1),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](enabled=False),
)
@classmethod
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
"""
Validates the given credentials.
"""
if 'api_key' not in credentials:
raise CredentialsValidateFailedError('ZhipuAI api_key must be provided.')
try:
credential_kwargs = {
'api_key': credentials['api_key']
}
llm = ZhipuAIChatLLM(
temperature=0.01,
**credential_kwargs
)
llm([HumanMessage(content='ping')])
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@classmethod
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
credentials['api_key'] = encrypter.encrypt_token(tenant_id, credentials['api_key'])
return credentials
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
if self.provider.provider_type == ProviderType.CUSTOM.value \
or (self.provider.provider_type == ProviderType.SYSTEM.value
and self.provider.quota_type == ProviderQuotaType.FREE.value):
try:
credentials = json.loads(self.provider.encrypted_config)
except JSONDecodeError:
credentials = {
'api_key': None,
}
if credentials['api_key']:
credentials['api_key'] = encrypter.decrypt_token(
self.provider.tenant_id,
credentials['api_key']
)
if obfuscated:
credentials['api_key'] = encrypter.obfuscated_token(credentials['api_key'])
return credentials
else:
return {}
def should_deduct_quota(self):
return True
@classmethod
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
"""
check model credentials valid.
:param model_name:
:param model_type:
:param credentials:
"""
return
@classmethod
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
credentials: dict) -> dict:
"""
encrypt model credentials for save.
:param tenant_id:
:param model_name:
:param model_type:
:param credentials:
:return:
"""
return {}
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
"""
get credentials for llm use.
:param model_name:
:param model_type:
:param obfuscated:
:return:
"""
return self.get_provider_credentials(obfuscated)

View File

@ -6,6 +6,7 @@
"tongyi",
"spark",
"wenxin",
"zhipuai",
"chatglm",
"replicate",
"huggingface_hub",

View File

@ -30,6 +30,12 @@
"unit": "0.001",
"currency": "USD"
},
"gpt-3.5-turbo-instruct": {
"prompt": "0.0015",
"completion": "0.002",
"unit": "0.001",
"currency": "USD"
},
"gpt-3.5-turbo-16k": {
"prompt": "0.003",
"completion": "0.004",

View File

@ -0,0 +1,44 @@
{
"support_provider_types": [
"system",
"custom"
],
"system_config": {
"supported_quota_types": [
"free"
],
"quota_unit": "tokens"
},
"model_flexibility": "fixed",
"price_config": {
"chatglm_pro": {
"prompt": "0.01",
"completion": "0.01",
"unit": "0.001",
"currency": "RMB"
},
"chatglm_std": {
"prompt": "0.005",
"completion": "0.005",
"unit": "0.001",
"currency": "RMB"
},
"chatglm_lite": {
"prompt": "0.002",
"completion": "0.002",
"unit": "0.001",
"currency": "RMB"
},
"chatglm_lite_32k": {
"prompt": "0.0004",
"completion": "0.0004",
"unit": "0.001",
"currency": "RMB"
},
"text_embedding": {
"completion": "0",
"unit": "0.001",
"currency": "RMB"
}
}
}

View File

@ -1,6 +1,7 @@
import math
from typing import Optional
from flask import current_app
from langchain import WikipediaAPIWrapper
from langchain.callbacks.manager import Callbacks
from langchain.memory.chat_memory import BaseChatMemory
@ -12,7 +13,7 @@ from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGa
from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain
from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain, SensitiveWordAvoidanceRule
from core.conversation_message_task import ConversationMessageTask
from core.model_providers.error import ProviderTokenNotInitError
from core.model_providers.model_factory import ModelFactory
@ -26,6 +27,7 @@ from core.tool.web_reader_tool import WebReaderTool
from extensions.ext_database import db
from models.dataset import Dataset, DatasetProcessRule
from models.model import AppModelConfig
from models.provider import ProviderType
class OrchestratorRuleParser:
@ -63,7 +65,7 @@ class OrchestratorRuleParser:
# add agent callback to record agent thoughts
agent_callback = AgentLoopGatherCallbackHandler(
model_instant=agent_model_instance,
model_instance=agent_model_instance,
conversation_message_task=conversation_message_task
)
@ -123,23 +125,45 @@ class OrchestratorRuleParser:
return chain
def to_sensitive_word_avoidance_chain(self, callbacks: Callbacks = None, **kwargs) \
def to_sensitive_word_avoidance_chain(self, model_instance: BaseLLM, callbacks: Callbacks = None, **kwargs) \
-> Optional[SensitiveWordAvoidanceChain]:
"""
Convert app sensitive word avoidance config to chain
:param model_instance: model instance
:param callbacks: callbacks for the chain
:param kwargs:
:return:
"""
if not self.app_model_config.sensitive_word_avoidance_dict:
return None
sensitive_word_avoidance_rule = None
sensitive_word_avoidance_config = self.app_model_config.sensitive_word_avoidance_dict
sensitive_words = sensitive_word_avoidance_config.get("words", "")
if sensitive_word_avoidance_config.get("enabled", False) and sensitive_words:
if self.app_model_config.sensitive_word_avoidance_dict:
sensitive_word_avoidance_config = self.app_model_config.sensitive_word_avoidance_dict
if sensitive_word_avoidance_config.get("enabled", False):
if sensitive_word_avoidance_config.get('type') == 'moderation':
sensitive_word_avoidance_rule = SensitiveWordAvoidanceRule(
type=SensitiveWordAvoidanceRule.Type.MODERATION,
canned_response=sensitive_word_avoidance_config.get("canned_response")
if sensitive_word_avoidance_config.get("canned_response")
else 'Your content violates our usage policy. Please revise and try again.',
)
else:
sensitive_words = sensitive_word_avoidance_config.get("words", "")
if sensitive_words:
sensitive_word_avoidance_rule = SensitiveWordAvoidanceRule(
type=SensitiveWordAvoidanceRule.Type.KEYWORDS,
canned_response=sensitive_word_avoidance_config.get("canned_response")
if sensitive_word_avoidance_config.get("canned_response")
else 'Your content violates our usage policy. Please revise and try again.',
extra_params={
'sensitive_words': sensitive_words.split(','),
}
)
if sensitive_word_avoidance_rule:
return SensitiveWordAvoidanceChain(
sensitive_words=sensitive_words.split(","),
canned_response=sensitive_word_avoidance_config.get("canned_response", ''),
model_instance=model_instance,
sensitive_word_avoidance_rule=sensitive_word_avoidance_rule,
output_key="sensitive_word_avoidance_output",
callbacks=callbacks,
**kwargs

View File

@ -0,0 +1,64 @@
"""Wrapper around ZhipuAI embedding models."""
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Extra, root_validator
from langchain.embeddings.base import Embeddings
from langchain.utils import get_from_dict_or_env
from core.third_party.langchain.llms.zhipuai_llm import ZhipuModelAPI
class ZhipuAIEmbeddings(BaseModel, Embeddings):
"""Wrapper around ZhipuAI embedding models.
1024 dimensions.
"""
client: Any #: :meta private:
model: str
"""Model name to use."""
base_url: str = "https://open.bigmodel.cn/api/paas/v3/model-api"
api_key: Optional[str] = None
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["api_key"] = get_from_dict_or_env(
values, "api_key", "ZHIPUAI_API_KEY"
)
values['client'] = ZhipuModelAPI(api_key=values['api_key'], base_url=values['base_url'])
return values
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Call out to ZhipuAI's embedding endpoint.
Args:
texts: The list of texts to embed.
Returns:
List of embeddings, one for each text.
"""
embeddings = []
for text in texts:
response = self.client.invoke(model=self.model, prompt=text)
data = response["data"]
embeddings.append(data.get('embedding'))
return [list(map(float, e)) for e in embeddings]
def embed_query(self, text: str) -> List[float]:
"""Call out to ZhipuAI's embedding endpoint.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
return self.embed_documents([text])[0]

View File

@ -14,6 +14,9 @@ class EnhanceOpenAI(OpenAI):
max_retries: int = 1
"""Maximum number of retries to make when generating."""
def __new__(cls, **data: Any): # type: ignore
return super(EnhanceOpenAI, cls).__new__(cls)
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""

View File

@ -0,0 +1,315 @@
"""Wrapper around ZhipuAI APIs."""
from __future__ import annotations
import json
import logging
import posixpath
from typing import (
Any,
Dict,
List,
Optional, Iterator, Sequence,
)
import zhipuai
from langchain.chat_models.base import BaseChatModel
from langchain.schema import BaseMessage, ChatMessage, HumanMessage, AIMessage, SystemMessage
from langchain.schema.messages import AIMessageChunk
from langchain.schema.output import ChatResult, ChatGenerationChunk, ChatGeneration
from pydantic import Extra, root_validator, BaseModel
from langchain.callbacks.manager import (
CallbackManagerForLLMRun,
)
from langchain.utils import get_from_dict_or_env
from zhipuai.model_api.api import InvokeType
from zhipuai.utils import jwt_token
from zhipuai.utils.http_client import post, stream
from zhipuai.utils.sse_client import SSEClient
logger = logging.getLogger(__name__)
class ZhipuModelAPI(BaseModel):
base_url: str
api_key: str
api_timeout_seconds = 60
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
def invoke(self, **kwargs):
url = self._build_api_url(kwargs, InvokeType.SYNC)
response = post(url, self._generate_token(), kwargs, self.api_timeout_seconds)
if not response['success']:
raise ValueError(
f"Error Code: {response['code']}, Message: {response['msg']} "
)
return response
def sse_invoke(self, **kwargs):
url = self._build_api_url(kwargs, InvokeType.SSE)
data = stream(url, self._generate_token(), kwargs, self.api_timeout_seconds)
return SSEClient(data)
def _build_api_url(self, kwargs, *path):
if kwargs:
if "model" not in kwargs:
raise Exception("model param missed")
model = kwargs.pop("model")
else:
model = "-"
return posixpath.join(self.base_url, model, *path)
def _generate_token(self):
if not self.api_key:
raise Exception(
"api_key not provided, you could provide it."
)
try:
return jwt_token.generate_token(self.api_key)
except Exception:
raise ValueError(
f"Your api_key is invalid, please check it."
)
class ZhipuAIChatLLM(BaseChatModel):
"""Wrapper around ZhipuAI large language models.
To use, you should pass the api_key as a named parameter to the constructor.
Example:
.. code-block:: python
from core.third_party.langchain.llms.zhipuai import ZhipuAI
model = ZhipuAI(model="<model_name>", api_key="my-api-key")
"""
@property
def lc_secrets(self) -> Dict[str, str]:
return {"api_key": "API_KEY"}
@property
def lc_serializable(self) -> bool:
return True
client: Any = None #: :meta private:
model: str = "chatglm_lite"
"""Model name to use."""
temperature: float = 0.95
"""A non-negative float that tunes the degree of randomness in generation."""
top_p: float = 0.7
"""Total probability mass of tokens to consider at each step."""
streaming: bool = False
"""Whether to stream the response or return it all at once."""
api_key: Optional[str] = None
base_url: str = "https://open.bigmodel.cn/api/paas/v3/model-api"
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["api_key"] = get_from_dict_or_env(
values, "api_key", "ZHIPUAI_API_KEY"
)
if 'test' in values['base_url']:
values['model'] = 'chatglm_130b_test'
values['client'] = ZhipuModelAPI(api_key=values['api_key'], base_url=values['base_url'])
return values
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling OpenAI API."""
return {
"model": self.model,
"temperature": self.temperature,
"top_p": self.top_p
}
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
return self._default_params
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "zhipuai"
def _convert_message_to_dict(self, message: BaseMessage) -> dict:
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
elif isinstance(message, SystemMessage):
message_dict = {"role": "user", "content": message.content}
else:
raise ValueError(f"Got unknown type {message}")
return message_dict
def _convert_dict_to_message(self, _dict: Dict[str, Any]) -> BaseMessage:
role = _dict["role"]
if role == "user":
return HumanMessage(content=_dict["content"])
elif role == "assistant":
return AIMessage(content=_dict["content"])
elif role == "system":
return SystemMessage(content=_dict["content"])
else:
return ChatMessage(content=_dict["content"], role=role)
def _create_message_dicts(
self, messages: List[BaseMessage]
) -> List[Dict[str, Any]]:
dict_messages = []
for m in messages:
message = self._convert_message_to_dict(m)
if dict_messages:
previous_message = dict_messages[-1]
if previous_message['role'] == message['role']:
dict_messages[-1]['content'] += f"\n{message['content']}"
else:
dict_messages.append(message)
else:
dict_messages.append(message)
return dict_messages
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self.streaming:
generation: Optional[ChatGenerationChunk] = None
llm_output: Optional[Dict] = None
for chunk in self._stream(
messages=messages, stop=stop, run_manager=run_manager, **kwargs
):
if chunk.generation_info is not None \
and 'token_usage' in chunk.generation_info:
llm_output = {"token_usage": chunk.generation_info['token_usage'], "model_name": self.model}
continue
if generation is None:
generation = chunk
else:
generation += chunk
assert generation is not None
return ChatResult(generations=[generation], llm_output=llm_output)
else:
message_dicts = self._create_message_dicts(messages)
request = self._default_params
request["prompt"] = message_dicts
request.update(kwargs)
response = self.client.invoke(**request)
return self._create_chat_result(response)
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
message_dicts = self._create_message_dicts(messages)
request = self._default_params
request["prompt"] = message_dicts
request.update(kwargs)
for event in self.client.sse_invoke(incremental=True, **request).events():
if event.event == "add":
yield ChatGenerationChunk(message=AIMessageChunk(content=event.data))
if run_manager:
run_manager.on_llm_new_token(event.data)
elif event.event == "error" or event.event == "interrupted":
raise ValueError(
f"{event.data}"
)
elif event.event == "finish":
meta = json.loads(event.meta)
token_usage = meta['usage']
if token_usage is not None:
if 'prompt_tokens' not in token_usage:
token_usage['prompt_tokens'] = 0
if 'completion_tokens' not in token_usage:
token_usage['completion_tokens'] = token_usage['total_tokens']
yield ChatGenerationChunk(
message=AIMessageChunk(content=event.data),
generation_info=dict({'token_usage': token_usage})
)
def _create_chat_result(self, response: Dict[str, Any]) -> ChatResult:
data = response["data"]
generations = []
for res in data["choices"]:
message = self._convert_dict_to_message(res)
gen = ChatGeneration(
message=message
)
generations.append(gen)
token_usage = data.get("usage")
if token_usage is not None:
if 'prompt_tokens' not in token_usage:
token_usage['prompt_tokens'] = 0
if 'completion_tokens' not in token_usage:
token_usage['completion_tokens'] = token_usage['total_tokens']
llm_output = {"token_usage": token_usage, "model_name": self.model}
return ChatResult(generations=generations, llm_output=llm_output)
# def get_token_ids(self, text: str) -> List[int]:
# """Return the ordered ids of the tokens in a text.
#
# Args:
# text: The string input to tokenize.
#
# Returns:
# A list of ids corresponding to the tokens in the text, in order they occur
# in the text.
# """
# from core.third_party.transformers.Token import ChatGLMTokenizer
#
# tokenizer = ChatGLMTokenizer.from_pretrained("THUDM/chatglm2-6b")
# return tokenizer.encode(text)
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
"""Get the number of tokens in the messages.
Useful for checking if an input will fit in a model's context window.
Args:
messages: The message inputs to tokenize.
Returns:
The sum of the number of tokens across the messages.
"""
return sum([self.get_num_tokens(m.content) for m in messages])
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
overall_token_usage: dict = {}
for output in llm_outputs:
if output is None:
# Happens in streaming
continue
token_usage = output["token_usage"]
for k, v in token_usage.items():
if k in overall_token_usage:
overall_token_usage[k] += v
else:
overall_token_usage[k] = v
return {"token_usage": overall_token_usage, "model_name": self.model}

View File

@ -33,7 +33,6 @@ class DatasetRetrieverTool(BaseTool):
return_resource: str
retriever_from: str
@classmethod
def from_dataset(cls, dataset: Dataset, **kwargs):
description = dataset.description
@ -94,7 +93,10 @@ class DatasetRetrieverTool(BaseTool):
query,
search_type='similarity_score_threshold',
search_kwargs={
'k': self.k
'k': self.k,
'filter': {
'group_id': [dataset.id]
}
}
)
else:

View File

@ -46,6 +46,11 @@ class QdrantVectorStore(Qdrant):
self.client.delete_collection(collection_name=self.collection_name)
def delete_group(self):
self._reload_if_needed()
self.client.delete_collection(collection_name=self.collection_name)
@classmethod
def _document_from_scored_point(
cls,

View File

@ -0,0 +1,47 @@
"""add_dataset_collection_binding
Revision ID: 6e2cfb077b04
Revises: 77e83833755c
Create Date: 2023-09-13 22:16:48.027810
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '6e2cfb077b04'
down_revision = '77e83833755c'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('dataset_collection_bindings',
sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('provider_name', sa.String(length=40), nullable=False),
sa.Column('model_name', sa.String(length=40), nullable=False),
sa.Column('collection_name', sa.String(length=64), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.PrimaryKeyConstraint('id', name='dataset_collection_bindings_pkey')
)
with op.batch_alter_table('dataset_collection_bindings', schema=None) as batch_op:
batch_op.create_index('provider_model_name_idx', ['provider_name', 'model_name'], unique=False)
with op.batch_alter_table('datasets', schema=None) as batch_op:
batch_op.add_column(sa.Column('collection_binding_id', postgresql.UUID(), nullable=True))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('datasets', schema=None) as batch_op:
batch_op.drop_column('collection_binding_id')
with op.batch_alter_table('dataset_collection_bindings', schema=None) as batch_op:
batch_op.drop_index('provider_model_name_idx')
op.drop_table('dataset_collection_bindings')
# ### end Alembic commands ###

View File

@ -38,6 +38,8 @@ class Dataset(db.Model):
server_default=db.text('CURRENT_TIMESTAMP(0)'))
embedding_model = db.Column(db.String(255), nullable=True)
embedding_model_provider = db.Column(db.String(255), nullable=True)
collection_binding_id = db.Column(UUID, nullable=True)
@property
def dataset_keyword_table(self):
@ -445,3 +447,19 @@ class Embedding(db.Model):
def get_embedding(self) -> list[float]:
return pickle.loads(self.embedding)
class DatasetCollectionBinding(db.Model):
__tablename__ = 'dataset_collection_bindings'
__table_args__ = (
db.PrimaryKeyConstraint('id', name='dataset_collection_bindings_pkey'),
db.Index('provider_model_name_idx', 'provider_name', 'model_name')
)
id = db.Column(UUID, primary_key=True, server_default=db.text('uuid_generate_v4()'))
provider_name = db.Column(db.String(40), nullable=False)
model_name = db.Column(db.String(40), nullable=False)
collection_name = db.Column(db.String(64), nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))

View File

@ -147,7 +147,7 @@ class AppModelConfig(db.Model):
"suggested_questions": self.suggested_questions_list,
"suggested_questions_after_answer": self.suggested_questions_after_answer_dict,
"speech_to_text": self.speech_to_text_dict,
"retriever_resource": self.retriever_resource,
"retriever_resource": self.retriever_resource_dict,
"more_like_this": self.more_like_this_dict,
"sensitive_word_avoidance": self.sensitive_word_avoidance_dict,
"model": self.model_dict,

View File

@ -11,7 +11,7 @@ flask-cors==3.0.10
gunicorn~=21.2.0
gevent~=22.10.2
langchain==0.0.250
openai~=0.27.8
openai~=0.28.0
psycopg2-binary~=2.9.6
pycryptodome==3.17
python-dotenv==1.0.0
@ -19,7 +19,7 @@ pytest~=7.3.1
pytest-mock~=3.11.1
tiktoken==0.3.3
Authlib==1.2.0
boto3~=1.26.123
boto3==1.28.17
tenacity==8.2.2
cachetools~=5.3.0
weaviate-client~=3.21.0
@ -49,5 +49,6 @@ huggingface_hub~=0.16.4
transformers~=4.31.0
stripe~=5.5.0
pandas==1.5.3
xinference==0.2.1
safetensors==0.3.2
xinference==0.4.2
safetensors==0.3.2
zhipuai==1.0.7

View File

@ -408,7 +408,6 @@ class RegisterService:
to=email,
token=token,
inviter_name=inviter.name if inviter else 'Dify',
workspace_id=tenant.id,
workspace_name=tenant.name,
)

View File

@ -366,6 +366,7 @@ class CompletionService:
generate_channel = list(pubsub.channels.keys())[0].decode('utf-8')
if not streaming:
try:
message_result = {}
for message in pubsub.listen():
if message["type"] == "message":
result = message["data"].decode('utf-8')
@ -373,7 +374,10 @@ class CompletionService:
if result.get('error'):
cls.handle_error(result)
if result['event'] == 'message' and 'data' in result:
return cls.get_message_response_data(result.get('data'))
message_result['message'] = result.get('data')
if result['event'] == 'message_end' and 'data' in result:
message_result['message_end'] = result.get('data')
return cls.get_blocking_message_response_data(message_result)
except ValueError as e:
if e.args[0] != "I/O operation on closed file.": # ignore this error
raise CompletionStoppedError()
@ -399,7 +403,6 @@ class CompletionService:
if event == "end":
logging.debug("{} finished".format(generate_channel))
break
if event == 'message':
yield "data: " + json.dumps(cls.get_message_response_data(result.get('data'))) + "\n\n"
elif event == 'chain':
@ -441,6 +444,27 @@ class CompletionService:
return response_data
@classmethod
def get_blocking_message_response_data(cls, data: dict):
message = data.get('message')
response_data = {
'event': 'message',
'task_id': message.get('task_id'),
'id': message.get('message_id'),
'answer': message.get('text'),
'metadata': {},
'created_at': int(time.time())
}
if message.get('mode') == 'chat':
response_data['conversation_id'] = message.get('conversation_id')
if 'message_end' in data:
message_end = data.get('message_end')
if 'retriever_resources' in message_end:
response_data['metadata']['retriever_resources'] = message_end.get('retriever_resources')
return response_data
@classmethod
def get_message_end_data(cls, data: dict):
response_data = {

View File

@ -20,7 +20,8 @@ from events.document_event import document_was_deleted
from extensions.ext_database import db
from libs import helper
from models.account import Account
from models.dataset import Dataset, Document, DatasetQuery, DatasetProcessRule, AppDatasetJoin, DocumentSegment
from models.dataset import Dataset, Document, DatasetQuery, DatasetProcessRule, AppDatasetJoin, DocumentSegment, \
DatasetCollectionBinding
from models.model import UploadFile
from models.source import DataSourceBinding
from services.errors.account import NoPermissionError
@ -147,6 +148,7 @@ class DatasetService:
action = 'remove'
filtered_data['embedding_model'] = None
filtered_data['embedding_model_provider'] = None
filtered_data['collection_binding_id'] = None
elif data['indexing_technique'] == 'high_quality':
action = 'add'
# get embedding model setting
@ -156,6 +158,11 @@ class DatasetService:
)
filtered_data['embedding_model'] = embedding_model.name
filtered_data['embedding_model_provider'] = embedding_model.model_provider.provider_name
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_model.model_provider.provider_name,
embedding_model.name
)
filtered_data['collection_binding_id'] = dataset_collection_binding.id
except LLMBadRequestError:
raise ValueError(
f"No Embedding Model available. Please configure a valid provider "
@ -464,7 +471,11 @@ class DocumentService:
)
dataset.embedding_model = embedding_model.name
dataset.embedding_model_provider = embedding_model.model_provider.provider_name
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_model.model_provider.provider_name,
embedding_model.name
)
dataset.collection_binding_id = dataset_collection_binding.id
documents = []
batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999))
@ -720,10 +731,16 @@ class DocumentService:
if total_count > tenant_document_count:
raise ValueError(f"All your documents have overed limit {tenant_document_count}.")
embedding_model = None
dataset_collection_binding_id = None
if document_data['indexing_technique'] == 'high_quality':
embedding_model = ModelFactory.get_embedding_model(
tenant_id=tenant_id
)
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_model.model_provider.provider_name,
embedding_model.name
)
dataset_collection_binding_id = dataset_collection_binding.id
# save dataset
dataset = Dataset(
tenant_id=tenant_id,
@ -732,7 +749,8 @@ class DocumentService:
indexing_technique=document_data["indexing_technique"],
created_by=account.id,
embedding_model=embedding_model.name if embedding_model else None,
embedding_model_provider=embedding_model.model_provider.provider_name if embedding_model else None
embedding_model_provider=embedding_model.model_provider.provider_name if embedding_model else None,
collection_binding_id=dataset_collection_binding_id
)
db.session.add(dataset)
@ -1069,3 +1087,23 @@ class SegmentService:
delete_segment_from_index_task.delay(segment.id, segment.index_node_id, dataset.id, document.id)
db.session.delete(segment)
db.session.commit()
class DatasetCollectionBindingService:
@classmethod
def get_dataset_collection_binding(cls, provider_name: str, model_name: str) -> DatasetCollectionBinding:
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
filter(DatasetCollectionBinding.provider_name == provider_name,
DatasetCollectionBinding.model_name == model_name). \
order_by(DatasetCollectionBinding.created_at). \
first()
if not dataset_collection_binding:
dataset_collection_binding = DatasetCollectionBinding(
provider_name=provider_name,
model_name=model_name,
collection_name="Vector_index_" + str(uuid.uuid4()).replace("-", "_") + '_Node'
)
db.session.add(dataset_collection_binding)
db.session.flush()
return dataset_collection_binding

View File

@ -47,7 +47,10 @@ class HitTestingService:
query,
search_type='similarity_score_threshold',
search_kwargs={
'k': 10
'k': 10,
'filter': {
'group_id': [dataset.id]
}
}
)
end = time.perf_counter()

View File

@ -518,7 +518,8 @@ class ProviderService:
def free_quota_submit(self, tenant_id: str, provider_name: str):
api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY")
api_url = os.environ.get("FREE_QUOTA_APPLY_URL")
api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL")
api_url = api_base_url + '/api/v1/providers/apply'
headers = {
'Content-Type': 'application/json',
@ -546,3 +547,42 @@ class ProviderService:
'type': rst['type'],
'result': 'success'
}
def free_quota_qualification_verify(self, tenant_id: str, provider_name: str, token: Optional[str]):
api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY")
api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL")
api_url = api_base_url + '/api/v1/providers/qualification-verify'
headers = {
'Content-Type': 'application/json',
'Authorization': f"Bearer {api_key}"
}
json_data = {'workspace_id': tenant_id, 'provider_name': provider_name}
if token:
json_data['token'] = token
response = requests.post(api_url, headers=headers,
json=json_data)
if not response.ok:
logging.error(f"Request FREE QUOTA APPLY SERVER Error: {response.status_code} ")
raise ValueError(f"Error: {response.status_code} ")
rst = response.json()
if rst["code"] != 'success':
raise ValueError(
f"error: {rst['message']}"
)
data = rst['data']
if data['qualified'] is True:
return {
'result': 'success',
'provider_name': provider_name,
'flag': True
}
else:
return {
'result': 'success',
'provider_name': provider_name,
'flag': False,
'reason': data['reason']
}

View File

@ -9,16 +9,15 @@ from extensions.ext_mail import mail
@shared_task(queue='mail')
def send_invite_member_mail_task(to: str, token: str, inviter_name: str, workspace_id: str, workspace_name: str):
def send_invite_member_mail_task(to: str, token: str, inviter_name: str, workspace_name: str):
"""
Async Send invite member mail
:param to
:param token
:param inviter_name
:param workspace_id
:param workspace_name
Usage: send_invite_member_mail_task.delay(to, token, inviter_name, workspace_id, workspace_name)
Usage: send_invite_member_mail_task.delay(to, token, inviter_name, workspace_name)
"""
if not mail.is_inited():
return
@ -36,12 +35,7 @@ def send_invite_member_mail_task(to: str, token: str, inviter_name: str, workspa
<p>Click <a href="{url}">here</a> to join.</p>
<p>Thanks,</p>
<p>Dify Team</p>""".format(inviter_name=inviter_name, workspace_name=workspace_name,
url='{}/activate?workspace_id={}&email={}&token={}'.format(
current_app.config.get("CONSOLE_WEB_URL"),
workspace_id,
to,
token)
)
url=f'{current_app.config.get("CONSOLE_WEB_URL")}/activate?token={token}')
)
end_at = time.perf_counter()

View File

@ -31,6 +31,9 @@ TONGYI_DASHSCOPE_API_KEY=
WENXIN_API_KEY=
WENXIN_SECRET_KEY=
# ZhipuAI Credentials
ZHIPUAI_API_KEY=
# ChatGLM Credentials
CHATGLM_API_BASE=

View File

@ -0,0 +1,50 @@
import json
import os
from unittest.mock import patch
from core.model_providers.models.embedding.zhipuai_embedding import ZhipuAIEmbedding
from core.model_providers.providers.zhipuai_provider import ZhipuAIProvider
from models.provider import Provider, ProviderType
def get_mock_provider(valid_api_key):
return Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='zhipuai',
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({
'api_key': valid_api_key
}),
is_valid=True,
)
def get_mock_embedding_model():
model_name = 'text_embedding'
valid_api_key = os.environ['ZHIPUAI_API_KEY']
provider = ZhipuAIProvider(provider=get_mock_provider(valid_api_key))
return ZhipuAIEmbedding(
model_provider=provider,
name=model_name
)
def decrypt_side_effect(tenant_id, encrypted_api_key):
return encrypted_api_key
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_embedding(mock_decrypt):
embedding_model = get_mock_embedding_model()
rst = embedding_model.client.embed_query('test')
assert isinstance(rst, list)
assert len(rst) == 1024
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_doc_embedding(mock_decrypt):
embedding_model = get_mock_embedding_model()
rst = embedding_model.client.embed_documents(['test', 'test2'])
assert isinstance(rst, list)
assert len(rst[0]) == 1024

View File

@ -42,7 +42,7 @@ def decrypt_side_effect(tenant_id, encrypted_openai_api_key):
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_num_tokens(mock_decrypt):
openai_model = get_mock_openai_model('text-davinci-003')
openai_model = get_mock_openai_model('gpt-3.5-turbo-instruct')
rst = openai_model.get_num_tokens([PromptMessage(content='you are a kindness Assistant.')])
assert rst == 6
@ -61,7 +61,7 @@ def test_chat_get_num_tokens(mock_decrypt):
def test_run(mock_decrypt, mocker):
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
openai_model = get_mock_openai_model('text-davinci-003')
openai_model = get_mock_openai_model('gpt-3.5-turbo-instruct')
rst = openai_model.run(
[PromptMessage(content='Human: Are you Human? you MUST only answer `y` or `n`? \nAssistant: ')],
stop=['\nHuman:'],

View File

@ -0,0 +1,79 @@
import json
import os
from unittest.mock import patch
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelKwargs
from core.model_providers.models.llm.zhipuai_model import ZhipuAIModel
from core.model_providers.providers.zhipuai_provider import ZhipuAIProvider
from models.provider import Provider, ProviderType
def get_mock_provider(valid_api_key):
return Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='zhipuai',
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({
'api_key': valid_api_key
}),
is_valid=True,
)
def get_mock_model(model_name: str, streaming: bool = False):
model_kwargs = ModelKwargs(
temperature=0.01,
)
valid_api_key = os.environ['ZHIPUAI_API_KEY']
model_provider = ZhipuAIProvider(provider=get_mock_provider(valid_api_key))
return ZhipuAIModel(
model_provider=model_provider,
name=model_name,
model_kwargs=model_kwargs,
streaming=streaming
)
def decrypt_side_effect(tenant_id, encrypted_api_key):
return encrypted_api_key
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_chat_get_num_tokens(mock_decrypt):
model = get_mock_model('chatglm_lite')
rst = model.get_num_tokens([
PromptMessage(type=MessageType.SYSTEM, content='you are a kindness Assistant.'),
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
])
assert rst > 0
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_chat_run(mock_decrypt, mocker):
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
model = get_mock_model('chatglm_lite')
messages = [
PromptMessage(type=MessageType.HUMAN, content='Are you Human? you MUST only answer `y` or `n`?')
]
rst = model.run(
messages,
)
assert len(rst.content) > 0
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_chat_stream_run(mock_decrypt, mocker):
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
model = get_mock_model('chatglm_lite', streaming=True)
messages = [
PromptMessage(type=MessageType.HUMAN, content='Are you Human? you MUST only answer `y` or `n`?')
]
rst = model.run(
messages
)
assert len(rst.content) > 0

View File

@ -2,7 +2,7 @@ import json
import os
from unittest.mock import patch
from core.model_providers.models.moderation.openai_moderation import OpenAIModeration, DEFAULT_AUDIO_MODEL
from core.model_providers.models.moderation.openai_moderation import OpenAIModeration, DEFAULT_MODEL
from core.model_providers.providers.openai_provider import OpenAIProvider
from models.provider import Provider, ProviderType
@ -23,7 +23,7 @@ def get_mock_openai_moderation_model():
openai_provider = OpenAIProvider(provider=get_mock_provider(valid_openai_api_key))
return OpenAIModeration(
model_provider=openai_provider,
name=DEFAULT_AUDIO_MODEL
name=DEFAULT_MODEL
)
@ -36,5 +36,4 @@ def test_run(mock_decrypt):
model = get_mock_openai_moderation_model()
rst = model.run('hello')
assert isinstance(rst, dict)
assert 'id' in rst
assert rst is True

View File

@ -39,7 +39,7 @@ def test_is_provider_credentials_valid_or_raise_invalid():
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise({})
credential = VALIDATE_CREDENTIAL.copy()
credential['api_key'] = 'invalid_key'
del credential['api_key']
# raise CredentialsValidateFailedError if api_key is invalid
with pytest.raises(CredentialsValidateFailedError):

View File

@ -0,0 +1,88 @@
import pytest
from unittest.mock import patch
import json
from langchain.schema import ChatResult, ChatGeneration, AIMessage
from core.model_providers.providers.base import CredentialsValidateFailedError
from core.model_providers.providers.zhipuai_provider import ZhipuAIProvider
from models.provider import ProviderType, Provider
PROVIDER_NAME = 'zhipuai'
MODEL_PROVIDER_CLASS = ZhipuAIProvider
VALIDATE_CREDENTIAL = {
'api_key': 'valid_key',
}
def encrypt_side_effect(tenant_id, encrypt_key):
return f'encrypted_{encrypt_key}'
def decrypt_side_effect(tenant_id, encrypted_key):
return encrypted_key.replace('encrypted_', '')
def test_is_provider_credentials_valid_or_raise_valid(mocker):
mocker.patch('core.third_party.langchain.llms.zhipuai_llm.ZhipuAIChatLLM._generate',
return_value=ChatResult(generations=[ChatGeneration(message=AIMessage(content='abc'))]))
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(VALIDATE_CREDENTIAL)
def test_is_provider_credentials_valid_or_raise_invalid():
# raise CredentialsValidateFailedError if api_key is not in credentials
with pytest.raises(CredentialsValidateFailedError):
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise({})
credential = VALIDATE_CREDENTIAL.copy()
credential['api_key'] = 'invalid_key'
# raise CredentialsValidateFailedError if api_key is invalid
with pytest.raises(CredentialsValidateFailedError):
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(credential)
@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect)
def test_encrypt_credentials(mock_encrypt):
result = MODEL_PROVIDER_CLASS.encrypt_provider_credentials('tenant_id', VALIDATE_CREDENTIAL.copy())
assert result['api_key'] == f'encrypted_{VALIDATE_CREDENTIAL["api_key"]}'
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_credentials_custom(mock_decrypt):
encrypted_credential = VALIDATE_CREDENTIAL.copy()
encrypted_credential['api_key'] = 'encrypted_' + encrypted_credential['api_key']
provider = Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name=PROVIDER_NAME,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps(encrypted_credential),
is_valid=True,
)
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
result = model_provider.get_provider_credentials()
assert result['api_key'] == 'valid_key'
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_credentials_obfuscated(mock_decrypt):
encrypted_credential = VALIDATE_CREDENTIAL.copy()
encrypted_credential['api_key'] = 'encrypted_' + encrypted_credential['api_key']
provider = Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name=PROVIDER_NAME,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps(encrypted_credential),
is_valid=True,
)
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
result = model_provider.get_provider_credentials(obfuscated=True)
middle_token = result['api_key'][6:-2]
assert len(middle_token) == max(len(VALIDATE_CREDENTIAL['api_key']) - 8, 0)
assert all(char == '*' for char in middle_token)

View File

@ -2,7 +2,7 @@ version: '3.1'
services:
# API service
api:
image: langgenius/dify-api:0.3.22
image: langgenius/dify-api:0.3.23
restart: always
environment:
# Startup mode, 'api' starts the API server.
@ -124,7 +124,7 @@ services:
# worker service
# The Celery worker for processing the queue.
worker:
image: langgenius/dify-api:0.3.22
image: langgenius/dify-api:0.3.23
restart: always
environment:
# Startup mode, 'worker' starts the Celery worker for processing the queue.
@ -176,7 +176,7 @@ services:
# Frontend web application.
web:
image: langgenius/dify-web:0.3.22
image: langgenius/dify-web:0.3.23
restart: always
environment:
EDITION: SELF_HOSTED

View File

@ -1,7 +1,6 @@
'use client'
import { useCallback, useEffect, useRef, useState } from 'react'
import { useRouter, useSearchParams } from 'next/navigation'
import { useEffect, useRef } from 'react'
import useSWRInfinite from 'swr/infinite'
import { useTranslation } from 'react-i18next'
import AppCard from './AppCard'
@ -10,8 +9,7 @@ import type { AppListResponse } from '@/models/app'
import { fetchAppList } from '@/service/apps'
import { useAppContext } from '@/context/app-context'
import { NEED_REFRESH_APP_LIST_KEY } from '@/config'
import { ProviderEnum } from '@/app/components/header/account-setting/model-page/declarations'
import Confirm from '@/app/components/base/confirm/common'
import { CheckModal } from '@/hooks/use-pay'
const getKey = (pageIndex: number, previousPageData: AppListResponse) => {
if (!pageIndex || previousPageData.has_more)
@ -24,16 +22,6 @@ const Apps = () => {
const { isCurrentWorkspaceManager } = useAppContext()
const { data, isLoading, setSize, mutate } = useSWRInfinite(getKey, fetchAppList, { revalidateFirstPage: false })
const anchorRef = useRef<HTMLDivElement>(null)
const searchParams = useSearchParams()
const router = useRouter()
const payProviderName = searchParams.get('provider_name')
const payStatus = searchParams.get('payment_result')
const [showPayStatusModal, setShowPayStatusModal] = useState(false)
const handleCancelShowPayStatusModal = useCallback(() => {
setShowPayStatusModal(false)
router.replace('/', { forceOptimisticNavigation: false })
}, [router])
useEffect(() => {
document.title = `${t('app.title')} - Dify`
@ -41,9 +29,7 @@ const Apps = () => {
localStorage.removeItem(NEED_REFRESH_APP_LIST_KEY)
mutate()
}
if (payProviderName === ProviderEnum.anthropic && (payStatus === 'succeeded' || payStatus === 'cancelled'))
setShowPayStatusModal(true)
}, [mutate, payProviderName, payStatus, t])
}, [mutate, t])
useEffect(() => {
let observer: IntersectionObserver | undefined
@ -64,27 +50,7 @@ const Apps = () => {
{data?.map(({ data: apps }) => apps.map(app => (
<AppCard key={app.id} app={app} onRefresh={mutate} />
)))}
{
showPayStatusModal && (
<Confirm
isShow
onCancel={handleCancelShowPayStatusModal}
onConfirm={handleCancelShowPayStatusModal}
type={
payStatus === 'succeeded'
? 'success'
: 'danger'
}
title={
payStatus === 'succeeded'
? t('common.actionMsg.paySucceeded')
: t('common.actionMsg.payCancelled')
}
showOperateCancel={false}
confirmText={(payStatus === 'cancelled' && t('common.operation.ok')) || ''}
/>
)
}
<CheckModal />
</nav>
<div ref={anchorRef} className='h-0'> </div>
</>

View File

@ -75,7 +75,7 @@ const Popup: FC<PopupProps> = ({
<Link
href={`/datasets/${source.dataset_id}/documents/${source.document_id}`}
className='hidden items-center h-[18px] text-xs text-primary-600 group-hover:flex'>
Link to dataset
{t('common.chat.citation.linkToDataset')}
<ArrowUpRight className='ml-1 w-3 h-3' />
</Link>
</div>

View File

@ -0,0 +1,100 @@
import React, { useEffect, useRef, useState } from 'react'
import mermaid from 'mermaid'
import { t } from 'i18next'
import CryptoJS from 'crypto-js'
let mermaidAPI: any
mermaidAPI = null
if (typeof window !== 'undefined') {
mermaid.initialize({
startOnLoad: true,
theme: 'default',
flowchart: {
htmlLabels: true,
useMaxWidth: true,
},
})
mermaidAPI = mermaid.mermaidAPI
}
const style = {
minWidth: '480px',
height: 'auto',
overflow: 'auto',
}
// eslint-disable-next-line react/display-name
const Flowchart = React.forwardRef((props: {
PrimitiveCode: string
}, ref) => {
const [svgCode, setSvgCode] = useState(null)
const chartId = useRef(`flowchart_${CryptoJS.MD5(props.PrimitiveCode).toString()}`)
const [isRender, setIsRender] = useState(true)
const renderFlowchart = async (PrimitiveCode: string) => {
try {
const cachedSvg: any = localStorage.getItem(chartId.current)
if (cachedSvg) {
setSvgCode(cachedSvg)
return
}
if (typeof window !== 'undefined' && mermaidAPI) {
const svgGraph = await mermaidAPI.render(chartId.current, PrimitiveCode)
// eslint-disable-next-line @typescript-eslint/no-use-before-define
const base64Svg: any = await svgToBase64(svgGraph.svg)
localStorage.setItem(chartId.current, base64Svg)
setSvgCode(base64Svg)
}
}
catch (error) {
localStorage.clear()
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
// @ts-expect-error
console.error(error.toString())
}
}
const svgToBase64 = (svgGraph: string) => {
const svgBytes = new TextEncoder().encode(svgGraph)
const blob = new Blob([svgBytes], { type: 'image/svg+xml;charset=utf-8' })
return new Promise((resolve, reject) => {
const reader = new FileReader()
reader.onloadend = () => resolve(reader.result)
reader.onerror = reject
reader.readAsDataURL(blob)
})
}
const handleReRender = () => {
setIsRender(false)
setSvgCode(null)
localStorage.removeItem(chartId.current)
setTimeout(() => {
setIsRender(true)
renderFlowchart(props.PrimitiveCode)
}, 100)
}
useEffect(() => {
setIsRender(false)
setTimeout(() => {
setIsRender(true)
renderFlowchart(props.PrimitiveCode)
}, 100)
}, [props.PrimitiveCode])
return (
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
// @ts-expect-error
<div ref={ref}>
{
isRender && <div id={chartId.current} className="mermaid" style={style}>{svgCode && (<img src={svgCode} style={{ width: '100%', height: 'auto' }} alt="Mermaid chart" />)}</div>
}
<button onClick={handleReRender}>{t('appApi.merMaind.rerender')}</button>
</div>
)
})
export default Flowchart

View File

@ -0,0 +1,23 @@
import React from 'react'
import s from './style.module.css'
type ISVGBtnProps = {
isSVG: boolean
setIsSVG: React.Dispatch<React.SetStateAction<boolean>>
}
const SVGBtn = ({
isSVG,
setIsSVG,
}: ISVGBtnProps) => {
return (
<div
className={'box-border p-0.5 flex items-center justify-center rounded-md bg-white cursor-pointer'}
onClick={() => { setIsSVG(prevIsSVG => !prevIsSVG) }}
>
<div className={`w-6 h-6 rounded-md hover:bg-gray-50 ${s.svgIcon} ${isSVG ? s.svgIconed : ''}`}></div>
</div>
)
}
export default SVGBtn

View File

@ -0,0 +1,11 @@
.svgIcon {
background-image: url(~@/app/components/develop/secret-key/assets/svg.svg);
background-position: center;
background-repeat: no-repeat;
}
.svgIconed {
background-image: url(~@/app/components/develop/secret-key/assets/svged.svg);
background-position: center;
background-repeat: no-repeat;
}

View File

@ -213,7 +213,7 @@ const ConfigModel: FC<IConfigModelProps> = ({
const handleParamChange = (key: string, value: number) => {
const currParamsRule = getAllParams()[provider]?.[modelId]
let notOutRangeValue = parseFloat(value.toFixed(2))
let notOutRangeValue = parseFloat((value || 0).toFixed(2))
notOutRangeValue = Math.max(currParamsRule[key].min, notOutRangeValue)
notOutRangeValue = Math.min(currParamsRule[key].max, notOutRangeValue)

View File

@ -1,9 +1,20 @@
'use client'
import type { FC } from 'react'
import React from 'react'
import React, { useEffect } from 'react'
import Tooltip from '@/app/components/base/tooltip'
import Slider from '@/app/components/base/slider'
export const getFitPrecisionValue = (num: number, precision: number | null) => {
if (!precision || !(`${num}`).includes('.'))
return num
const currNumPrecision = (`${num}`).split('.')[1].length
if (currNumPrecision > precision)
return parseFloat(num.toFixed(precision))
return num
}
export type IParamIteProps = {
id: string
name: string
@ -12,16 +23,32 @@ export type IParamIteProps = {
step?: number
min?: number
max: number
precision: number | null
onChange: (key: string, value: number) => void
}
const ParamIte: FC<IParamIteProps> = ({ id, name, tip, step = 0.1, min = 0, max, value, onChange }) => {
const TIMES_TEMPLATE = '1000000000000'
const ParamItem: FC<IParamIteProps> = ({ id, name, tip, step = 0.1, min = 0, max, precision, value, onChange }) => {
const getToIntTimes = (num: number) => {
if (precision)
return parseInt(TIMES_TEMPLATE.slice(0, precision + 1), 10)
if (num < 5)
return 10
return 1
}
const times = getToIntTimes(max)
useEffect(() => {
if (precision)
onChange(id, getFitPrecisionValue(value, precision))
}, [value, precision])
return (
<div className="flex items-center justify-between">
<div className="flex items-center">
<span className="mr-[6px] text-gray-500 text-[13px] font-medium">{name}</span>
{/* Give tooltip different tip to avoiding hide bug */}
<Tooltip htmlContent={<div className="w-[200px]">{tip}</div>} position='top' selector={`param-name-tooltip-${id}`}>
<Tooltip htmlContent={<div className="w-[200px] whitespace-pre-wrap">{tip}</div>} position='top' selector={`param-name-tooltip-${id}`}>
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M8.66667 10.6667H8V8H7.33333M8 5.33333H8.00667M14 8C14 8.78793 13.8448 9.56815 13.5433 10.2961C13.2417 11.0241 12.7998 11.6855 12.2426 12.2426C11.6855 12.7998 11.0241 13.2417 10.2961 13.5433C9.56815 13.8448 8.78793 14 8 14C7.21207 14 6.43185 13.8448 5.7039 13.5433C4.97595 13.2417 4.31451 12.7998 3.75736 12.2426C3.20021 11.6855 2.75825 11.0241 2.45672 10.2961C2.15519 9.56815 2 8.78793 2 8C2 6.4087 2.63214 4.88258 3.75736 3.75736C4.88258 2.63214 6.4087 2 8 2C9.5913 2 11.1174 2.63214 12.2426 3.75736C13.3679 4.88258 14 6.4087 14 8Z" stroke="#9CA3AF" strokeWidth="1.5" strokeLinecap="round" strokeLinejoin="round" />
</svg>
@ -29,17 +56,21 @@ const ParamIte: FC<IParamIteProps> = ({ id, name, tip, step = 0.1, min = 0, max,
</div>
<div className="flex items-center">
<div className="mr-4 w-[120px]">
<Slider value={max < 5 ? value * 10 : value} min={min < 0 ? min * 10 : min} max={max < 5 ? max * 10 : max} onChange={value => onChange(id, value / (max < 5 ? 10 : 1))} />
<Slider value={value * times} min={min * times} max={max * times} onChange={(value) => {
onChange(id, value / times)
}} />
</div>
<input type="number" min={min} max={max} step={step} className="block w-[64px] h-9 leading-9 rounded-lg border-0 pl-1 pl py-1.5 bg-gray-50 text-gray-900 placeholder:text-gray-400 focus:ring-1 focus:ring-inset focus:ring-primary-600" value={value} onChange={(e) => {
const value = parseFloat(e.target.value)
if (value < min || value > max)
return
let value = getFitPrecisionValue(isNaN(parseFloat(e.target.value)) ? min : parseFloat(e.target.value), precision)
if (value < min)
value = min
if (value > max)
value = max
onChange(id, value)
}} />
</div>
</div>
)
}
export default React.memo(ParamIte)
export default React.memo(ParamItem)

View File

@ -1,5 +1,6 @@
'use client'
import cn from 'classnames'
import { useState } from 'react'
type AvatarProps = {
name: string
@ -17,14 +18,20 @@ const Avatar = ({
}: AvatarProps) => {
const avatarClassName = 'shrink-0 flex items-center rounded-full bg-primary-600'
const style = { width: `${size}px`, height: `${size}px`, fontSize: `${size}px`, lineHeight: `${size}px` }
const [imgError, setImgError] = useState(false)
if (avatar) {
const handleError = () => {
setImgError(true)
}
if (avatar && !imgError) {
return (
<img
className={cn(avatarClassName, className)}
style={style}
alt={name}
src={avatar}
onError={handleError}
/>
)
}

View File

@ -8,7 +8,7 @@ import { AlertCircle } from '@/app/components/base/icons/src/vender/solid/alerts
import { CheckCircle } from '@/app/components/base/icons/src/vender/solid/general'
import Button from '@/app/components/base/button'
type ConfirmCommonProps = {
export type ConfirmCommonProps = {
type?: string
isShow: boolean
onCancel: () => void

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 13 KiB

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 13 KiB

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 10 KiB

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,16 @@
// GENERATE BY script
// DON NOT EDIT IT MANUALLY
import * as React from 'react'
import data from './Zhipuai.json'
import IconBase from '@/app/components/base/icons/IconBase'
import type { IconBaseProps, IconData } from '@/app/components/base/icons/IconBase'
const Icon = React.forwardRef<React.MutableRefObject<SVGElement>, Omit<IconBaseProps, 'data'>>((
props,
ref,
) => <IconBase {...props} ref={ref} data={data as IconData} />)
Icon.displayName = 'Zhipuai'
export default Icon

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,16 @@
// GENERATE BY script
// DON NOT EDIT IT MANUALLY
import * as React from 'react'
import data from './ZhipuaiText.json'
import IconBase from '@/app/components/base/icons/IconBase'
import type { IconBaseProps, IconData } from '@/app/components/base/icons/IconBase'
const Icon = React.forwardRef<React.MutableRefObject<SVGElement>, Omit<IconBaseProps, 'data'>>((
props,
ref,
) => <IconBase {...props} ref={ref} data={data as IconData} />)
Icon.displayName = 'ZhipuaiText'
export default Icon

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,16 @@
// GENERATE BY script
// DON NOT EDIT IT MANUALLY
import * as React from 'react'
import data from './ZhipuaiTextCn.json'
import IconBase from '@/app/components/base/icons/IconBase'
import type { IconBaseProps, IconData } from '@/app/components/base/icons/IconBase'
const Icon = React.forwardRef<React.MutableRefObject<SVGElement>, Omit<IconBaseProps, 'data'>>((
props,
ref,
) => <IconBase {...props} ref={ref} data={data as IconData} />)
Icon.displayName = 'ZhipuaiTextCn'
export default Icon

View File

@ -29,3 +29,6 @@ export { default as ReplicateText } from './ReplicateText'
export { default as Replicate } from './Replicate'
export { default as XorbitsInferenceText } from './XorbitsInferenceText'
export { default as XorbitsInference } from './XorbitsInference'
export { default as ZhipuaiTextCn } from './ZhipuaiTextCn'
export { default as ZhipuaiText } from './ZhipuaiText'
export { default as Zhipuai } from './Zhipuai'

View File

@ -8,18 +8,28 @@ import SyntaxHighlighter from 'react-syntax-highlighter'
import { atelierHeathLight } from 'react-syntax-highlighter/dist/esm/styles/hljs'
import type { RefObject } from 'react'
import { useEffect, useRef, useState } from 'react'
import cn from 'classnames'
import CopyBtn from '@/app/components/app/chat/copy-btn'
import SVGBtn from '@/app/components/app/chat/svg'
import Flowchart from '@/app/components/app/chat/mermaid'
import s from '@/app/components/app/chat/style.module.css'
// Available language https://github.com/react-syntax-highlighter/react-syntax-highlighter/blob/master/AVAILABLE_LANGUAGES_HLJS.MD
const capitalizationLanguageNameMap: Record<string, string> = {
sql: 'SQL',
javascript: 'JavaScript',
java: 'Java',
typescript: 'TypeScript',
vbscript: 'VBScript',
css: 'CSS',
html: 'HTML',
xml: 'XML',
php: 'PHP',
python: 'Python',
yaml: 'Yaml',
mermaid: 'Mermaid',
markdown: 'MarkDown',
makefile: 'MakeFile',
}
const getCorrectCapitalizationLanguageName = (language: string) => {
if (!language)
@ -73,6 +83,7 @@ const useLazyLoad = (ref: RefObject<Element>): boolean => {
export function Markdown(props: { content: string }) {
const [isCopied, setIsCopied] = useState(false)
const [isSVG, setIsSVG] = useState(false)
return (
<div className="markdown-body">
<ReactMarkdown
@ -95,24 +106,35 @@ export function Markdown(props: { content: string }) {
}}
>
<div className='text-[13px] text-gray-500 font-normal'>{languageShowName}</div>
<CopyBtn
value={String(children).replace(/\n$/, '')}
isPlain
/>
<div style={{ display: 'flex' }}>
{language === 'mermaid'
&& <SVGBtn
isSVG={isSVG}
setIsSVG={setIsSVG}
/>
}
<CopyBtn
className={cn(s.copyBtn, 'mr-1')}
value={String(children).replace(/\n$/, '')}
isPlain
/>
</div>
</div>
<SyntaxHighlighter
{...props}
style={atelierHeathLight}
customStyle={{
paddingLeft: 12,
backgroundColor: '#fff',
}}
language={match[1]}
showLineNumbers
PreTag="div"
>
{String(children).replace(/\n$/, '')}
</SyntaxHighlighter>
{ (language === 'mermaid' && isSVG)
? (<Flowchart PrimitiveCode={String(children).replace(/\n$/, '')} />)
: (<SyntaxHighlighter
{...props}
style={atelierHeathLight}
customStyle={{
paddingLeft: 12,
backgroundColor: '#fff',
}}
language={match[1]}
showLineNumbers
PreTag="div"
>
{String(children).replace(/\n$/, '')}
</SyntaxHighlighter>)}
</div>
)
: (

View File

@ -487,8 +487,10 @@ const StepTwo = ({
<input
type="number"
className={s.input}
placeholder={t('datasetCreation.stepTwo.separatorPlaceholder') || ''} value={max}
onChange={e => setMax(Number(e.target.value))}
placeholder={t('datasetCreation.stepTwo.separatorPlaceholder') || ''}
value={max}
min={1}
onChange={e => setMax(parseInt(e.target.value.replace(/^0+/, ''), 10))}
/>
</div>
</div>
@ -497,7 +499,7 @@ const StepTwo = ({
<div className={s.label}>{t('datasetCreation.stepTwo.rules')}</div>
{rules.map(rule => (
<div key={rule.id} className={s.ruleItem}>
<input id={rule.id} type="checkbox" defaultChecked={rule.enabled} onChange={() => ruleChangeHandle(rule.id)} className="w-4 h-4 rounded border-gray-300 text-blue-700 focus:ring-blue-700" />
<input id={rule.id} type="checkbox" checked={rule.enabled} onChange={() => ruleChangeHandle(rule.id)} className="w-4 h-4 rounded border-gray-300 text-blue-700 focus:ring-blue-700" />
<label htmlFor={rule.id} className="ml-2 text-sm font-normal cursor-pointer text-gray-800">{getRuleName(rule.id)}</label>
</div>
))}

View File

@ -0,0 +1 @@
<?xml version="1.0" standalone="no"?><!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd"><svg t="1694177685288" class="icon" viewBox="0 0 1024 1024" version="1.1" xmlns="http://www.w3.org/2000/svg" p-id="4415" xmlns:xlink="http://www.w3.org/1999/xlink" width="16" height="16"><path d="M192 384h640a42.666667 42.666667 0 0 1 42.666667 42.666667v362.666666a42.666667 42.666667 0 0 1-42.666667 42.666667H192v106.666667a21.333333 21.333333 0 0 0 21.333333 21.333333h725.333334a21.333333 21.333333 0 0 0 21.333333-21.333333V308.821333L949.909333 298.666667h-126.528A98.048 98.048 0 0 1 725.333333 200.618667V72.661333L716.714667 64H213.333333a21.333333 21.333333 0 0 0-21.333333 21.333333v298.666667zM128 832H42.666667a42.666667 42.666667 0 0 1-42.666667-42.666667V426.666667a42.666667 42.666667 0 0 1 42.666667-42.666667h85.333333V85.333333a85.333333 85.333333 0 0 1 85.333333-85.333333h530.026667L1024 282.453333V938.666667a85.333333 85.333333 0 0 1-85.333333 85.333333H213.333333a85.333333 85.333333 0 0 1-85.333333-85.333333v-106.666667z m61.376-364.885333c-27.434667 0-49.898667 6.528-67.712 19.968-19.221333 13.824-28.501333 33.024-28.501333 57.216s9.621333 42.624 29.226666 55.296c7.466667 4.608 27.093333 12.288 58.432 23.04 28.138667 9.216 44.522667 15.36 49.514667 18.048 15.68 8.448 23.872 19.968 23.872 34.56 0 11.52-5.696 20.352-16.384 27.264-10.688 6.528-25.664 9.984-44.181333 9.984-21.013333 0-36.352-4.224-46.314667-11.904-11.050667-8.832-17.813333-23.808-20.672-44.544H85.333333c1.792 34.944 13.546667 60.288 34.922667 76.416 17.450667 13.056 42.026667 19.584 73.386667 19.584 32.426667 0 57.706667-7.296 75.52-21.12 17.813333-14.208 26.730667-33.792 26.730666-58.368 0-25.344-11.050667-44.928-33.130666-59.136-9.984-6.144-32.064-15.36-66.624-26.88-23.509333-8.064-38.122667-13.824-43.477334-16.896-12.096-6.912-17.813333-16.512-17.813333-28.032 0-13.056 4.992-22.656 15.68-28.416 8.554667-4.992 20.672-7.296 36.693333-7.296 18.538667 0 32.789333 3.456 42.048 11.136 9.258667 7.296 16.021333 19.584 19.584 36.48h41.344c-2.496-29.952-12.821333-52.224-30.656-66.432-16.725333-13.44-40.256-19.968-70.186666-19.968z m118.976 5.376L398.848 746.666667h50.24l90.496-274.176h-45.226667l-69.845333 223.488h-1.066667l-69.845333-223.488h-45.226667z m368.405333-5.376c-37.76 0-67.690667 13.824-89.792 42.24-21.013333 26.496-31.36 60.288-31.36 101.376 0 40.704 10.346667 74.112 31.36 99.84 22.442667 27.648 53.802667 41.472 94.421334 41.472 22.805333 0 43.093333-3.072 61.632-9.216A143.829333 143.829333 0 0 0 789.333333 716.714667V600.746667h-109.013333v38.4h67.328v56.448c-8.533333 5.376-17.450667 9.6-27.434667 12.672a123.285333 123.285333 0 0 1-34.197333 4.608c-30.997333 0-53.802667-9.216-68.416-27.648-13.525333-17.28-20.309333-42.24-20.309333-74.496 0-33.792 7.488-59.52 22.826666-77.952 13.866667-17.664 32.768-26.112 56.64-26.112 19.221333 0 34.901333 4.224 46.656 13.056 11.413333 8.832 19.242667 21.888 22.826667 39.552h42.026667c-4.629333-30.72-16.042667-53.376-34.197334-68.736-18.88-15.744-44.544-23.424-77.312-23.424z" fill="#8a8a8a" p-id="4416"></path></svg>

After

Width:  |  Height:  |  Size: 3.0 KiB

View File

@ -0,0 +1 @@
<?xml version="1.0" standalone="no"?><!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd"><svg t="1694177378730" class="icon" viewBox="0 0 1024 1024" version="1.1" xmlns="http://www.w3.org/2000/svg" p-id="4206" width="16" height="16" xmlns:xlink="http://www.w3.org/1999/xlink"><path d="M192 384h640a42.666667 42.666667 0 0 1 42.666667 42.666667v362.666666a42.666667 42.666667 0 0 1-42.666667 42.666667H192v106.666667a21.333333 21.333333 0 0 0 21.333333 21.333333h725.333334a21.333333 21.333333 0 0 0 21.333333-21.333333V308.821333L949.909333 298.666667h-126.528A98.048 98.048 0 0 1 725.333333 200.618667V72.661333L716.714667 64H213.333333a21.333333 21.333333 0 0 0-21.333333 21.333333v298.666667zM128 832H42.666667a42.666667 42.666667 0 0 1-42.666667-42.666667V426.666667a42.666667 42.666667 0 0 1 42.666667-42.666667h85.333333V85.333333a85.333333 85.333333 0 0 1 85.333333-85.333333h530.026667L1024 282.453333V938.666667a85.333333 85.333333 0 0 1-85.333333 85.333333H213.333333a85.333333 85.333333 0 0 1-85.333333-85.333333v-106.666667z m61.376-364.885333c-27.434667 0-49.898667 6.528-67.712 19.968-19.221333 13.824-28.501333 33.024-28.501333 57.216s9.621333 42.624 29.226666 55.296c7.466667 4.608 27.093333 12.288 58.432 23.04 28.138667 9.216 44.522667 15.36 49.514667 18.048 15.68 8.448 23.872 19.968 23.872 34.56 0 11.52-5.696 20.352-16.384 27.264-10.688 6.528-25.664 9.984-44.181333 9.984-21.013333 0-36.352-4.224-46.314667-11.904-11.050667-8.832-17.813333-23.808-20.672-44.544H85.333333c1.792 34.944 13.546667 60.288 34.922667 76.416 17.450667 13.056 42.026667 19.584 73.386667 19.584 32.426667 0 57.706667-7.296 75.52-21.12 17.813333-14.208 26.730667-33.792 26.730666-58.368 0-25.344-11.050667-44.928-33.130666-59.136-9.984-6.144-32.064-15.36-66.624-26.88-23.509333-8.064-38.122667-13.824-43.477334-16.896-12.096-6.912-17.813333-16.512-17.813333-28.032 0-13.056 4.992-22.656 15.68-28.416 8.554667-4.992 20.672-7.296 36.693333-7.296 18.538667 0 32.789333 3.456 42.048 11.136 9.258667 7.296 16.021333 19.584 19.584 36.48h41.344c-2.496-29.952-12.821333-52.224-30.656-66.432-16.725333-13.44-40.256-19.968-70.186666-19.968z m118.976 5.376L398.848 746.666667h50.24l90.496-274.176h-45.226667l-69.845333 223.488h-1.066667l-69.845333-223.488h-45.226667z m368.405333-5.376c-37.76 0-67.690667 13.824-89.792 42.24-21.013333 26.496-31.36 60.288-31.36 101.376 0 40.704 10.346667 74.112 31.36 99.84 22.442667 27.648 53.802667 41.472 94.421334 41.472 22.805333 0 43.093333-3.072 61.632-9.216A143.829333 143.829333 0 0 0 789.333333 716.714667V600.746667h-109.013333v38.4h67.328v56.448c-8.533333 5.376-17.450667 9.6-27.434667 12.672a123.285333 123.285333 0 0 1-34.197333 4.608c-30.997333 0-53.802667-9.216-68.416-27.648-13.525333-17.28-20.309333-42.24-20.309333-74.496 0-33.792 7.488-59.52 22.826666-77.952 13.866667-17.664 32.768-26.112 56.64-26.112 19.221333 0 34.901333 4.224 46.656 13.056 11.413333 8.832 19.242667 21.888 22.826667 39.552h42.026667c-4.629333-30.72-16.042667-53.376-34.197334-68.736-18.88-15.744-44.544-23.424-77.312-23.424z" fill="#1A8EF7" p-id="4207"></path></svg>

After

Width:  |  Height:  |  Size: 3.0 KiB

View File

@ -48,7 +48,7 @@ const ItemOperation: FC<IItemOperationProps> = ({
<PortalToFollowElemTrigger
onClick={() => setOpen(v => !v)}
>
<div className={cn(className, s.btn, 'h-6 w-6 rounded-md border-none py-1', open && `${s.open} !bg-gray-100 !shadow-none`)}></div>
<div className={cn(className, s.btn, 'h-6 w-6 rounded-md border-none py-1', (isItemHovering || open) && `${s.open} !bg-gray-100 !shadow-none`)}></div>
</PortalToFollowElemTrigger>
<PortalToFollowElemContent
className="z-50"

View File

@ -11,6 +11,7 @@ import chatglm from './chatglm'
import xinference from './xinference'
import openllm from './openllm'
import localai from './localai'
import zhipuai from './zhipuai'
export default {
openai,
@ -26,4 +27,5 @@ export default {
xinference,
openllm,
localai,
zhipuai,
}

View File

@ -0,0 +1,55 @@
import { ProviderEnum } from '../declarations'
import type { ProviderConfig } from '../declarations'
import { Zhipuai, ZhipuaiText, ZhipuaiTextCn } from '@/app/components/base/icons/src/public/llm'
const config: ProviderConfig = {
selector: {
name: {
'en': 'ZHIPU AI',
'zh-Hans': '智谱 AI',
},
icon: <Zhipuai className='w-full h-full' />,
},
item: {
key: ProviderEnum.zhipuai,
titleIcon: {
'en': <ZhipuaiText className='-ml-1 h-7' />,
'zh-Hans': <ZhipuaiTextCn className='h-8' />,
},
},
modal: {
key: ProviderEnum.zhipuai,
title: {
'en': 'ZHIPU AI',
'zh-Hans': '智谱 AI',
},
icon: <Zhipuai className='w-6 h-6' />,
link: {
href: 'https://open.bigmodel.cn/usercenter/apikeys',
label: {
'en': 'Get your API key from ZHIPU AI',
'zh-Hans': '从智谱 AI 获取 API Key',
},
},
validateKeys: [
'api_key',
],
fields: [
{
type: 'text',
key: 'api_key',
required: true,
label: {
'en': 'APIKey',
'zh-Hans': 'APIKey',
},
placeholder: {
'en': 'Enter your APIKey here',
'zh-Hans': '在此输入您的 APIKey',
},
},
],
},
}
export default config

View File

@ -42,6 +42,7 @@ export enum ProviderEnum {
'xinference' = 'xinference',
'openllm' = 'openllm',
'localai' = 'localai',
'zhipuai' = 'zhipuai',
}
export type ProviderConfigItem = {

Some files were not shown because too many files have changed in this diff Show More