mirror of
https://github.com/langgenius/dify.git
synced 2026-01-28 07:45:58 +08:00
Compare commits
26 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 2feb16d957 | |||
| 3043fbe73b | |||
| 9f99c3f55b | |||
| a07a6d8c26 | |||
| 695841a3cf | |||
| 3efaa713da | |||
| 9822f687f7 | |||
| b9d83c04bc | |||
| 298ad6782d | |||
| f4be2b8bcd | |||
| e83e239faf | |||
| 62bf7f0fc2 | |||
| 7dea485d57 | |||
| 5b9858a8a3 | |||
| 42a5b3ec17 | |||
| 2d1cb076c6 | |||
| 289c93d081 | |||
| c0fe706597 | |||
| 9cba1c8bf4 | |||
| cbf095465c | |||
| c007dbdc13 | |||
| ff493d017b | |||
| 7f6ad9653e | |||
| 2851a9f04e | |||
| c536f85b2e | |||
| b1352ff8b7 |
2
.github/workflows/build-api-image.yml
vendored
2
.github/workflows/build-api-image.yml
vendored
@ -31,7 +31,7 @@ jobs:
|
||||
with:
|
||||
images: langgenius/dify-api
|
||||
tags: |
|
||||
type=raw,value=latest,enable={{is_default_branch}}
|
||||
type=raw,value=latest,enable=${{ startsWith(github.ref, 'refs/tags/') }}
|
||||
type=ref,event=branch
|
||||
type=sha,enable=true,priority=100,prefix=,suffix=,format=long
|
||||
type=semver,pattern={{major}}.{{minor}}.{{patch}}
|
||||
|
||||
2
.github/workflows/build-web-image.yml
vendored
2
.github/workflows/build-web-image.yml
vendored
@ -31,7 +31,7 @@ jobs:
|
||||
with:
|
||||
images: langgenius/dify-web
|
||||
tags: |
|
||||
type=raw,value=latest,enable={{is_default_branch}}
|
||||
type=raw,value=latest,enable=${{ startsWith(github.ref, 'refs/tags/') }}
|
||||
type=ref,event=branch
|
||||
type=sha,enable=true,priority=100,prefix=,suffix=,format=long
|
||||
type=semver,pattern={{major}}.{{minor}}.{{patch}}
|
||||
|
||||
37
.github/workflows/check_no_chinese_comments.py
vendored
37
.github/workflows/check_no_chinese_comments.py
vendored
@ -1,37 +0,0 @@
|
||||
import os
|
||||
import re
|
||||
from zhon.hanzi import punctuation
|
||||
|
||||
def has_chinese_characters(text):
|
||||
for char in text:
|
||||
if '\u4e00' <= char <= '\u9fff' or char in punctuation:
|
||||
return True
|
||||
return False
|
||||
|
||||
def check_file_for_chinese_comments(file_path):
|
||||
with open(file_path, 'r', encoding='utf-8') as file:
|
||||
for line_number, line in enumerate(file, start=1):
|
||||
if has_chinese_characters(line):
|
||||
print(f"Found Chinese characters in {file_path} on line {line_number}:")
|
||||
print(line.strip())
|
||||
return True
|
||||
return False
|
||||
|
||||
def main():
|
||||
has_chinese = False
|
||||
excluded_files = ["model_template.py", 'stopwords.py', 'commands.py',
|
||||
'indexing_runner.py', 'web_reader_tool.py', 'spark_provider.py',
|
||||
'prompts.py']
|
||||
|
||||
for root, _, files in os.walk("."):
|
||||
for file in files:
|
||||
if file.endswith(".py") and file not in excluded_files:
|
||||
file_path = os.path.join(root, file)
|
||||
if check_file_for_chinese_comments(file_path):
|
||||
has_chinese = True
|
||||
|
||||
if has_chinese:
|
||||
raise Exception("Found Chinese characters in Python files. Please remove them.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
31
.github/workflows/check_no_chinese_comments.yml
vendored
31
.github/workflows/check_no_chinese_comments.yml
vendored
@ -1,31 +0,0 @@
|
||||
name: Check for Chinese comments
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- 'main'
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
|
||||
jobs:
|
||||
check-chinese-comments:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: 3.9
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install zhon
|
||||
|
||||
- name: Run script to check for Chinese comments
|
||||
run: |
|
||||
python .github/workflows/check_no_chinese_comments.py
|
||||
@ -6,6 +6,9 @@ from werkzeug.exceptions import Unauthorized
|
||||
if not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true':
|
||||
from gevent import monkey
|
||||
monkey.patch_all()
|
||||
if os.environ.get("VECTOR_STORE") == 'milvus':
|
||||
import grpc.experimental.gevent
|
||||
grpc.experimental.gevent.init_gevent()
|
||||
|
||||
import logging
|
||||
import json
|
||||
|
||||
@ -92,7 +92,7 @@ class Config:
|
||||
self.CONSOLE_URL = get_env('CONSOLE_URL')
|
||||
self.API_URL = get_env('API_URL')
|
||||
self.APP_URL = get_env('APP_URL')
|
||||
self.CURRENT_VERSION = "0.3.26"
|
||||
self.CURRENT_VERSION = "0.3.28"
|
||||
self.COMMIT_SHA = get_env('COMMIT_SHA')
|
||||
self.EDITION = "SELF_HOSTED"
|
||||
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
|
||||
|
||||
@ -31,6 +31,7 @@ model_templates = {
|
||||
'model': json.dumps({
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo-instruct",
|
||||
"mode": "completion",
|
||||
"completion_params": {
|
||||
"max_tokens": 512,
|
||||
"temperature": 1,
|
||||
@ -81,6 +82,7 @@ model_templates = {
|
||||
'model': json.dumps({
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo",
|
||||
"mode": "chat",
|
||||
"completion_params": {
|
||||
"max_tokens": 512,
|
||||
"temperature": 1,
|
||||
@ -137,10 +139,11 @@ demo_model_templates = {
|
||||
},
|
||||
opening_statement='',
|
||||
suggested_questions=None,
|
||||
pre_prompt="Please translate the following text into {{target_language}}:\n",
|
||||
pre_prompt="Please translate the following text into {{target_language}}:\n{{query}}\ntranslate:",
|
||||
model=json.dumps({
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo-instruct",
|
||||
"mode": "completion",
|
||||
"completion_params": {
|
||||
"max_tokens": 1000,
|
||||
"temperature": 0,
|
||||
@ -169,6 +172,13 @@ demo_model_templates = {
|
||||
'Italian',
|
||||
]
|
||||
}
|
||||
},{
|
||||
"paragraph": {
|
||||
"label": "Query",
|
||||
"variable": "query",
|
||||
"required": True,
|
||||
"default": ""
|
||||
}
|
||||
}
|
||||
])
|
||||
)
|
||||
@ -200,6 +210,7 @@ demo_model_templates = {
|
||||
model=json.dumps({
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo",
|
||||
"mode": "chat",
|
||||
"completion_params": {
|
||||
"max_tokens": 300,
|
||||
"temperature": 0.8,
|
||||
@ -255,10 +266,11 @@ demo_model_templates = {
|
||||
},
|
||||
opening_statement='',
|
||||
suggested_questions=None,
|
||||
pre_prompt="请将以下文本翻译为{{target_language}}:\n",
|
||||
pre_prompt="请将以下文本翻译为{{target_language}}:\n{{query}}\n翻译:",
|
||||
model=json.dumps({
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo-instruct",
|
||||
"mode": "completion",
|
||||
"completion_params": {
|
||||
"max_tokens": 1000,
|
||||
"temperature": 0,
|
||||
@ -287,6 +299,13 @@ demo_model_templates = {
|
||||
"意大利语",
|
||||
]
|
||||
}
|
||||
},{
|
||||
"paragraph": {
|
||||
"label": "文本内容",
|
||||
"variable": "query",
|
||||
"required": True,
|
||||
"default": ""
|
||||
}
|
||||
}
|
||||
])
|
||||
)
|
||||
@ -318,6 +337,7 @@ demo_model_templates = {
|
||||
model=json.dumps({
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo",
|
||||
"mode": "chat",
|
||||
"completion_params": {
|
||||
"max_tokens": 300,
|
||||
"temperature": 0.8,
|
||||
|
||||
@ -9,7 +9,7 @@ api = ExternalApi(bp)
|
||||
from . import setup, version, apikey, admin
|
||||
|
||||
# Import app controllers
|
||||
from .app import app, site, completion, model_config, statistic, conversation, message, generator, audio
|
||||
from .app import advanced_prompt_template, app, site, completion, model_config, statistic, conversation, message, generator, audio
|
||||
|
||||
# Import auth controllers
|
||||
from .auth import login, oauth, data_source_oauth, activate
|
||||
|
||||
25
api/controllers/console/app/advanced_prompt_template.py
Normal file
25
api/controllers/console/app/advanced_prompt_template.py
Normal file
@ -0,0 +1,25 @@
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from libs.login import login_required
|
||||
from services.advanced_prompt_template_service import AdvancedPromptTemplateService
|
||||
|
||||
class AdvancedPromptTemplateList(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('app_mode', type=str, required=True, location='args')
|
||||
parser.add_argument('model_mode', type=str, required=True, location='args')
|
||||
parser.add_argument('has_context', type=str, required=False, default='true', location='args')
|
||||
parser.add_argument('model_name', type=str, required=True, location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
return AdvancedPromptTemplateService.get_prompt(args)
|
||||
|
||||
api.add_resource(AdvancedPromptTemplateList, '/app/prompt-templates')
|
||||
@ -12,35 +12,6 @@ from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededE
|
||||
LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, ModelCurrentlyNotSupportError
|
||||
|
||||
|
||||
class IntroductionGenerateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('prompt_template', type=str, required=True, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
account = current_user
|
||||
|
||||
try:
|
||||
answer = LLMGenerator.generate_introduction(
|
||||
account.current_tenant_id,
|
||||
args['prompt_template']
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
|
||||
return {'introduction': answer}
|
||||
|
||||
|
||||
class RuleGenerateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -72,5 +43,4 @@ class RuleGenerateApi(Resource):
|
||||
return rules
|
||||
|
||||
|
||||
api.add_resource(IntroductionGenerateApi, '/introduction-generate')
|
||||
api.add_resource(RuleGenerateApi, '/rule-generate')
|
||||
|
||||
@ -295,8 +295,8 @@ class MessageSuggestedQuestionApi(Resource):
|
||||
try:
|
||||
questions = MessageService.get_suggested_questions_after_answer(
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
message_id=message_id,
|
||||
user=current_user,
|
||||
check_enabled=False
|
||||
)
|
||||
except MessageNotExistsError:
|
||||
@ -329,7 +329,7 @@ class MessageApi(Resource):
|
||||
message_id = str(message_id)
|
||||
|
||||
# get app info
|
||||
app_model = _get_app(app_id, 'chat')
|
||||
app_model = _get_app(app_id)
|
||||
|
||||
message = db.session.query(Message).filter(
|
||||
Message.id == message_id,
|
||||
|
||||
@ -54,6 +54,7 @@ class ConversationDetailApi(AppApiResource):
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
class ConversationRenameApi(AppApiResource):
|
||||
|
||||
@marshal_with(simple_conversation_fields)
|
||||
|
||||
@ -10,6 +10,8 @@ from controllers.service_api.app.error import NotChatAppError
|
||||
from controllers.service_api.wraps import AppApiResource
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from services.message_service import MessageService
|
||||
from extensions.ext_database import db
|
||||
from models.model import Account, Message
|
||||
|
||||
|
||||
class MessageListApi(AppApiResource):
|
||||
@ -96,5 +98,36 @@ class MessageFeedbackApi(AppApiResource):
|
||||
return {'result': 'success'}
|
||||
|
||||
|
||||
class MessageSuggestedApi(AppApiResource):
|
||||
def get(self, app_model, end_user, message_id):
|
||||
message_id = str(message_id)
|
||||
if app_model.mode != 'chat':
|
||||
raise NotChatAppError()
|
||||
|
||||
try:
|
||||
message = db.session.query(Message).filter(
|
||||
Message.id == message_id,
|
||||
Message.app_id == app_model.id,
|
||||
).first()
|
||||
|
||||
if end_user is None and message.from_account_id is not None:
|
||||
user = db.session.get(Account, message.from_account_id)
|
||||
elif end_user is None and message.from_end_user_id is not None:
|
||||
user = create_or_update_end_user_for_user_id(app_model, message.from_end_user_id)
|
||||
else:
|
||||
user = end_user
|
||||
|
||||
questions = MessageService.get_suggested_questions_after_answer(
|
||||
app_model=app_model,
|
||||
user=user,
|
||||
message_id=message_id
|
||||
)
|
||||
except services.errors.message.MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
return {'result': 'success', 'data': questions}
|
||||
|
||||
|
||||
api.add_resource(MessageListApi, '/messages')
|
||||
api.add_resource(MessageFeedbackApi, '/messages/<uuid:message_id>/feedbacks')
|
||||
api.add_resource(MessageSuggestedApi, '/messages/<uuid:message_id>/suggested')
|
||||
|
||||
@ -115,7 +115,7 @@ class MessageMoreLikeThisApi(WebApiResource):
|
||||
streaming = args['response_mode'] == 'streaming'
|
||||
|
||||
try:
|
||||
response = CompletionService.generate_more_like_this(app_model, end_user, message_id, streaming)
|
||||
response = CompletionService.generate_more_like_this(app_model, end_user, message_id, streaming, 'web_app')
|
||||
return compact_response(response)
|
||||
except MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
@ -2,14 +2,18 @@ import json
|
||||
from typing import Tuple, List, Any, Union, Sequence, Optional, cast
|
||||
|
||||
from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
|
||||
from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.prompts.chat import BaseMessagePromptTemplate
|
||||
from langchain.schema import AgentAction, AgentFinish, SystemMessage
|
||||
from langchain.schema import AgentAction, AgentFinish, SystemMessage, Generation, LLMResult, AIMessage
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.tools import BaseTool
|
||||
from pydantic import root_validator
|
||||
|
||||
from core.model_providers.models.entity.message import to_prompt_messages
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.third_party.langchain.llms.fake import FakeLLM
|
||||
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
|
||||
|
||||
|
||||
@ -24,6 +28,10 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@root_validator
|
||||
def validate_llm(cls, values: dict) -> dict:
|
||||
return values
|
||||
|
||||
def should_use_agent(self, query: str):
|
||||
"""
|
||||
return should use agent
|
||||
@ -65,17 +73,57 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
|
||||
return AgentFinish(return_values={"output": observation}, log=observation)
|
||||
|
||||
try:
|
||||
agent_decision = super().plan(intermediate_steps, callbacks, **kwargs)
|
||||
agent_decision = self.real_plan(intermediate_steps, callbacks, **kwargs)
|
||||
if isinstance(agent_decision, AgentAction):
|
||||
tool_inputs = agent_decision.tool_input
|
||||
if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
|
||||
tool_inputs['query'] = kwargs['input']
|
||||
agent_decision.tool_input = tool_inputs
|
||||
else:
|
||||
agent_decision.return_values['output'] = ''
|
||||
return agent_decision
|
||||
except Exception as e:
|
||||
new_exception = self.model_instance.handle_exceptions(e)
|
||||
raise new_exception
|
||||
|
||||
def real_plan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date, along with observations
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
agent_scratchpad = _format_intermediate_steps(intermediate_steps)
|
||||
selected_inputs = {
|
||||
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
|
||||
}
|
||||
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
|
||||
prompt = self.prompt.format_prompt(**full_inputs)
|
||||
messages = prompt.to_messages()
|
||||
prompt_messages = to_prompt_messages(messages)
|
||||
result = self.model_instance.run(
|
||||
messages=prompt_messages,
|
||||
functions=self.functions,
|
||||
)
|
||||
|
||||
ai_message = AIMessage(
|
||||
content=result.content,
|
||||
additional_kwargs={
|
||||
'function_call': result.function_call
|
||||
}
|
||||
)
|
||||
|
||||
agent_decision = _parse_ai_message(ai_message)
|
||||
return agent_decision
|
||||
|
||||
async def aplan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
@ -87,7 +135,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
|
||||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
model_instance: BaseLLM,
|
||||
tools: Sequence[BaseTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
|
||||
@ -96,11 +144,15 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
|
||||
),
|
||||
**kwargs: Any,
|
||||
) -> BaseSingleActionAgent:
|
||||
return super().from_llm_and_tools(
|
||||
llm=llm,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
prompt = cls.create_prompt(
|
||||
extra_prompt_messages=extra_prompt_messages,
|
||||
system_message=system_message,
|
||||
)
|
||||
return cls(
|
||||
model_instance=model_instance,
|
||||
llm=FakeLLM(response=''),
|
||||
prompt=prompt,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -5,21 +5,40 @@ from langchain.agents.openai_functions_agent.base import _parse_ai_message, \
|
||||
_format_intermediate_steps
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chat_models.openai import _convert_message_to_dict, _import_tiktoken
|
||||
from langchain.memory.prompt import SUMMARY_PROMPT
|
||||
from langchain.prompts.chat import BaseMessagePromptTemplate
|
||||
from langchain.schema import AgentAction, AgentFinish, SystemMessage
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.schema import AgentAction, AgentFinish, SystemMessage, AIMessage, HumanMessage, BaseMessage, \
|
||||
get_buffer_string
|
||||
from langchain.tools import BaseTool
|
||||
from pydantic import root_validator
|
||||
|
||||
from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError
|
||||
from core.agent.agent.openai_function_call_summarize_mixin import OpenAIFunctionCallSummarizeMixin
|
||||
from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError, CalcTokenMixin
|
||||
from core.chain.llm_chain import LLMChain
|
||||
from core.model_providers.models.entity.message import to_prompt_messages
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.third_party.langchain.llms.fake import FakeLLM
|
||||
|
||||
|
||||
class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctionCallSummarizeMixin):
|
||||
class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixin):
|
||||
moving_summary_buffer: str = ""
|
||||
moving_summary_index: int = 0
|
||||
summary_model_instance: BaseLLM = None
|
||||
model_instance: BaseLLM
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@root_validator
|
||||
def validate_llm(cls, values: dict) -> dict:
|
||||
return values
|
||||
|
||||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
model_instance: BaseLLM,
|
||||
tools: Sequence[BaseTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
|
||||
@ -28,12 +47,16 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio
|
||||
),
|
||||
**kwargs: Any,
|
||||
) -> BaseSingleActionAgent:
|
||||
return super().from_llm_and_tools(
|
||||
llm=llm,
|
||||
prompt = cls.create_prompt(
|
||||
extra_prompt_messages=extra_prompt_messages,
|
||||
system_message=system_message,
|
||||
)
|
||||
return cls(
|
||||
model_instance=model_instance,
|
||||
llm=FakeLLM(response=''),
|
||||
prompt=prompt,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
extra_prompt_messages=extra_prompt_messages,
|
||||
system_message=cls.get_system_message(),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -44,23 +67,26 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio
|
||||
:param query:
|
||||
:return:
|
||||
"""
|
||||
original_max_tokens = self.llm.max_tokens
|
||||
self.llm.max_tokens = 40
|
||||
original_max_tokens = self.model_instance.model_kwargs.max_tokens
|
||||
self.model_instance.model_kwargs.max_tokens = 40
|
||||
|
||||
prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
|
||||
messages = prompt.to_messages()
|
||||
|
||||
try:
|
||||
predicted_message = self.llm.predict_messages(
|
||||
messages, functions=self.functions, callbacks=None
|
||||
prompt_messages = to_prompt_messages(messages)
|
||||
result = self.model_instance.run(
|
||||
messages=prompt_messages,
|
||||
functions=self.functions,
|
||||
callbacks=None
|
||||
)
|
||||
except Exception as e:
|
||||
new_exception = self.model_instance.handle_exceptions(e)
|
||||
raise new_exception
|
||||
|
||||
function_call = predicted_message.additional_kwargs.get("function_call", {})
|
||||
function_call = result.function_call
|
||||
|
||||
self.llm.max_tokens = original_max_tokens
|
||||
self.model_instance.model_kwargs.max_tokens = original_max_tokens
|
||||
|
||||
return True if function_call else False
|
||||
|
||||
@ -93,10 +119,19 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio
|
||||
except ExceededLLMTokensLimitError as e:
|
||||
return AgentFinish(return_values={"output": str(e)}, log=str(e))
|
||||
|
||||
predicted_message = self.llm.predict_messages(
|
||||
messages, functions=self.functions, callbacks=callbacks
|
||||
prompt_messages = to_prompt_messages(messages)
|
||||
result = self.model_instance.run(
|
||||
messages=prompt_messages,
|
||||
functions=self.functions,
|
||||
)
|
||||
agent_decision = _parse_ai_message(predicted_message)
|
||||
|
||||
ai_message = AIMessage(
|
||||
content=result.content,
|
||||
additional_kwargs={
|
||||
'function_call': result.function_call
|
||||
}
|
||||
)
|
||||
agent_decision = _parse_ai_message(ai_message)
|
||||
|
||||
if isinstance(agent_decision, AgentAction) and agent_decision.tool == 'dataset':
|
||||
tool_inputs = agent_decision.tool_input
|
||||
@ -122,3 +157,142 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio
|
||||
return super().return_stopped_response(early_stopping_method, intermediate_steps, **kwargs)
|
||||
except ValueError:
|
||||
return AgentFinish({"output": "I'm sorry, I don't know how to respond to that."}, "")
|
||||
|
||||
def summarize_messages_if_needed(self, messages: List[BaseMessage], **kwargs) -> List[BaseMessage]:
|
||||
# calculate rest tokens and summarize previous function observation messages if rest_tokens < 0
|
||||
rest_tokens = self.get_message_rest_tokens(self.model_instance, messages, **kwargs)
|
||||
rest_tokens = rest_tokens - 20 # to deal with the inaccuracy of rest_tokens
|
||||
if rest_tokens >= 0:
|
||||
return messages
|
||||
|
||||
system_message = None
|
||||
human_message = None
|
||||
should_summary_messages = []
|
||||
for message in messages:
|
||||
if isinstance(message, SystemMessage):
|
||||
system_message = message
|
||||
elif isinstance(message, HumanMessage):
|
||||
human_message = message
|
||||
else:
|
||||
should_summary_messages.append(message)
|
||||
|
||||
if len(should_summary_messages) > 2:
|
||||
ai_message = should_summary_messages[-2]
|
||||
function_message = should_summary_messages[-1]
|
||||
should_summary_messages = should_summary_messages[self.moving_summary_index:-2]
|
||||
self.moving_summary_index = len(should_summary_messages)
|
||||
else:
|
||||
error_msg = "Exceeded LLM tokens limit, stopped."
|
||||
raise ExceededLLMTokensLimitError(error_msg)
|
||||
|
||||
new_messages = [system_message, human_message]
|
||||
|
||||
if self.moving_summary_index == 0:
|
||||
should_summary_messages.insert(0, human_message)
|
||||
|
||||
self.moving_summary_buffer = self.predict_new_summary(
|
||||
messages=should_summary_messages,
|
||||
existing_summary=self.moving_summary_buffer
|
||||
)
|
||||
|
||||
new_messages.append(AIMessage(content=self.moving_summary_buffer))
|
||||
new_messages.append(ai_message)
|
||||
new_messages.append(function_message)
|
||||
|
||||
return new_messages
|
||||
|
||||
def predict_new_summary(
|
||||
self, messages: List[BaseMessage], existing_summary: str
|
||||
) -> str:
|
||||
new_lines = get_buffer_string(
|
||||
messages,
|
||||
human_prefix="Human",
|
||||
ai_prefix="AI",
|
||||
)
|
||||
|
||||
chain = LLMChain(model_instance=self.summary_model_instance, prompt=SUMMARY_PROMPT)
|
||||
return chain.predict(summary=existing_summary, new_lines=new_lines)
|
||||
|
||||
def get_num_tokens_from_messages(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int:
|
||||
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
||||
|
||||
Official documentation: https://github.com/openai/openai-cookbook/blob/
|
||||
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
|
||||
if model_instance.model_provider.provider_name == 'azure_openai':
|
||||
model = model_instance.base_model_name
|
||||
model = model.replace("gpt-35", "gpt-3.5")
|
||||
else:
|
||||
model = model_instance.base_model_name
|
||||
|
||||
tiktoken_ = _import_tiktoken()
|
||||
try:
|
||||
encoding = tiktoken_.encoding_for_model(model)
|
||||
except KeyError:
|
||||
model = "cl100k_base"
|
||||
encoding = tiktoken_.get_encoding(model)
|
||||
|
||||
if model.startswith("gpt-3.5-turbo"):
|
||||
# every message follows <im_start>{role/name}\n{content}<im_end>\n
|
||||
tokens_per_message = 4
|
||||
# if there's a name, the role is omitted
|
||||
tokens_per_name = -1
|
||||
elif model.startswith("gpt-4"):
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"get_num_tokens_from_messages() is not presently implemented "
|
||||
f"for model {model}."
|
||||
"See https://github.com/openai/openai-python/blob/main/chatml.md for "
|
||||
"information on how messages are converted to tokens."
|
||||
)
|
||||
num_tokens = 0
|
||||
for m in messages:
|
||||
message = _convert_message_to_dict(m)
|
||||
num_tokens += tokens_per_message
|
||||
for key, value in message.items():
|
||||
if key == "function_call":
|
||||
for f_key, f_value in value.items():
|
||||
num_tokens += len(encoding.encode(f_key))
|
||||
num_tokens += len(encoding.encode(f_value))
|
||||
else:
|
||||
num_tokens += len(encoding.encode(value))
|
||||
|
||||
if key == "name":
|
||||
num_tokens += tokens_per_name
|
||||
# every reply is primed with <im_start>assistant
|
||||
num_tokens += 3
|
||||
|
||||
if kwargs.get('functions'):
|
||||
for function in kwargs.get('functions'):
|
||||
num_tokens += len(encoding.encode('name'))
|
||||
num_tokens += len(encoding.encode(function.get("name")))
|
||||
num_tokens += len(encoding.encode('description'))
|
||||
num_tokens += len(encoding.encode(function.get("description")))
|
||||
parameters = function.get("parameters")
|
||||
num_tokens += len(encoding.encode('parameters'))
|
||||
if 'title' in parameters:
|
||||
num_tokens += len(encoding.encode('title'))
|
||||
num_tokens += len(encoding.encode(parameters.get("title")))
|
||||
num_tokens += len(encoding.encode('type'))
|
||||
num_tokens += len(encoding.encode(parameters.get("type")))
|
||||
if 'properties' in parameters:
|
||||
num_tokens += len(encoding.encode('properties'))
|
||||
for key, value in parameters.get('properties').items():
|
||||
num_tokens += len(encoding.encode(key))
|
||||
for field_key, field_value in value.items():
|
||||
num_tokens += len(encoding.encode(field_key))
|
||||
if field_key == 'enum':
|
||||
for enum_field in field_value:
|
||||
num_tokens += 3
|
||||
num_tokens += len(encoding.encode(enum_field))
|
||||
else:
|
||||
num_tokens += len(encoding.encode(field_key))
|
||||
num_tokens += len(encoding.encode(str(field_value)))
|
||||
if 'required' in parameters:
|
||||
num_tokens += len(encoding.encode('required'))
|
||||
for required_field in parameters['required']:
|
||||
num_tokens += 3
|
||||
num_tokens += len(encoding.encode(required_field))
|
||||
|
||||
return num_tokens
|
||||
|
||||
@ -1,140 +0,0 @@
|
||||
from typing import cast, List
|
||||
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.chat_models.openai import _convert_message_to_dict
|
||||
from langchain.memory.summary import SummarizerMixin
|
||||
from langchain.schema import SystemMessage, HumanMessage, BaseMessage, AIMessage
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError, CalcTokenMixin
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
|
||||
|
||||
class OpenAIFunctionCallSummarizeMixin(BaseModel, CalcTokenMixin):
|
||||
moving_summary_buffer: str = ""
|
||||
moving_summary_index: int = 0
|
||||
summary_llm: BaseLanguageModel = None
|
||||
model_instance: BaseLLM
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def summarize_messages_if_needed(self, messages: List[BaseMessage], **kwargs) -> List[BaseMessage]:
|
||||
# calculate rest tokens and summarize previous function observation messages if rest_tokens < 0
|
||||
rest_tokens = self.get_message_rest_tokens(self.model_instance, messages, **kwargs)
|
||||
rest_tokens = rest_tokens - 20 # to deal with the inaccuracy of rest_tokens
|
||||
if rest_tokens >= 0:
|
||||
return messages
|
||||
|
||||
system_message = None
|
||||
human_message = None
|
||||
should_summary_messages = []
|
||||
for message in messages:
|
||||
if isinstance(message, SystemMessage):
|
||||
system_message = message
|
||||
elif isinstance(message, HumanMessage):
|
||||
human_message = message
|
||||
else:
|
||||
should_summary_messages.append(message)
|
||||
|
||||
if len(should_summary_messages) > 2:
|
||||
ai_message = should_summary_messages[-2]
|
||||
function_message = should_summary_messages[-1]
|
||||
should_summary_messages = should_summary_messages[self.moving_summary_index:-2]
|
||||
self.moving_summary_index = len(should_summary_messages)
|
||||
else:
|
||||
error_msg = "Exceeded LLM tokens limit, stopped."
|
||||
raise ExceededLLMTokensLimitError(error_msg)
|
||||
|
||||
new_messages = [system_message, human_message]
|
||||
|
||||
if self.moving_summary_index == 0:
|
||||
should_summary_messages.insert(0, human_message)
|
||||
|
||||
summary_handler = SummarizerMixin(llm=self.summary_llm)
|
||||
self.moving_summary_buffer = summary_handler.predict_new_summary(
|
||||
messages=should_summary_messages,
|
||||
existing_summary=self.moving_summary_buffer
|
||||
)
|
||||
|
||||
new_messages.append(AIMessage(content=self.moving_summary_buffer))
|
||||
new_messages.append(ai_message)
|
||||
new_messages.append(function_message)
|
||||
|
||||
return new_messages
|
||||
|
||||
def get_num_tokens_from_messages(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int:
|
||||
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
||||
|
||||
Official documentation: https://github.com/openai/openai-cookbook/blob/
|
||||
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
|
||||
llm = cast(ChatOpenAI, model_instance.client)
|
||||
model, encoding = llm._get_encoding_model()
|
||||
if model.startswith("gpt-3.5-turbo"):
|
||||
# every message follows <im_start>{role/name}\n{content}<im_end>\n
|
||||
tokens_per_message = 4
|
||||
# if there's a name, the role is omitted
|
||||
tokens_per_name = -1
|
||||
elif model.startswith("gpt-4"):
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"get_num_tokens_from_messages() is not presently implemented "
|
||||
f"for model {model}."
|
||||
"See https://github.com/openai/openai-python/blob/main/chatml.md for "
|
||||
"information on how messages are converted to tokens."
|
||||
)
|
||||
num_tokens = 0
|
||||
for m in messages:
|
||||
message = _convert_message_to_dict(m)
|
||||
num_tokens += tokens_per_message
|
||||
for key, value in message.items():
|
||||
if key == "function_call":
|
||||
for f_key, f_value in value.items():
|
||||
num_tokens += len(encoding.encode(f_key))
|
||||
num_tokens += len(encoding.encode(f_value))
|
||||
else:
|
||||
num_tokens += len(encoding.encode(value))
|
||||
|
||||
if key == "name":
|
||||
num_tokens += tokens_per_name
|
||||
# every reply is primed with <im_start>assistant
|
||||
num_tokens += 3
|
||||
|
||||
if kwargs.get('functions'):
|
||||
for function in kwargs.get('functions'):
|
||||
num_tokens += len(encoding.encode('name'))
|
||||
num_tokens += len(encoding.encode(function.get("name")))
|
||||
num_tokens += len(encoding.encode('description'))
|
||||
num_tokens += len(encoding.encode(function.get("description")))
|
||||
parameters = function.get("parameters")
|
||||
num_tokens += len(encoding.encode('parameters'))
|
||||
if 'title' in parameters:
|
||||
num_tokens += len(encoding.encode('title'))
|
||||
num_tokens += len(encoding.encode(parameters.get("title")))
|
||||
num_tokens += len(encoding.encode('type'))
|
||||
num_tokens += len(encoding.encode(parameters.get("type")))
|
||||
if 'properties' in parameters:
|
||||
num_tokens += len(encoding.encode('properties'))
|
||||
for key, value in parameters.get('properties').items():
|
||||
num_tokens += len(encoding.encode(key))
|
||||
for field_key, field_value in value.items():
|
||||
num_tokens += len(encoding.encode(field_key))
|
||||
if field_key == 'enum':
|
||||
for enum_field in field_value:
|
||||
num_tokens += 3
|
||||
num_tokens += len(encoding.encode(enum_field))
|
||||
else:
|
||||
num_tokens += len(encoding.encode(field_key))
|
||||
num_tokens += len(encoding.encode(str(field_value)))
|
||||
if 'required' in parameters:
|
||||
num_tokens += len(encoding.encode('required'))
|
||||
for required_field in parameters['required']:
|
||||
num_tokens += 3
|
||||
num_tokens += len(encoding.encode(required_field))
|
||||
|
||||
return num_tokens
|
||||
@ -1,107 +0,0 @@
|
||||
from typing import List, Tuple, Any, Union, Sequence, Optional
|
||||
|
||||
from langchain.agents import BaseMultiActionAgent
|
||||
from langchain.agents.openai_functions_multi_agent.base import OpenAIMultiFunctionsAgent, _format_intermediate_steps, \
|
||||
_parse_ai_message
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.prompts.chat import BaseMessagePromptTemplate
|
||||
from langchain.schema import AgentAction, AgentFinish, SystemMessage
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError
|
||||
from core.agent.agent.openai_function_call_summarize_mixin import OpenAIFunctionCallSummarizeMixin
|
||||
|
||||
|
||||
class AutoSummarizingOpenMultiAIFunctionCallAgent(OpenAIMultiFunctionsAgent, OpenAIFunctionCallSummarizeMixin):
|
||||
|
||||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
tools: Sequence[BaseTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
|
||||
system_message: Optional[SystemMessage] = SystemMessage(
|
||||
content="You are a helpful AI assistant."
|
||||
),
|
||||
**kwargs: Any,
|
||||
) -> BaseMultiActionAgent:
|
||||
return super().from_llm_and_tools(
|
||||
llm=llm,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
extra_prompt_messages=extra_prompt_messages,
|
||||
system_message=cls.get_system_message(),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def should_use_agent(self, query: str):
|
||||
"""
|
||||
return should use agent
|
||||
|
||||
:param query:
|
||||
:return:
|
||||
"""
|
||||
original_max_tokens = self.llm.max_tokens
|
||||
self.llm.max_tokens = 15
|
||||
|
||||
prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
|
||||
messages = prompt.to_messages()
|
||||
|
||||
try:
|
||||
predicted_message = self.llm.predict_messages(
|
||||
messages, functions=self.functions, callbacks=None
|
||||
)
|
||||
except Exception as e:
|
||||
new_exception = self.model_instance.handle_exceptions(e)
|
||||
raise new_exception
|
||||
|
||||
function_call = predicted_message.additional_kwargs.get("function_call", {})
|
||||
|
||||
self.llm.max_tokens = original_max_tokens
|
||||
|
||||
return True if function_call else False
|
||||
|
||||
def plan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date, along with observations
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
agent_scratchpad = _format_intermediate_steps(intermediate_steps)
|
||||
selected_inputs = {
|
||||
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
|
||||
}
|
||||
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
|
||||
prompt = self.prompt.format_prompt(**full_inputs)
|
||||
messages = prompt.to_messages()
|
||||
|
||||
# summarize messages if rest_tokens < 0
|
||||
try:
|
||||
messages = self.summarize_messages_if_needed(messages, functions=self.functions)
|
||||
except ExceededLLMTokensLimitError as e:
|
||||
return AgentFinish(return_values={"output": str(e)}, log=str(e))
|
||||
|
||||
predicted_message = self.llm.predict_messages(
|
||||
messages, functions=self.functions, callbacks=callbacks
|
||||
)
|
||||
agent_decision = _parse_ai_message(predicted_message)
|
||||
return agent_decision
|
||||
|
||||
@classmethod
|
||||
def get_system_message(cls):
|
||||
# get current time
|
||||
return SystemMessage(content="You are a helpful AI assistant.\n"
|
||||
"The current date or current time you know is wrong.\n"
|
||||
"Respond directly if appropriate.")
|
||||
@ -4,7 +4,6 @@ from typing import List, Tuple, Any, Union, Sequence, Optional, cast
|
||||
from langchain import BasePromptTemplate
|
||||
from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent
|
||||
from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
|
||||
@ -12,6 +11,7 @@ from langchain.schema import AgentAction, AgentFinish, OutputParserException
|
||||
from langchain.tools import BaseTool
|
||||
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
|
||||
|
||||
from core.chain.llm_chain import LLMChain
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
|
||||
|
||||
@ -49,7 +49,6 @@ Action:
|
||||
|
||||
|
||||
class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
|
||||
model_instance: BaseLLM
|
||||
dataset_tools: Sequence[BaseTool]
|
||||
|
||||
class Config:
|
||||
@ -98,7 +97,7 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
|
||||
try:
|
||||
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
|
||||
except Exception as e:
|
||||
new_exception = self.model_instance.handle_exceptions(e)
|
||||
new_exception = self.llm_chain.model_instance.handle_exceptions(e)
|
||||
raise new_exception
|
||||
|
||||
try:
|
||||
@ -108,6 +107,8 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
|
||||
if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
|
||||
tool_inputs['query'] = kwargs['input']
|
||||
agent_decision.tool_input = tool_inputs
|
||||
else:
|
||||
agent_decision.return_values['output'] = ''
|
||||
return agent_decision
|
||||
except OutputParserException:
|
||||
return AgentFinish({"output": "I'm sorry, the answer of model is invalid, "
|
||||
@ -145,7 +146,7 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
|
||||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
model_instance: BaseLLM,
|
||||
tools: Sequence[BaseTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
output_parser: Optional[AgentOutputParser] = None,
|
||||
@ -157,17 +158,28 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
|
||||
memory_prompts: Optional[List[BasePromptTemplate]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Agent:
|
||||
return super().from_llm_and_tools(
|
||||
llm=llm,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
output_parser=output_parser,
|
||||
"""Construct an agent from an LLM and tools."""
|
||||
cls._validate_tools(tools)
|
||||
prompt = cls.create_prompt(
|
||||
tools,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
human_message_template=human_message_template,
|
||||
format_instructions=format_instructions,
|
||||
input_variables=input_variables,
|
||||
memory_prompts=memory_prompts,
|
||||
)
|
||||
llm_chain = LLMChain(
|
||||
model_instance=model_instance,
|
||||
prompt=prompt,
|
||||
callback_manager=callback_manager,
|
||||
)
|
||||
tool_names = [tool.name for tool in tools]
|
||||
_output_parser = output_parser
|
||||
return cls(
|
||||
llm_chain=llm_chain,
|
||||
allowed_tools=tool_names,
|
||||
output_parser=_output_parser,
|
||||
dataset_tools=tools,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -4,16 +4,17 @@ from typing import List, Tuple, Any, Union, Sequence, Optional
|
||||
from langchain import BasePromptTemplate
|
||||
from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent
|
||||
from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.memory.summary import SummarizerMixin
|
||||
from langchain.memory.prompt import SUMMARY_PROMPT
|
||||
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
|
||||
from langchain.schema import AgentAction, AgentFinish, AIMessage, HumanMessage, OutputParserException
|
||||
from langchain.schema import AgentAction, AgentFinish, AIMessage, HumanMessage, OutputParserException, BaseMessage, \
|
||||
get_buffer_string
|
||||
from langchain.tools import BaseTool
|
||||
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
|
||||
|
||||
from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
|
||||
from core.chain.llm_chain import LLMChain
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
|
||||
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
||||
@ -52,8 +53,7 @@ Action:
|
||||
class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
||||
moving_summary_buffer: str = ""
|
||||
moving_summary_index: int = 0
|
||||
summary_llm: BaseLanguageModel = None
|
||||
model_instance: BaseLLM
|
||||
summary_model_instance: BaseLLM = None
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
@ -95,14 +95,14 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
||||
if prompts:
|
||||
messages = prompts[0].to_messages()
|
||||
|
||||
rest_tokens = self.get_message_rest_tokens(self.model_instance, messages)
|
||||
rest_tokens = self.get_message_rest_tokens(self.llm_chain.model_instance, messages)
|
||||
if rest_tokens < 0:
|
||||
full_inputs = self.summarize_messages(intermediate_steps, **kwargs)
|
||||
|
||||
try:
|
||||
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
|
||||
except Exception as e:
|
||||
new_exception = self.model_instance.handle_exceptions(e)
|
||||
new_exception = self.llm_chain.model_instance.handle_exceptions(e)
|
||||
raise new_exception
|
||||
|
||||
try:
|
||||
@ -118,7 +118,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
||||
"I don't know how to respond to that."}, "")
|
||||
|
||||
def summarize_messages(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs):
|
||||
if len(intermediate_steps) >= 2 and self.summary_llm:
|
||||
if len(intermediate_steps) >= 2 and self.summary_model_instance:
|
||||
should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1]
|
||||
should_summary_messages = [AIMessage(content=observation)
|
||||
for _, observation in should_summary_intermediate_steps]
|
||||
@ -130,11 +130,10 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
||||
error_msg = "Exceeded LLM tokens limit, stopped."
|
||||
raise ExceededLLMTokensLimitError(error_msg)
|
||||
|
||||
summary_handler = SummarizerMixin(llm=self.summary_llm)
|
||||
if self.moving_summary_buffer and 'chat_history' in kwargs:
|
||||
kwargs["chat_history"].pop()
|
||||
|
||||
self.moving_summary_buffer = summary_handler.predict_new_summary(
|
||||
self.moving_summary_buffer = self.predict_new_summary(
|
||||
messages=should_summary_messages,
|
||||
existing_summary=self.moving_summary_buffer
|
||||
)
|
||||
@ -144,6 +143,18 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
||||
|
||||
return self.get_full_inputs([intermediate_steps[-1]], **kwargs)
|
||||
|
||||
def predict_new_summary(
|
||||
self, messages: List[BaseMessage], existing_summary: str
|
||||
) -> str:
|
||||
new_lines = get_buffer_string(
|
||||
messages,
|
||||
human_prefix="Human",
|
||||
ai_prefix="AI",
|
||||
)
|
||||
|
||||
chain = LLMChain(model_instance=self.summary_model_instance, prompt=SUMMARY_PROMPT)
|
||||
return chain.predict(summary=existing_summary, new_lines=new_lines)
|
||||
|
||||
@classmethod
|
||||
def create_prompt(
|
||||
cls,
|
||||
@ -176,7 +187,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
||||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
model_instance: BaseLLM,
|
||||
tools: Sequence[BaseTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
output_parser: Optional[AgentOutputParser] = None,
|
||||
@ -188,16 +199,27 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
||||
memory_prompts: Optional[List[BasePromptTemplate]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Agent:
|
||||
return super().from_llm_and_tools(
|
||||
llm=llm,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
output_parser=output_parser,
|
||||
"""Construct an agent from an LLM and tools."""
|
||||
cls._validate_tools(tools)
|
||||
prompt = cls.create_prompt(
|
||||
tools,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
human_message_template=human_message_template,
|
||||
format_instructions=format_instructions,
|
||||
input_variables=input_variables,
|
||||
memory_prompts=memory_prompts,
|
||||
)
|
||||
llm_chain = LLMChain(
|
||||
model_instance=model_instance,
|
||||
prompt=prompt,
|
||||
callback_manager=callback_manager,
|
||||
)
|
||||
tool_names = [tool.name for tool in tools]
|
||||
_output_parser = output_parser
|
||||
return cls(
|
||||
llm_chain=llm_chain,
|
||||
allowed_tools=tool_names,
|
||||
output_parser=_output_parser,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -10,7 +10,6 @@ from pydantic import BaseModel, Extra
|
||||
|
||||
from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
|
||||
from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent
|
||||
from core.agent.agent.openai_multi_function_call import AutoSummarizingOpenMultiAIFunctionCallAgent
|
||||
from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser
|
||||
from core.agent.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent
|
||||
from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent
|
||||
@ -27,7 +26,6 @@ class PlanningStrategy(str, enum.Enum):
|
||||
REACT_ROUTER = 'react_router'
|
||||
REACT = 'react'
|
||||
FUNCTION_CALL = 'function_call'
|
||||
MULTI_FUNCTION_CALL = 'multi_function_call'
|
||||
|
||||
|
||||
class AgentConfiguration(BaseModel):
|
||||
@ -64,30 +62,18 @@ class AgentExecutor:
|
||||
if self.configuration.strategy == PlanningStrategy.REACT:
|
||||
agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools(
|
||||
model_instance=self.configuration.model_instance,
|
||||
llm=self.configuration.model_instance.client,
|
||||
tools=self.configuration.tools,
|
||||
output_parser=StructuredChatOutputParser(),
|
||||
summary_llm=self.configuration.summary_model_instance.client
|
||||
summary_model_instance=self.configuration.summary_model_instance
|
||||
if self.configuration.summary_model_instance else None,
|
||||
verbose=True
|
||||
)
|
||||
elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL:
|
||||
agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools(
|
||||
model_instance=self.configuration.model_instance,
|
||||
llm=self.configuration.model_instance.client,
|
||||
tools=self.configuration.tools,
|
||||
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory
|
||||
summary_llm=self.configuration.summary_model_instance.client
|
||||
if self.configuration.summary_model_instance else None,
|
||||
verbose=True
|
||||
)
|
||||
elif self.configuration.strategy == PlanningStrategy.MULTI_FUNCTION_CALL:
|
||||
agent = AutoSummarizingOpenMultiAIFunctionCallAgent.from_llm_and_tools(
|
||||
model_instance=self.configuration.model_instance,
|
||||
llm=self.configuration.model_instance.client,
|
||||
tools=self.configuration.tools,
|
||||
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory
|
||||
summary_llm=self.configuration.summary_model_instance.client
|
||||
summary_model_instance=self.configuration.summary_model_instance
|
||||
if self.configuration.summary_model_instance else None,
|
||||
verbose=True
|
||||
)
|
||||
@ -95,7 +81,6 @@ class AgentExecutor:
|
||||
self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)]
|
||||
agent = MultiDatasetRouterAgent.from_llm_and_tools(
|
||||
model_instance=self.configuration.model_instance,
|
||||
llm=self.configuration.model_instance.client,
|
||||
tools=self.configuration.tools,
|
||||
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None,
|
||||
verbose=True
|
||||
@ -104,7 +89,6 @@ class AgentExecutor:
|
||||
self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)]
|
||||
agent = StructuredMultiDatasetRouterAgent.from_llm_and_tools(
|
||||
model_instance=self.configuration.model_instance,
|
||||
llm=self.configuration.model_instance.client,
|
||||
tools=self.configuration.tools,
|
||||
output_parser=StructuredChatOutputParser(),
|
||||
verbose=True
|
||||
|
||||
36
api/core/chain/llm_chain.py
Normal file
36
api/core/chain/llm_chain.py
Normal file
@ -0,0 +1,36 @@
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
from langchain import LLMChain as LCLLMChain
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.schema import LLMResult, Generation
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
|
||||
from core.model_providers.models.entity.message import to_prompt_messages
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.third_party.langchain.llms.fake import FakeLLM
|
||||
|
||||
|
||||
class LLMChain(LCLLMChain):
|
||||
model_instance: BaseLLM
|
||||
"""The language model instance to use."""
|
||||
llm: BaseLanguageModel = FakeLLM(response="")
|
||||
|
||||
def generate(
|
||||
self,
|
||||
input_list: List[Dict[str, Any]],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> LLMResult:
|
||||
"""Generate LLM result from inputs."""
|
||||
prompts, stop = self.prep_prompts(input_list, run_manager=run_manager)
|
||||
messages = prompts[0].to_messages()
|
||||
prompt_messages = to_prompt_messages(messages)
|
||||
result = self.model_instance.run(
|
||||
messages=prompt_messages,
|
||||
stop=stop
|
||||
)
|
||||
|
||||
generations = [
|
||||
[Generation(text=result.content)]
|
||||
]
|
||||
|
||||
return LLMResult(generations=generations)
|
||||
@ -1,4 +1,3 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, List, Union
|
||||
|
||||
@ -16,10 +15,8 @@ from core.model_providers.model_factory import ModelFactory
|
||||
from core.model_providers.models.entity.message import PromptMessage
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.orchestrator_rule_parser import OrchestratorRuleParser
|
||||
from core.prompt.prompt_builder import PromptBuilder
|
||||
from core.prompt.prompts import MORE_LIKE_THIS_GENERATE_PROMPT
|
||||
from models.dataset import DocumentSegment, Dataset, Document
|
||||
from models.model import App, AppModelConfig, Account, Conversation, Message, EndUser
|
||||
from core.prompt.prompt_template import PromptTemplateParser
|
||||
from models.model import App, AppModelConfig, Account, Conversation, EndUser
|
||||
|
||||
|
||||
class Completion:
|
||||
@ -30,7 +27,7 @@ class Completion:
|
||||
"""
|
||||
errors: ProviderTokenNotInitError
|
||||
"""
|
||||
query = PromptBuilder.process_template(query)
|
||||
query = PromptTemplateParser.remove_template_variables(query)
|
||||
|
||||
memory = None
|
||||
if conversation:
|
||||
@ -160,14 +157,28 @@ class Completion:
|
||||
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory],
|
||||
fake_response: Optional[str]):
|
||||
# get llm prompt
|
||||
prompt_messages, stop_words = model_instance.get_prompt(
|
||||
mode=mode,
|
||||
pre_prompt=app_model_config.pre_prompt,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
context=agent_execute_result.output if agent_execute_result else None,
|
||||
memory=memory
|
||||
)
|
||||
if app_model_config.prompt_type == 'simple':
|
||||
prompt_messages, stop_words = model_instance.get_prompt(
|
||||
mode=mode,
|
||||
pre_prompt=app_model_config.pre_prompt,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
context=agent_execute_result.output if agent_execute_result else None,
|
||||
memory=memory
|
||||
)
|
||||
else:
|
||||
prompt_messages = model_instance.get_advanced_prompt(
|
||||
app_mode=mode,
|
||||
app_model_config=app_model_config,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
context=agent_execute_result.output if agent_execute_result else None,
|
||||
memory=memory
|
||||
)
|
||||
|
||||
model_config = app_model_config.model_dict
|
||||
completion_params = model_config.get("completion_params", {})
|
||||
stop_words = completion_params.get("stop", [])
|
||||
|
||||
cls.recale_llm_max_tokens(
|
||||
model_instance=model_instance,
|
||||
@ -176,7 +187,7 @@ class Completion:
|
||||
|
||||
response = model_instance.run(
|
||||
messages=prompt_messages,
|
||||
stop=stop_words,
|
||||
stop=stop_words if stop_words else None,
|
||||
callbacks=[LLMCallbackHandler(model_instance, conversation_message_task)],
|
||||
fake_response=fake_response
|
||||
)
|
||||
@ -266,52 +277,3 @@ class Completion:
|
||||
model_kwargs = model_instance.get_model_kwargs()
|
||||
model_kwargs.max_tokens = max_tokens
|
||||
model_instance.set_model_kwargs(model_kwargs)
|
||||
|
||||
@classmethod
|
||||
def generate_more_like_this(cls, task_id: str, app: App, message: Message, pre_prompt: str,
|
||||
app_model_config: AppModelConfig, user: Account, streaming: bool):
|
||||
|
||||
final_model_instance = ModelFactory.get_text_generation_model_from_model_config(
|
||||
tenant_id=app.tenant_id,
|
||||
model_config=app_model_config.model_dict,
|
||||
streaming=streaming
|
||||
)
|
||||
|
||||
# get llm prompt
|
||||
old_prompt_messages, _ = final_model_instance.get_prompt(
|
||||
mode='completion',
|
||||
pre_prompt=pre_prompt,
|
||||
inputs=message.inputs,
|
||||
query=message.query,
|
||||
context=None,
|
||||
memory=None
|
||||
)
|
||||
|
||||
original_completion = message.answer.strip()
|
||||
|
||||
prompt = MORE_LIKE_THIS_GENERATE_PROMPT
|
||||
prompt = prompt.format(prompt=old_prompt_messages[0].content, original_completion=original_completion)
|
||||
|
||||
prompt_messages = [PromptMessage(content=prompt)]
|
||||
|
||||
conversation_message_task = ConversationMessageTask(
|
||||
task_id=task_id,
|
||||
app=app,
|
||||
app_model_config=app_model_config,
|
||||
user=user,
|
||||
inputs=message.inputs,
|
||||
query=message.query,
|
||||
is_override=True if message.override_model_configs else False,
|
||||
streaming=streaming,
|
||||
model_instance=final_model_instance
|
||||
)
|
||||
|
||||
cls.recale_llm_max_tokens(
|
||||
model_instance=final_model_instance,
|
||||
prompt_messages=prompt_messages
|
||||
)
|
||||
|
||||
final_model_instance.run(
|
||||
messages=prompt_messages,
|
||||
callbacks=[LLMCallbackHandler(final_model_instance, conversation_message_task)]
|
||||
)
|
||||
|
||||
@ -10,7 +10,7 @@ from core.model_providers.model_factory import ModelFactory
|
||||
from core.model_providers.models.entity.message import to_prompt_messages, MessageType
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.prompt.prompt_builder import PromptBuilder
|
||||
from core.prompt.prompt_template import JinjaPromptTemplate
|
||||
from core.prompt.prompt_template import PromptTemplateParser
|
||||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
@ -74,10 +74,10 @@ class ConversationMessageTask:
|
||||
if self.mode == 'chat':
|
||||
introduction = self.app_model_config.opening_statement
|
||||
if introduction:
|
||||
prompt_template = JinjaPromptTemplate.from_template(template=introduction)
|
||||
prompt_inputs = {k: self.inputs[k] for k in prompt_template.input_variables if k in self.inputs}
|
||||
prompt_template = PromptTemplateParser(template=introduction)
|
||||
prompt_inputs = {k: self.inputs[k] for k in prompt_template.variable_keys if k in self.inputs}
|
||||
try:
|
||||
introduction = prompt_template.format(**prompt_inputs)
|
||||
introduction = prompt_template.format(prompt_inputs)
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
@ -150,12 +150,12 @@ class ConversationMessageTask:
|
||||
message_tokens = llm_message.prompt_tokens
|
||||
answer_tokens = llm_message.completion_tokens
|
||||
|
||||
message_unit_price = self.model_instance.get_tokens_unit_price(MessageType.HUMAN)
|
||||
message_price_unit = self.model_instance.get_price_unit(MessageType.HUMAN)
|
||||
message_unit_price = self.model_instance.get_tokens_unit_price(MessageType.USER)
|
||||
message_price_unit = self.model_instance.get_price_unit(MessageType.USER)
|
||||
answer_unit_price = self.model_instance.get_tokens_unit_price(MessageType.ASSISTANT)
|
||||
answer_price_unit = self.model_instance.get_price_unit(MessageType.ASSISTANT)
|
||||
|
||||
message_total_price = self.model_instance.calc_tokens_price(message_tokens, MessageType.HUMAN)
|
||||
message_total_price = self.model_instance.calc_tokens_price(message_tokens, MessageType.USER)
|
||||
answer_total_price = self.model_instance.calc_tokens_price(answer_tokens, MessageType.ASSISTANT)
|
||||
total_price = message_total_price + answer_total_price
|
||||
|
||||
@ -163,7 +163,7 @@ class ConversationMessageTask:
|
||||
self.message.message_tokens = message_tokens
|
||||
self.message.message_unit_price = message_unit_price
|
||||
self.message.message_price_unit = message_price_unit
|
||||
self.message.answer = PromptBuilder.process_template(
|
||||
self.message.answer = PromptTemplateParser.remove_template_variables(
|
||||
llm_message.completion.strip()) if llm_message.completion else ''
|
||||
self.message.answer_tokens = answer_tokens
|
||||
self.message.answer_unit_price = answer_unit_price
|
||||
@ -226,15 +226,15 @@ class ConversationMessageTask:
|
||||
|
||||
def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instance: BaseLLM,
|
||||
agent_loop: AgentLoop):
|
||||
agent_message_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.HUMAN)
|
||||
agent_message_price_unit = agent_model_instance.get_price_unit(MessageType.HUMAN)
|
||||
agent_message_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.USER)
|
||||
agent_message_price_unit = agent_model_instance.get_price_unit(MessageType.USER)
|
||||
agent_answer_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.ASSISTANT)
|
||||
agent_answer_price_unit = agent_model_instance.get_price_unit(MessageType.ASSISTANT)
|
||||
|
||||
loop_message_tokens = agent_loop.prompt_tokens
|
||||
loop_answer_tokens = agent_loop.completion_tokens
|
||||
|
||||
loop_message_total_price = agent_model_instance.calc_tokens_price(loop_message_tokens, MessageType.HUMAN)
|
||||
loop_message_total_price = agent_model_instance.calc_tokens_price(loop_message_tokens, MessageType.USER)
|
||||
loop_answer_total_price = agent_model_instance.calc_tokens_price(loop_answer_tokens, MessageType.ASSISTANT)
|
||||
loop_total_price = loop_message_total_price + loop_answer_total_price
|
||||
|
||||
|
||||
@ -10,9 +10,8 @@ from core.model_providers.models.entity.model_params import ModelKwargs
|
||||
from core.prompt.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
|
||||
|
||||
from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
|
||||
from core.prompt.prompt_template import JinjaPromptTemplate, OutLinePromptTemplate
|
||||
from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, CONVERSATION_SUMMARY_PROMPT, INTRODUCTION_GENERATE_PROMPT, \
|
||||
GENERATOR_QA_PROMPT
|
||||
from core.prompt.prompt_template import PromptTemplateParser
|
||||
from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, GENERATOR_QA_PROMPT
|
||||
|
||||
|
||||
class LLMGenerator:
|
||||
@ -44,78 +43,19 @@ class LLMGenerator:
|
||||
|
||||
return answer.strip()
|
||||
|
||||
@classmethod
|
||||
def generate_conversation_summary(cls, tenant_id: str, messages):
|
||||
max_tokens = 200
|
||||
|
||||
model_instance = ModelFactory.get_text_generation_model(
|
||||
tenant_id=tenant_id,
|
||||
model_kwargs=ModelKwargs(
|
||||
max_tokens=max_tokens
|
||||
)
|
||||
)
|
||||
|
||||
prompt = CONVERSATION_SUMMARY_PROMPT
|
||||
prompt_with_empty_context = prompt.format(context='')
|
||||
prompt_tokens = model_instance.get_num_tokens([PromptMessage(content=prompt_with_empty_context)])
|
||||
max_context_token_length = model_instance.model_rules.max_tokens.max
|
||||
max_context_token_length = max_context_token_length if max_context_token_length else 1500
|
||||
rest_tokens = max_context_token_length - prompt_tokens - max_tokens - 1
|
||||
|
||||
context = ''
|
||||
for message in messages:
|
||||
if not message.answer:
|
||||
continue
|
||||
|
||||
if len(message.query) > 2000:
|
||||
query = message.query[:300] + "...[TRUNCATED]..." + message.query[-300:]
|
||||
else:
|
||||
query = message.query
|
||||
|
||||
if len(message.answer) > 2000:
|
||||
answer = message.answer[:300] + "...[TRUNCATED]..." + message.answer[-300:]
|
||||
else:
|
||||
answer = message.answer
|
||||
|
||||
message_qa_text = "\n\nHuman:" + query + "\n\nAssistant:" + answer
|
||||
if rest_tokens - model_instance.get_num_tokens([PromptMessage(content=context + message_qa_text)]) > 0:
|
||||
context += message_qa_text
|
||||
|
||||
if not context:
|
||||
return '[message too long, no summary]'
|
||||
|
||||
prompt = prompt.format(context=context)
|
||||
prompts = [PromptMessage(content=prompt)]
|
||||
response = model_instance.run(prompts)
|
||||
answer = response.content
|
||||
return answer.strip()
|
||||
|
||||
@classmethod
|
||||
def generate_introduction(cls, tenant_id: str, pre_prompt: str):
|
||||
prompt = INTRODUCTION_GENERATE_PROMPT
|
||||
prompt = prompt.format(prompt=pre_prompt)
|
||||
|
||||
model_instance = ModelFactory.get_text_generation_model(
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
prompts = [PromptMessage(content=prompt)]
|
||||
response = model_instance.run(prompts)
|
||||
answer = response.content
|
||||
return answer.strip()
|
||||
|
||||
@classmethod
|
||||
def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: str):
|
||||
output_parser = SuggestedQuestionsAfterAnswerOutputParser()
|
||||
format_instructions = output_parser.get_format_instructions()
|
||||
|
||||
prompt = JinjaPromptTemplate(
|
||||
template="{{histories}}\n{{format_instructions}}\nquestions:\n",
|
||||
input_variables=["histories"],
|
||||
partial_variables={"format_instructions": format_instructions}
|
||||
prompt_template = PromptTemplateParser(
|
||||
template="{{histories}}\n{{format_instructions}}\nquestions:\n"
|
||||
)
|
||||
|
||||
_input = prompt.format_prompt(histories=histories)
|
||||
prompt = prompt_template.format({
|
||||
"histories": histories,
|
||||
"format_instructions": format_instructions
|
||||
})
|
||||
|
||||
try:
|
||||
model_instance = ModelFactory.get_text_generation_model(
|
||||
@ -128,10 +68,10 @@ class LLMGenerator:
|
||||
except ProviderTokenNotInitError:
|
||||
return []
|
||||
|
||||
prompts = [PromptMessage(content=_input.to_string())]
|
||||
prompt_messages = [PromptMessage(content=prompt)]
|
||||
|
||||
try:
|
||||
output = model_instance.run(prompts)
|
||||
output = model_instance.run(prompt_messages)
|
||||
questions = output_parser.parse(output.content)
|
||||
except LLMError:
|
||||
questions = []
|
||||
@ -145,19 +85,21 @@ class LLMGenerator:
|
||||
def generate_rule_config(cls, tenant_id: str, audiences: str, hoping_to_solve: str) -> dict:
|
||||
output_parser = RuleConfigGeneratorOutputParser()
|
||||
|
||||
prompt = OutLinePromptTemplate(
|
||||
template=output_parser.get_format_instructions(),
|
||||
input_variables=["audiences", "hoping_to_solve"],
|
||||
partial_variables={
|
||||
"variable": '{variable}',
|
||||
"lanA": '{lanA}',
|
||||
"lanB": '{lanB}',
|
||||
"topic": '{topic}'
|
||||
},
|
||||
validate_template=False
|
||||
prompt_template = PromptTemplateParser(
|
||||
template=output_parser.get_format_instructions()
|
||||
)
|
||||
|
||||
_input = prompt.format_prompt(audiences=audiences, hoping_to_solve=hoping_to_solve)
|
||||
prompt = prompt_template.format(
|
||||
inputs={
|
||||
"audiences": audiences,
|
||||
"hoping_to_solve": hoping_to_solve,
|
||||
"variable": "{{variable}}",
|
||||
"lanA": "{{lanA}}",
|
||||
"lanB": "{{lanB}}",
|
||||
"topic": "{{topic}}"
|
||||
},
|
||||
remove_template_variables=False
|
||||
)
|
||||
|
||||
model_instance = ModelFactory.get_text_generation_model(
|
||||
tenant_id=tenant_id,
|
||||
@ -167,10 +109,10 @@ class LLMGenerator:
|
||||
)
|
||||
)
|
||||
|
||||
prompts = [PromptMessage(content=_input.to_string())]
|
||||
prompt_messages = [PromptMessage(content=prompt)]
|
||||
|
||||
try:
|
||||
output = model_instance.run(prompts)
|
||||
output = model_instance.run(prompt_messages)
|
||||
rule_config = output_parser.parse(output.content)
|
||||
except LLMError as e:
|
||||
raise e
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import random
|
||||
|
||||
import openai
|
||||
|
||||
@ -16,19 +17,20 @@ def check_moderation(model_provider: BaseModelProvider, text: str) -> bool:
|
||||
length = 2000
|
||||
text_chunks = [text[i:i + length] for i in range(0, len(text), length)]
|
||||
|
||||
max_text_chunks = 32
|
||||
chunks = [text_chunks[i:i + max_text_chunks] for i in range(0, len(text_chunks), max_text_chunks)]
|
||||
if len(text_chunks) == 0:
|
||||
return True
|
||||
|
||||
for text_chunk in chunks:
|
||||
try:
|
||||
moderation_result = openai.Moderation.create(input=text_chunk,
|
||||
api_key=hosted_model_providers.openai.api_key)
|
||||
except Exception as ex:
|
||||
logging.exception(ex)
|
||||
raise LLMBadRequestError('Rate limit exceeded, please try again later.')
|
||||
text_chunk = random.choice(text_chunks)
|
||||
|
||||
for result in moderation_result.results:
|
||||
if result['flagged'] is True:
|
||||
return False
|
||||
try:
|
||||
moderation_result = openai.Moderation.create(input=text_chunk,
|
||||
api_key=hosted_model_providers.openai.api_key)
|
||||
except Exception as ex:
|
||||
logging.exception(ex)
|
||||
raise LLMBadRequestError('Rate limit exceeded, please try again later.')
|
||||
|
||||
for result in moderation_result.results:
|
||||
if result['flagged'] is True:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@ -167,8 +167,6 @@ class Milvus(VectorStore):
|
||||
self._init()
|
||||
|
||||
@property
|
||||
|
||||
|
||||
def embeddings(self) -> Embeddings:
|
||||
return self.embedding_func
|
||||
|
||||
|
||||
@ -11,6 +11,7 @@ from flask import current_app, Flask
|
||||
from flask_login import current_user
|
||||
from langchain.schema import Document
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
|
||||
from sqlalchemy.orm.exc import ObjectDeletedError
|
||||
|
||||
from core.data_loader.file_extractor import FileExtractor
|
||||
from core.data_loader.loader.notion import NotionLoader
|
||||
@ -79,6 +80,8 @@ class IndexingRunner:
|
||||
dataset_document.error = str(e.description)
|
||||
dataset_document.stopped_at = datetime.datetime.utcnow()
|
||||
db.session.commit()
|
||||
except ObjectDeletedError:
|
||||
logging.warning('Document deleted, document id: {}'.format(dataset_document.id))
|
||||
except Exception as e:
|
||||
logging.exception("consume document failed")
|
||||
dataset_document.indexing_status = 'error'
|
||||
@ -276,13 +279,14 @@ class IndexingRunner:
|
||||
)
|
||||
if len(preview_texts) > 0:
|
||||
# qa model document
|
||||
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0], doc_language)
|
||||
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0],
|
||||
doc_language)
|
||||
document_qa_list = self.format_split_text(response)
|
||||
return {
|
||||
"total_segments": total_segments * 20,
|
||||
"tokens": total_segments * 2000,
|
||||
"total_price": '{:f}'.format(
|
||||
text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.HUMAN)),
|
||||
text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.USER)),
|
||||
"currency": embedding_model.get_currency(),
|
||||
"qa_preview": document_qa_list,
|
||||
"preview": preview_texts
|
||||
@ -372,13 +376,14 @@ class IndexingRunner:
|
||||
)
|
||||
if len(preview_texts) > 0:
|
||||
# qa model document
|
||||
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0], doc_language)
|
||||
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0],
|
||||
doc_language)
|
||||
document_qa_list = self.format_split_text(response)
|
||||
return {
|
||||
"total_segments": total_segments * 20,
|
||||
"tokens": total_segments * 2000,
|
||||
"total_price": '{:f}'.format(
|
||||
text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.HUMAN)),
|
||||
text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.USER)),
|
||||
"currency": embedding_model.get_currency(),
|
||||
"qa_preview": document_qa_list,
|
||||
"preview": preview_texts
|
||||
@ -582,7 +587,6 @@ class IndexingRunner:
|
||||
|
||||
all_qa_documents.extend(format_documents)
|
||||
|
||||
|
||||
def _split_to_documents_for_estimate(self, text_docs: List[Document], splitter: TextSplitter,
|
||||
processing_rule: DatasetProcessRule) -> List[Document]:
|
||||
"""
|
||||
@ -734,6 +738,9 @@ class IndexingRunner:
|
||||
count = DatasetDocument.query.filter_by(id=document_id, is_paused=True).count()
|
||||
if count > 0:
|
||||
raise DocumentIsPausedException()
|
||||
document = DatasetDocument.query.filter_by(id=document_id).first()
|
||||
if not document:
|
||||
raise DocumentIsDeletedPausedException()
|
||||
|
||||
update_params = {
|
||||
DatasetDocument.indexing_status: after_indexing_status
|
||||
@ -781,3 +788,7 @@ class IndexingRunner:
|
||||
|
||||
class DocumentIsPausedException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class DocumentIsDeletedPausedException(Exception):
|
||||
pass
|
||||
|
||||
@ -31,7 +31,7 @@ class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory):
|
||||
|
||||
chat_messages: List[PromptMessage] = []
|
||||
for message in messages:
|
||||
chat_messages.append(PromptMessage(content=message.query, type=MessageType.HUMAN))
|
||||
chat_messages.append(PromptMessage(content=message.query, type=MessageType.USER))
|
||||
chat_messages.append(PromptMessage(content=message.answer, type=MessageType.ASSISTANT))
|
||||
|
||||
if not chat_messages:
|
||||
|
||||
@ -1,8 +1,7 @@
|
||||
from core.third_party.langchain.embeddings.xinference_embedding import XinferenceEmbedding as XinferenceEmbeddings
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
from core.model_providers.models.embedding.base import BaseEmbedding
|
||||
from core.third_party.langchain.embeddings.xinference_embedding import XinferenceEmbeddings
|
||||
|
||||
|
||||
class XinferenceEmbedding(BaseEmbedding):
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import enum
|
||||
|
||||
from langchain.schema import HumanMessage, AIMessage, SystemMessage, BaseMessage
|
||||
from langchain.schema import HumanMessage, AIMessage, SystemMessage, BaseMessage, FunctionMessage
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@ -9,26 +9,31 @@ class LLMRunResult(BaseModel):
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
source: list = None
|
||||
function_call: dict = None
|
||||
|
||||
|
||||
class MessageType(enum.Enum):
|
||||
HUMAN = 'human'
|
||||
USER = 'user'
|
||||
ASSISTANT = 'assistant'
|
||||
SYSTEM = 'system'
|
||||
|
||||
|
||||
class PromptMessage(BaseModel):
|
||||
type: MessageType = MessageType.HUMAN
|
||||
type: MessageType = MessageType.USER
|
||||
content: str = ''
|
||||
function_call: dict = None
|
||||
|
||||
|
||||
def to_lc_messages(messages: list[PromptMessage]):
|
||||
lc_messages = []
|
||||
for message in messages:
|
||||
if message.type == MessageType.HUMAN:
|
||||
if message.type == MessageType.USER:
|
||||
lc_messages.append(HumanMessage(content=message.content))
|
||||
elif message.type == MessageType.ASSISTANT:
|
||||
lc_messages.append(AIMessage(content=message.content))
|
||||
additional_kwargs = {}
|
||||
if message.function_call:
|
||||
additional_kwargs['function_call'] = message.function_call
|
||||
lc_messages.append(AIMessage(content=message.content, additional_kwargs=additional_kwargs))
|
||||
elif message.type == MessageType.SYSTEM:
|
||||
lc_messages.append(SystemMessage(content=message.content))
|
||||
|
||||
@ -39,11 +44,21 @@ def to_prompt_messages(messages: list[BaseMessage]):
|
||||
prompt_messages = []
|
||||
for message in messages:
|
||||
if isinstance(message, HumanMessage):
|
||||
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.HUMAN))
|
||||
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.USER))
|
||||
elif isinstance(message, AIMessage):
|
||||
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.ASSISTANT))
|
||||
message_kwargs = {
|
||||
'content': message.content,
|
||||
'type': MessageType.ASSISTANT
|
||||
}
|
||||
|
||||
if 'function_call' in message.additional_kwargs:
|
||||
message_kwargs['function_call'] = message.additional_kwargs['function_call']
|
||||
|
||||
prompt_messages.append(PromptMessage(**message_kwargs))
|
||||
elif isinstance(message, SystemMessage):
|
||||
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.SYSTEM))
|
||||
elif isinstance(message, FunctionMessage):
|
||||
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.USER))
|
||||
return prompt_messages
|
||||
|
||||
|
||||
|
||||
@ -81,7 +81,20 @@ class AzureOpenAIModel(BaseLLM):
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return self._client.generate([prompts], stop, callbacks)
|
||||
generate_kwargs = {
|
||||
'stop': stop,
|
||||
'callbacks': callbacks
|
||||
}
|
||||
|
||||
if isinstance(prompts, str):
|
||||
generate_kwargs['prompts'] = [prompts]
|
||||
else:
|
||||
generate_kwargs['messages'] = [prompts]
|
||||
|
||||
if 'functions' in kwargs:
|
||||
generate_kwargs['functions'] = kwargs['functions']
|
||||
|
||||
return self._client.generate(**generate_kwargs)
|
||||
|
||||
@property
|
||||
def base_model_name(self) -> str:
|
||||
|
||||
@ -13,11 +13,12 @@ from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage,
|
||||
from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, DifyStdOutCallbackHandler
|
||||
from core.helper import moderation
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult, to_prompt_messages
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult, to_prompt_messages, \
|
||||
to_lc_messages
|
||||
from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
from core.prompt.prompt_builder import PromptBuilder
|
||||
from core.prompt.prompt_template import JinjaPromptTemplate
|
||||
from core.prompt.prompt_template import PromptTemplateParser
|
||||
from core.third_party.langchain.llms.fake import FakeLLM
|
||||
import logging
|
||||
|
||||
@ -157,8 +158,11 @@ class BaseLLM(BaseProviderModel):
|
||||
except Exception as ex:
|
||||
raise self.handle_exceptions(ex)
|
||||
|
||||
function_call = None
|
||||
if isinstance(result.generations[0][0], ChatGeneration):
|
||||
completion_content = result.generations[0][0].message.content
|
||||
if 'function_call' in result.generations[0][0].message.additional_kwargs:
|
||||
function_call = result.generations[0][0].message.additional_kwargs.get('function_call')
|
||||
else:
|
||||
completion_content = result.generations[0][0].text
|
||||
|
||||
@ -191,7 +195,8 @@ class BaseLLM(BaseProviderModel):
|
||||
return LLMRunResult(
|
||||
content=completion_content,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens
|
||||
completion_tokens=completion_tokens,
|
||||
function_call=function_call
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
@ -227,7 +232,7 @@ class BaseLLM(BaseProviderModel):
|
||||
:param message_type:
|
||||
:return:
|
||||
"""
|
||||
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
|
||||
if message_type == MessageType.USER or message_type == MessageType.SYSTEM:
|
||||
unit_price = self.price_config['prompt']
|
||||
else:
|
||||
unit_price = self.price_config['completion']
|
||||
@ -245,7 +250,7 @@ class BaseLLM(BaseProviderModel):
|
||||
:param message_type:
|
||||
:return: decimal.Decimal('0.0001')
|
||||
"""
|
||||
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
|
||||
if message_type == MessageType.USER or message_type == MessageType.SYSTEM:
|
||||
unit_price = self.price_config['prompt']
|
||||
else:
|
||||
unit_price = self.price_config['completion']
|
||||
@ -260,7 +265,7 @@ class BaseLLM(BaseProviderModel):
|
||||
:param message_type:
|
||||
:return: decimal.Decimal('0.000001')
|
||||
"""
|
||||
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
|
||||
if message_type == MessageType.USER or message_type == MessageType.SYSTEM:
|
||||
price_unit = self.price_config['unit']
|
||||
else:
|
||||
price_unit = self.price_config['unit']
|
||||
@ -325,6 +330,85 @@ class BaseLLM(BaseProviderModel):
|
||||
prompt, stops = self._get_prompt_and_stop(prompt_rules, pre_prompt, inputs, query, context, memory)
|
||||
return [PromptMessage(content=prompt)], stops
|
||||
|
||||
def get_advanced_prompt(self, app_mode: str,
|
||||
app_model_config: str, inputs: dict,
|
||||
query: str,
|
||||
context: Optional[str],
|
||||
memory: Optional[BaseChatMemory]) -> List[PromptMessage]:
|
||||
|
||||
model_mode = app_model_config.model_dict['mode']
|
||||
conversation_histories_role = {}
|
||||
|
||||
raw_prompt_list = []
|
||||
prompt_messages = []
|
||||
|
||||
if app_mode == 'chat' and model_mode == ModelMode.COMPLETION.value:
|
||||
prompt_text = app_model_config.completion_prompt_config_dict['prompt']['text']
|
||||
raw_prompt_list = [{
|
||||
'role': MessageType.USER.value,
|
||||
'text': prompt_text
|
||||
}]
|
||||
conversation_histories_role = app_model_config.completion_prompt_config_dict['conversation_histories_role']
|
||||
elif app_mode == 'chat' and model_mode == ModelMode.CHAT.value:
|
||||
raw_prompt_list = app_model_config.chat_prompt_config_dict['prompt']
|
||||
elif app_mode == 'completion' and model_mode == ModelMode.CHAT.value:
|
||||
raw_prompt_list = app_model_config.chat_prompt_config_dict['prompt']
|
||||
elif app_mode == 'completion' and model_mode == ModelMode.COMPLETION.value:
|
||||
prompt_text = app_model_config.completion_prompt_config_dict['prompt']['text']
|
||||
raw_prompt_list = [{
|
||||
'role': MessageType.USER.value,
|
||||
'text': prompt_text
|
||||
}]
|
||||
else:
|
||||
raise Exception("app_mode or model_mode not support")
|
||||
|
||||
for prompt_item in raw_prompt_list:
|
||||
prompt = prompt_item['text']
|
||||
|
||||
# set prompt template variables
|
||||
prompt_template = PromptTemplateParser(template=prompt)
|
||||
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
|
||||
|
||||
if '#context#' in prompt:
|
||||
if context:
|
||||
prompt_inputs['#context#'] = context
|
||||
else:
|
||||
prompt_inputs['#context#'] = ''
|
||||
|
||||
if '#query#' in prompt:
|
||||
if query:
|
||||
prompt_inputs['#query#'] = query
|
||||
else:
|
||||
prompt_inputs['#query#'] = ''
|
||||
|
||||
if '#histories#' in prompt:
|
||||
if memory and app_mode == 'chat' and model_mode == ModelMode.COMPLETION.value:
|
||||
memory.human_prefix = conversation_histories_role['user_prefix']
|
||||
memory.ai_prefix = conversation_histories_role['assistant_prefix']
|
||||
histories = self._get_history_messages_from_memory(memory, 2000)
|
||||
prompt_inputs['#histories#'] = histories
|
||||
else:
|
||||
prompt_inputs['#histories#'] = ''
|
||||
|
||||
prompt = prompt_template.format(
|
||||
prompt_inputs
|
||||
)
|
||||
|
||||
prompt = re.sub(r'<\|.*?\|>', '', prompt)
|
||||
|
||||
prompt_messages.append(PromptMessage(type = MessageType(prompt_item['role']) ,content=prompt))
|
||||
|
||||
if memory and app_mode == 'chat' and model_mode == ModelMode.CHAT.value:
|
||||
memory.human_prefix = MessageType.USER.value
|
||||
memory.ai_prefix = MessageType.ASSISTANT.value
|
||||
histories = self._get_history_messages_list_from_memory(memory, 2000)
|
||||
prompt_messages.extend(histories)
|
||||
|
||||
if app_mode == 'chat' and model_mode == ModelMode.CHAT.value:
|
||||
prompt_messages.append(PromptMessage(type = MessageType.USER ,content=query))
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def prompt_file_name(self, mode: str) -> str:
|
||||
if mode == 'completion':
|
||||
return 'common_completion'
|
||||
@ -337,17 +421,17 @@ class BaseLLM(BaseProviderModel):
|
||||
memory: Optional[BaseChatMemory]) -> Tuple[str, Optional[list]]:
|
||||
context_prompt_content = ''
|
||||
if context and 'context_prompt' in prompt_rules:
|
||||
prompt_template = JinjaPromptTemplate.from_template(template=prompt_rules['context_prompt'])
|
||||
prompt_template = PromptTemplateParser(template=prompt_rules['context_prompt'])
|
||||
context_prompt_content = prompt_template.format(
|
||||
context=context
|
||||
{'context': context}
|
||||
)
|
||||
|
||||
pre_prompt_content = ''
|
||||
if pre_prompt:
|
||||
prompt_template = JinjaPromptTemplate.from_template(template=pre_prompt)
|
||||
prompt_inputs = {k: inputs[k] for k in prompt_template.input_variables if k in inputs}
|
||||
prompt_template = PromptTemplateParser(template=pre_prompt)
|
||||
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
|
||||
pre_prompt_content = prompt_template.format(
|
||||
**prompt_inputs
|
||||
prompt_inputs
|
||||
)
|
||||
|
||||
prompt = ''
|
||||
@ -380,10 +464,8 @@ class BaseLLM(BaseProviderModel):
|
||||
memory.ai_prefix = prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant'
|
||||
|
||||
histories = self._get_history_messages_from_memory(memory, rest_tokens)
|
||||
prompt_template = JinjaPromptTemplate.from_template(template=prompt_rules['histories_prompt'])
|
||||
histories_prompt_content = prompt_template.format(
|
||||
histories=histories
|
||||
)
|
||||
prompt_template = PromptTemplateParser(template=prompt_rules['histories_prompt'])
|
||||
histories_prompt_content = prompt_template.format({'histories': histories})
|
||||
|
||||
prompt = ''
|
||||
for order in prompt_rules['system_prompt_orders']:
|
||||
@ -394,10 +476,8 @@ class BaseLLM(BaseProviderModel):
|
||||
elif order == 'histories_prompt':
|
||||
prompt += histories_prompt_content
|
||||
|
||||
prompt_template = JinjaPromptTemplate.from_template(template=query_prompt)
|
||||
query_prompt_content = prompt_template.format(
|
||||
query=query
|
||||
)
|
||||
prompt_template = PromptTemplateParser(template=query_prompt)
|
||||
query_prompt_content = prompt_template.format({'query': query})
|
||||
|
||||
prompt += query_prompt_content
|
||||
|
||||
@ -428,6 +508,16 @@ class BaseLLM(BaseProviderModel):
|
||||
external_context = memory.load_memory_variables({})
|
||||
return external_context[memory_key]
|
||||
|
||||
def _get_history_messages_list_from_memory(self, memory: BaseChatMemory,
|
||||
max_token_limit: int) -> List[PromptMessage]:
|
||||
"""Get memory messages."""
|
||||
memory.max_token_limit = max_token_limit
|
||||
memory.return_messages = True
|
||||
memory_key = memory.memory_variables[0]
|
||||
external_context = memory.load_memory_variables({})
|
||||
memory.return_messages = False
|
||||
return to_prompt_messages(external_context[memory_key])
|
||||
|
||||
def _get_prompt_from_messages(self, messages: List[PromptMessage],
|
||||
model_mode: Optional[ModelMode] = None) -> Union[str | List[BaseMessage]]:
|
||||
if not model_mode:
|
||||
@ -442,16 +532,7 @@ class BaseLLM(BaseProviderModel):
|
||||
if len(messages) == 0:
|
||||
return []
|
||||
|
||||
chat_messages = []
|
||||
for message in messages:
|
||||
if message.type == MessageType.HUMAN:
|
||||
chat_messages.append(HumanMessage(content=message.content))
|
||||
elif message.type == MessageType.ASSISTANT:
|
||||
chat_messages.append(AIMessage(content=message.content))
|
||||
elif message.type == MessageType.SYSTEM:
|
||||
chat_messages.append(SystemMessage(content=message.content))
|
||||
|
||||
return chat_messages
|
||||
return to_lc_messages(messages)
|
||||
|
||||
def _to_model_kwargs_input(self, model_rules: ModelKwargsRules, model_kwargs: ModelKwargs) -> dict:
|
||||
"""
|
||||
|
||||
@ -1,26 +1,23 @@
|
||||
import decimal
|
||||
from typing import List, Optional, Any
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.llms import Minimax
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType
|
||||
from core.model_providers.models.entity.message import PromptMessage
|
||||
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
|
||||
from core.third_party.langchain.llms.minimax_llm import MinimaxChatLLM
|
||||
|
||||
|
||||
class MinimaxModel(BaseLLM):
|
||||
model_mode: ModelMode = ModelMode.COMPLETION
|
||||
model_mode: ModelMode = ModelMode.CHAT
|
||||
|
||||
def _init_client(self) -> Any:
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
||||
return Minimax(
|
||||
return MinimaxChatLLM(
|
||||
model=self.name,
|
||||
model_kwargs={
|
||||
'stream': False
|
||||
},
|
||||
streaming=self.streaming,
|
||||
callbacks=self.callbacks,
|
||||
**self.credentials,
|
||||
**provider_model_kwargs
|
||||
@ -49,7 +46,7 @@ class MinimaxModel(BaseLLM):
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return max(self._client.get_num_tokens(prompts), 0)
|
||||
return max(self._client.get_num_tokens_from_messages(prompts), 0)
|
||||
|
||||
def get_currency(self):
|
||||
return 'RMB'
|
||||
@ -65,3 +62,7 @@ class MinimaxModel(BaseLLM):
|
||||
return LLMBadRequestError(f"Minimax: {str(ex)}")
|
||||
else:
|
||||
return ex
|
||||
|
||||
@property
|
||||
def support_streaming(self):
|
||||
return True
|
||||
|
||||
@ -33,7 +33,7 @@ MODEL_MAX_TOKENS = {
|
||||
'gpt-4': 8192,
|
||||
'gpt-4-32k': 32768,
|
||||
'gpt-3.5-turbo': 4096,
|
||||
'gpt-3.5-turbo-instruct': 8192,
|
||||
'gpt-3.5-turbo-instruct': 4097,
|
||||
'gpt-3.5-turbo-16k': 16384,
|
||||
'text-davinci-003': 4097,
|
||||
}
|
||||
@ -106,7 +106,21 @@ class OpenAIModel(BaseLLM):
|
||||
raise ModelCurrentlyNotSupportError("Dify Hosted OpenAI GPT-4 currently not support.")
|
||||
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return self._client.generate([prompts], stop, callbacks)
|
||||
|
||||
generate_kwargs = {
|
||||
'stop': stop,
|
||||
'callbacks': callbacks
|
||||
}
|
||||
|
||||
if isinstance(prompts, str):
|
||||
generate_kwargs['prompts'] = [prompts]
|
||||
else:
|
||||
generate_kwargs['messages'] = [prompts]
|
||||
|
||||
if 'functions' in kwargs:
|
||||
generate_kwargs['functions'] = kwargs['functions']
|
||||
|
||||
return self._client.generate(**generate_kwargs)
|
||||
|
||||
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
|
||||
"""
|
||||
|
||||
@ -18,7 +18,6 @@ class TongyiModel(BaseLLM):
|
||||
|
||||
def _init_client(self) -> Any:
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
||||
del provider_model_kwargs['max_tokens']
|
||||
return EnhanceTongyi(
|
||||
model_name=self.name,
|
||||
max_retries=1,
|
||||
@ -58,7 +57,6 @@ class TongyiModel(BaseLLM):
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||
del provider_model_kwargs['max_tokens']
|
||||
for k, v in provider_model_kwargs.items():
|
||||
if hasattr(self.client, k):
|
||||
setattr(self.client, k, v)
|
||||
|
||||
@ -9,7 +9,7 @@ from langchain.schema import HumanMessage
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule
|
||||
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelMode
|
||||
from core.model_providers.models.entity.provider import ModelFeature
|
||||
from core.model_providers.models.llm.anthropic_model import AnthropicModel
|
||||
from core.model_providers.models.llm.base import ModelType
|
||||
@ -34,10 +34,12 @@ class AnthropicProvider(BaseModelProvider):
|
||||
{
|
||||
'id': 'claude-instant-1',
|
||||
'name': 'claude-instant-1',
|
||||
'mode': ModelMode.CHAT.value,
|
||||
},
|
||||
{
|
||||
'id': 'claude-2',
|
||||
'name': 'claude-2',
|
||||
'mode': ModelMode.CHAT.value,
|
||||
'features': [
|
||||
ModelFeature.AGENT_THOUGHT.value
|
||||
]
|
||||
@ -46,6 +48,9 @@ class AnthropicProvider(BaseModelProvider):
|
||||
else:
|
||||
return []
|
||||
|
||||
def _get_text_generation_model_mode(self, model_name) -> str:
|
||||
return ModelMode.CHAT.value
|
||||
|
||||
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
|
||||
"""
|
||||
Returns the model class.
|
||||
|
||||
@ -12,7 +12,7 @@ from core.helper import encrypter
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.embedding.azure_openai_embedding import AzureOpenAIEmbedding, \
|
||||
AZURE_OPENAI_API_VERSION
|
||||
from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules, KwargRule
|
||||
from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules, KwargRule, ModelMode
|
||||
from core.model_providers.models.entity.provider import ModelFeature
|
||||
from core.model_providers.models.llm.azure_openai_model import AzureOpenAIModel
|
||||
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||
@ -61,6 +61,10 @@ class AzureOpenAIProvider(BaseModelProvider):
|
||||
}
|
||||
|
||||
credentials = json.loads(provider_model.encrypted_config)
|
||||
|
||||
if provider_model.model_type == ModelType.TEXT_GENERATION.value:
|
||||
model_dict['mode'] = self._get_text_generation_model_mode(credentials['base_model_name'])
|
||||
|
||||
if credentials['base_model_name'] in [
|
||||
'gpt-4',
|
||||
'gpt-4-32k',
|
||||
@ -77,12 +81,19 @@ class AzureOpenAIProvider(BaseModelProvider):
|
||||
|
||||
return model_list
|
||||
|
||||
def _get_text_generation_model_mode(self, model_name) -> str:
|
||||
if model_name == 'text-davinci-003':
|
||||
return ModelMode.COMPLETION.value
|
||||
else:
|
||||
return ModelMode.CHAT.value
|
||||
|
||||
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
models = [
|
||||
{
|
||||
'id': 'gpt-3.5-turbo',
|
||||
'name': 'gpt-3.5-turbo',
|
||||
'mode': ModelMode.CHAT.value,
|
||||
'features': [
|
||||
ModelFeature.AGENT_THOUGHT.value
|
||||
]
|
||||
@ -90,6 +101,7 @@ class AzureOpenAIProvider(BaseModelProvider):
|
||||
{
|
||||
'id': 'gpt-3.5-turbo-16k',
|
||||
'name': 'gpt-3.5-turbo-16k',
|
||||
'mode': ModelMode.CHAT.value,
|
||||
'features': [
|
||||
ModelFeature.AGENT_THOUGHT.value
|
||||
]
|
||||
@ -97,6 +109,7 @@ class AzureOpenAIProvider(BaseModelProvider):
|
||||
{
|
||||
'id': 'gpt-4',
|
||||
'name': 'gpt-4',
|
||||
'mode': ModelMode.CHAT.value,
|
||||
'features': [
|
||||
ModelFeature.AGENT_THOUGHT.value
|
||||
]
|
||||
@ -104,6 +117,7 @@ class AzureOpenAIProvider(BaseModelProvider):
|
||||
{
|
||||
'id': 'gpt-4-32k',
|
||||
'name': 'gpt-4-32k',
|
||||
'mode': ModelMode.CHAT.value,
|
||||
'features': [
|
||||
ModelFeature.AGENT_THOUGHT.value
|
||||
]
|
||||
@ -111,6 +125,7 @@ class AzureOpenAIProvider(BaseModelProvider):
|
||||
{
|
||||
'id': 'text-davinci-003',
|
||||
'name': 'text-davinci-003',
|
||||
'mode': ModelMode.COMPLETION.value,
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@ -6,7 +6,7 @@ from langchain.schema import HumanMessage
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
|
||||
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
|
||||
from core.model_providers.models.llm.baichuan_model import BaichuanModel
|
||||
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||
from core.third_party.langchain.llms.baichuan_llm import BaichuanChatLLM
|
||||
@ -21,6 +21,9 @@ class BaichuanProvider(BaseModelProvider):
|
||||
Returns the name of a provider.
|
||||
"""
|
||||
return 'baichuan'
|
||||
|
||||
def _get_text_generation_model_mode(self, model_name) -> str:
|
||||
return ModelMode.CHAT.value
|
||||
|
||||
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
@ -28,6 +31,7 @@ class BaichuanProvider(BaseModelProvider):
|
||||
{
|
||||
'id': 'baichuan2-53b',
|
||||
'name': 'Baichuan2-53B',
|
||||
'mode': ModelMode.CHAT.value,
|
||||
}
|
||||
]
|
||||
else:
|
||||
|
||||
@ -61,10 +61,19 @@ class BaseModelProvider(BaseModel, ABC):
|
||||
ProviderModel.is_valid == True
|
||||
).order_by(ProviderModel.created_at.asc()).all()
|
||||
|
||||
return [{
|
||||
'id': provider_model.model_name,
|
||||
'name': provider_model.model_name
|
||||
} for provider_model in provider_models]
|
||||
provider_model_list = []
|
||||
for provider_model in provider_models:
|
||||
provider_model_dict = {
|
||||
'id': provider_model.model_name,
|
||||
'name': provider_model.model_name
|
||||
}
|
||||
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
provider_model_dict['mode'] = self._get_text_generation_model_mode(provider_model.model_name)
|
||||
|
||||
provider_model_list.append(provider_model_dict)
|
||||
|
||||
return provider_model_list
|
||||
|
||||
@abstractmethod
|
||||
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
|
||||
@ -76,6 +85,16 @@ class BaseModelProvider(BaseModel, ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _get_text_generation_model_mode(self, model_name) -> str:
|
||||
"""
|
||||
get text generation model mode.
|
||||
|
||||
:param model_name:
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_model_class(self, model_type: ModelType) -> Type:
|
||||
"""
|
||||
|
||||
@ -6,7 +6,7 @@ from langchain.llms import ChatGLM
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
|
||||
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
|
||||
from core.model_providers.models.llm.chatglm_model import ChatGLMModel
|
||||
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||
from models.provider import ProviderType
|
||||
@ -27,15 +27,20 @@ class ChatGLMProvider(BaseModelProvider):
|
||||
{
|
||||
'id': 'chatglm2-6b',
|
||||
'name': 'ChatGLM2-6B',
|
||||
'mode': ModelMode.COMPLETION.value,
|
||||
},
|
||||
{
|
||||
'id': 'chatglm-6b',
|
||||
'name': 'ChatGLM-6B',
|
||||
'mode': ModelMode.COMPLETION.value,
|
||||
}
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
def _get_text_generation_model_mode(self, model_name) -> str:
|
||||
return ModelMode.COMPLETION.value
|
||||
|
||||
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
|
||||
"""
|
||||
Returns the model class.
|
||||
|
||||
@ -5,7 +5,7 @@ import requests
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
|
||||
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType, ModelMode
|
||||
from core.model_providers.models.llm.huggingface_hub_model import HuggingfaceHubModel
|
||||
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||
|
||||
@ -29,6 +29,9 @@ class HuggingfaceHubProvider(BaseModelProvider):
|
||||
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
|
||||
return []
|
||||
|
||||
def _get_text_generation_model_mode(self, model_name) -> str:
|
||||
return ModelMode.COMPLETION.value
|
||||
|
||||
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
|
||||
"""
|
||||
Returns the model class.
|
||||
|
||||
@ -6,7 +6,7 @@ from langchain.schema import HumanMessage
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.model_providers.models.embedding.localai_embedding import LocalAIEmbedding
|
||||
from core.model_providers.models.entity.model_params import ModelKwargsRules, ModelType, KwargRule
|
||||
from core.model_providers.models.entity.model_params import ModelKwargsRules, ModelType, KwargRule, ModelMode
|
||||
from core.model_providers.models.llm.localai_model import LocalAIModel
|
||||
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||
|
||||
@ -27,6 +27,13 @@ class LocalAIProvider(BaseModelProvider):
|
||||
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
|
||||
return []
|
||||
|
||||
def _get_text_generation_model_mode(self, model_name) -> str:
|
||||
credentials = self.get_model_credentials(model_name, ModelType.TEXT_GENERATION)
|
||||
if credentials['completion_type'] == 'chat_completion':
|
||||
return ModelMode.CHAT.value
|
||||
else:
|
||||
return ModelMode.COMPLETION.value
|
||||
|
||||
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
|
||||
"""
|
||||
Returns the model class.
|
||||
|
||||
@ -2,14 +2,15 @@ import json
|
||||
from json import JSONDecodeError
|
||||
from typing import Type
|
||||
|
||||
from langchain.llms import Minimax
|
||||
from langchain.schema import HumanMessage
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.embedding.minimax_embedding import MinimaxEmbedding
|
||||
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
|
||||
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
|
||||
from core.model_providers.models.llm.minimax_model import MinimaxModel
|
||||
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||
from core.third_party.langchain.llms.minimax_llm import MinimaxChatLLM
|
||||
from models.provider import ProviderType, ProviderQuotaType
|
||||
|
||||
|
||||
@ -28,10 +29,12 @@ class MinimaxProvider(BaseModelProvider):
|
||||
{
|
||||
'id': 'abab5.5-chat',
|
||||
'name': 'abab5.5-chat',
|
||||
'mode': ModelMode.COMPLETION.value,
|
||||
},
|
||||
{
|
||||
'id': 'abab5-chat',
|
||||
'name': 'abab5-chat',
|
||||
'mode': ModelMode.COMPLETION.value,
|
||||
}
|
||||
]
|
||||
elif model_type == ModelType.EMBEDDINGS:
|
||||
@ -44,6 +47,9 @@ class MinimaxProvider(BaseModelProvider):
|
||||
else:
|
||||
return []
|
||||
|
||||
def _get_text_generation_model_mode(self, model_name) -> str:
|
||||
return ModelMode.COMPLETION.value
|
||||
|
||||
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
|
||||
"""
|
||||
Returns the model class.
|
||||
@ -98,14 +104,14 @@ class MinimaxProvider(BaseModelProvider):
|
||||
'minimax_api_key': credentials['minimax_api_key'],
|
||||
}
|
||||
|
||||
llm = Minimax(
|
||||
llm = MinimaxChatLLM(
|
||||
model='abab5.5-chat',
|
||||
max_tokens=10,
|
||||
temperature=0.01,
|
||||
**credential_kwargs
|
||||
)
|
||||
|
||||
llm("ping")
|
||||
llm([HumanMessage(content='ping')])
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
|
||||
@ -13,8 +13,8 @@ from core.model_providers.models.entity.provider import ModelFeature
|
||||
from core.model_providers.models.speech2text.openai_whisper import OpenAIWhisper
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.embedding.openai_embedding import OpenAIEmbedding
|
||||
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
|
||||
from core.model_providers.models.llm.openai_model import OpenAIModel
|
||||
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
|
||||
from core.model_providers.models.llm.openai_model import OpenAIModel, COMPLETION_MODELS
|
||||
from core.model_providers.models.moderation.openai_moderation import OpenAIModeration
|
||||
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||
from core.model_providers.providers.hosted import hosted_model_providers
|
||||
@ -36,6 +36,7 @@ class OpenAIProvider(BaseModelProvider):
|
||||
{
|
||||
'id': 'gpt-3.5-turbo',
|
||||
'name': 'gpt-3.5-turbo',
|
||||
'mode': ModelMode.CHAT.value,
|
||||
'features': [
|
||||
ModelFeature.AGENT_THOUGHT.value
|
||||
]
|
||||
@ -43,10 +44,12 @@ class OpenAIProvider(BaseModelProvider):
|
||||
{
|
||||
'id': 'gpt-3.5-turbo-instruct',
|
||||
'name': 'GPT-3.5-Turbo-Instruct',
|
||||
'mode': ModelMode.COMPLETION.value,
|
||||
},
|
||||
{
|
||||
'id': 'gpt-3.5-turbo-16k',
|
||||
'name': 'gpt-3.5-turbo-16k',
|
||||
'mode': ModelMode.CHAT.value,
|
||||
'features': [
|
||||
ModelFeature.AGENT_THOUGHT.value
|
||||
]
|
||||
@ -54,6 +57,7 @@ class OpenAIProvider(BaseModelProvider):
|
||||
{
|
||||
'id': 'gpt-4',
|
||||
'name': 'gpt-4',
|
||||
'mode': ModelMode.CHAT.value,
|
||||
'features': [
|
||||
ModelFeature.AGENT_THOUGHT.value
|
||||
]
|
||||
@ -61,6 +65,7 @@ class OpenAIProvider(BaseModelProvider):
|
||||
{
|
||||
'id': 'gpt-4-32k',
|
||||
'name': 'gpt-4-32k',
|
||||
'mode': ModelMode.CHAT.value,
|
||||
'features': [
|
||||
ModelFeature.AGENT_THOUGHT.value
|
||||
]
|
||||
@ -68,6 +73,7 @@ class OpenAIProvider(BaseModelProvider):
|
||||
{
|
||||
'id': 'text-davinci-003',
|
||||
'name': 'text-davinci-003',
|
||||
'mode': ModelMode.COMPLETION.value,
|
||||
}
|
||||
]
|
||||
|
||||
@ -100,6 +106,12 @@ class OpenAIProvider(BaseModelProvider):
|
||||
else:
|
||||
return []
|
||||
|
||||
def _get_text_generation_model_mode(self, model_name) -> str:
|
||||
if model_name in COMPLETION_MODELS:
|
||||
return ModelMode.COMPLETION.value
|
||||
else:
|
||||
return ModelMode.CHAT.value
|
||||
|
||||
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
|
||||
"""
|
||||
Returns the model class.
|
||||
@ -132,7 +144,7 @@ class OpenAIProvider(BaseModelProvider):
|
||||
'gpt-4': 8192,
|
||||
'gpt-4-32k': 32768,
|
||||
'gpt-3.5-turbo': 4096,
|
||||
'gpt-3.5-turbo-instruct': 8192,
|
||||
'gpt-3.5-turbo-instruct': 4097,
|
||||
'gpt-3.5-turbo-16k': 16384,
|
||||
'text-davinci-003': 4097,
|
||||
}
|
||||
|
||||
@ -3,7 +3,7 @@ from typing import Type
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.model_providers.models.embedding.openllm_embedding import OpenLLMEmbedding
|
||||
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
|
||||
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType, ModelMode
|
||||
from core.model_providers.models.llm.openllm_model import OpenLLMModel
|
||||
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||
|
||||
@ -24,6 +24,9 @@ class OpenLLMProvider(BaseModelProvider):
|
||||
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
|
||||
return []
|
||||
|
||||
def _get_text_generation_model_mode(self, model_name) -> str:
|
||||
return ModelMode.COMPLETION.value
|
||||
|
||||
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
|
||||
"""
|
||||
Returns the model class.
|
||||
|
||||
@ -6,7 +6,8 @@ import replicate
|
||||
from replicate.exceptions import ReplicateError
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.model_providers.models.entity.model_params import KwargRule, KwargRuleType, ModelKwargsRules, ModelType
|
||||
from core.model_providers.models.entity.model_params import KwargRule, KwargRuleType, ModelKwargsRules, ModelType, \
|
||||
ModelMode
|
||||
from core.model_providers.models.llm.replicate_model import ReplicateModel
|
||||
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||
|
||||
@ -26,6 +27,9 @@ class ReplicateProvider(BaseModelProvider):
|
||||
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
|
||||
return []
|
||||
|
||||
def _get_text_generation_model_mode(self, model_name) -> str:
|
||||
return ModelMode.CHAT.value if model_name.endswith('-chat') else ModelMode.COMPLETION.value
|
||||
|
||||
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
|
||||
"""
|
||||
Returns the model class.
|
||||
|
||||
@ -7,7 +7,7 @@ from langchain.schema import HumanMessage
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
|
||||
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
|
||||
from core.model_providers.models.llm.spark_model import SparkModel
|
||||
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||
from core.third_party.langchain.llms.spark import ChatSpark
|
||||
@ -30,15 +30,20 @@ class SparkProvider(BaseModelProvider):
|
||||
{
|
||||
'id': 'spark',
|
||||
'name': 'Spark V1.5',
|
||||
'mode': ModelMode.CHAT.value,
|
||||
},
|
||||
{
|
||||
'id': 'spark-v2',
|
||||
'name': 'Spark V2.0',
|
||||
'mode': ModelMode.CHAT.value,
|
||||
}
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
def _get_text_generation_model_mode(self, model_name) -> str:
|
||||
return ModelMode.CHAT.value
|
||||
|
||||
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
|
||||
"""
|
||||
Returns the model class.
|
||||
|
||||
@ -4,7 +4,7 @@ from typing import Type
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
|
||||
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
|
||||
from core.model_providers.models.llm.tongyi_model import TongyiModel
|
||||
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||
from core.third_party.langchain.llms.tongyi_llm import EnhanceTongyi
|
||||
@ -24,17 +24,22 @@ class TongyiProvider(BaseModelProvider):
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
return [
|
||||
{
|
||||
'id': 'qwen-v1',
|
||||
'name': 'qwen-v1',
|
||||
'id': 'qwen-turbo',
|
||||
'name': 'qwen-turbo',
|
||||
'mode': ModelMode.COMPLETION.value,
|
||||
},
|
||||
{
|
||||
'id': 'qwen-plus-v1',
|
||||
'name': 'qwen-plus-v1',
|
||||
'id': 'qwen-plus',
|
||||
'name': 'qwen-plus',
|
||||
'mode': ModelMode.COMPLETION.value,
|
||||
}
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
def _get_text_generation_model_mode(self, model_name) -> str:
|
||||
return ModelMode.COMPLETION.value
|
||||
|
||||
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
|
||||
"""
|
||||
Returns the model class.
|
||||
@ -58,16 +63,16 @@ class TongyiProvider(BaseModelProvider):
|
||||
:return:
|
||||
"""
|
||||
model_max_tokens = {
|
||||
'qwen-v1': 1500,
|
||||
'qwen-plus-v1': 6500
|
||||
'qwen-turbo': 6000,
|
||||
'qwen-plus': 6000
|
||||
}
|
||||
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](enabled=False),
|
||||
top_p=KwargRule[float](min=0, max=1, default=0.8, precision=2),
|
||||
temperature=KwargRule[float](min=0.01, max=1, default=1, precision=2),
|
||||
top_p=KwargRule[float](min=0.01, max=0.99, default=0.5, precision=2),
|
||||
presence_penalty=KwargRule[float](enabled=False),
|
||||
frequency_penalty=KwargRule[float](enabled=False),
|
||||
max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name), default=1024, precision=0),
|
||||
max_tokens=KwargRule[int](enabled=False, max=model_max_tokens.get(model_name)),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -84,7 +89,7 @@ class TongyiProvider(BaseModelProvider):
|
||||
}
|
||||
|
||||
llm = EnhanceTongyi(
|
||||
model_name='qwen-v1',
|
||||
model_name='qwen-turbo',
|
||||
max_retries=1,
|
||||
**credential_kwargs
|
||||
)
|
||||
|
||||
@ -4,7 +4,7 @@ from typing import Type
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
|
||||
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
|
||||
from core.model_providers.models.llm.wenxin_model import WenxinModel
|
||||
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||
from core.third_party.langchain.llms.wenxin import Wenxin
|
||||
@ -26,19 +26,25 @@ class WenxinProvider(BaseModelProvider):
|
||||
{
|
||||
'id': 'ernie-bot',
|
||||
'name': 'ERNIE-Bot',
|
||||
'mode': ModelMode.COMPLETION.value,
|
||||
},
|
||||
{
|
||||
'id': 'ernie-bot-turbo',
|
||||
'name': 'ERNIE-Bot-turbo',
|
||||
'mode': ModelMode.COMPLETION.value,
|
||||
},
|
||||
{
|
||||
'id': 'bloomz-7b',
|
||||
'name': 'BLOOMZ-7B',
|
||||
'mode': ModelMode.COMPLETION.value,
|
||||
}
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
def _get_text_generation_model_mode(self, model_name) -> str:
|
||||
return ModelMode.COMPLETION.value
|
||||
|
||||
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
|
||||
"""
|
||||
Returns the model class.
|
||||
|
||||
@ -2,15 +2,15 @@ import json
|
||||
from typing import Type
|
||||
|
||||
import requests
|
||||
from langchain.embeddings import XinferenceEmbeddings
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding
|
||||
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
|
||||
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType, ModelMode
|
||||
from core.model_providers.models.llm.xinference_model import XinferenceModel
|
||||
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.third_party.langchain.embeddings.xinference_embedding import XinferenceEmbeddings
|
||||
from core.third_party.langchain.llms.xinference_llm import XinferenceLLM
|
||||
from models.provider import ProviderType
|
||||
|
||||
@ -26,6 +26,9 @@ class XinferenceProvider(BaseModelProvider):
|
||||
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
|
||||
return []
|
||||
|
||||
def _get_text_generation_model_mode(self, model_name) -> str:
|
||||
return ModelMode.COMPLETION.value
|
||||
|
||||
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
|
||||
"""
|
||||
Returns the model class.
|
||||
|
||||
@ -7,7 +7,7 @@ from langchain.schema import HumanMessage
|
||||
from core.helper import encrypter
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.embedding.zhipuai_embedding import ZhipuAIEmbedding
|
||||
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
|
||||
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
|
||||
from core.model_providers.models.llm.zhipuai_model import ZhipuAIModel
|
||||
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||
from core.third_party.langchain.llms.zhipuai_llm import ZhipuAIChatLLM
|
||||
@ -29,18 +29,22 @@ class ZhipuAIProvider(BaseModelProvider):
|
||||
{
|
||||
'id': 'chatglm_pro',
|
||||
'name': 'chatglm_pro',
|
||||
'mode': ModelMode.CHAT.value,
|
||||
},
|
||||
{
|
||||
'id': 'chatglm_std',
|
||||
'name': 'chatglm_std',
|
||||
'mode': ModelMode.CHAT.value,
|
||||
},
|
||||
{
|
||||
'id': 'chatglm_lite',
|
||||
'name': 'chatglm_lite',
|
||||
'mode': ModelMode.CHAT.value,
|
||||
},
|
||||
{
|
||||
'id': 'chatglm_lite_32k',
|
||||
'name': 'chatglm_lite_32k',
|
||||
'mode': ModelMode.CHAT.value,
|
||||
}
|
||||
]
|
||||
elif model_type == ModelType.EMBEDDINGS:
|
||||
@ -53,6 +57,9 @@ class ZhipuAIProvider(BaseModelProvider):
|
||||
else:
|
||||
return []
|
||||
|
||||
def _get_text_generation_model_mode(self, model_name) -> str:
|
||||
return ModelMode.CHAT.value
|
||||
|
||||
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
|
||||
"""
|
||||
Returns the model class.
|
||||
|
||||
@ -3,5 +3,19 @@
|
||||
"custom"
|
||||
],
|
||||
"system_config": null,
|
||||
"model_flexibility": "fixed"
|
||||
"model_flexibility": "fixed",
|
||||
"price_config": {
|
||||
"qwen-turbo": {
|
||||
"prompt": "0.012",
|
||||
"completion": "0.012",
|
||||
"unit": "0.001",
|
||||
"currency": "RMB"
|
||||
},
|
||||
"qwen-plus": {
|
||||
"prompt": "0.14",
|
||||
"completion": "0.14",
|
||||
"unit": "0.001",
|
||||
"currency": "RMB"
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,7 +1,5 @@
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
from flask import current_app
|
||||
from langchain import WikipediaAPIWrapper
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
@ -27,7 +25,6 @@ from core.tool.web_reader_tool import WebReaderTool
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DatasetProcessRule
|
||||
from models.model import AppModelConfig
|
||||
from models.provider import ProviderType
|
||||
|
||||
|
||||
class OrchestratorRuleParser:
|
||||
@ -39,18 +36,20 @@ class OrchestratorRuleParser:
|
||||
|
||||
def to_agent_executor(self, conversation_message_task: ConversationMessageTask, memory: Optional[BaseChatMemory],
|
||||
rest_tokens: int, chain_callback: MainChainGatherCallbackHandler,
|
||||
return_resource: bool = False, retriever_from: str = 'dev') -> Optional[AgentExecutor]:
|
||||
retriever_from: str = 'dev') -> Optional[AgentExecutor]:
|
||||
if not self.app_model_config.agent_mode_dict:
|
||||
return None
|
||||
|
||||
agent_mode_config = self.app_model_config.agent_mode_dict
|
||||
model_dict = self.app_model_config.model_dict
|
||||
return_resource = self.app_model_config.retriever_resource_dict.get('enabled', False)
|
||||
|
||||
chain = None
|
||||
if agent_mode_config and agent_mode_config.get('enabled'):
|
||||
tool_configs = agent_mode_config.get('tools', [])
|
||||
agent_provider_name = model_dict.get('provider', 'openai')
|
||||
agent_model_name = model_dict.get('name', 'gpt-4')
|
||||
dataset_configs = self.app_model_config.dataset_configs_dict
|
||||
|
||||
agent_model_instance = ModelFactory.get_text_generation_model(
|
||||
tenant_id=self.tenant_id,
|
||||
@ -77,7 +76,7 @@ class OrchestratorRuleParser:
|
||||
# only OpenAI chat model (include Azure) support function call, use ReACT instead
|
||||
if agent_model_instance.model_mode != ModelMode.CHAT \
|
||||
or agent_model_instance.model_provider.provider_name not in ['openai', 'azure_openai']:
|
||||
if planning_strategy in [PlanningStrategy.FUNCTION_CALL, PlanningStrategy.MULTI_FUNCTION_CALL]:
|
||||
if planning_strategy == PlanningStrategy.FUNCTION_CALL:
|
||||
planning_strategy = PlanningStrategy.REACT
|
||||
elif planning_strategy == PlanningStrategy.ROUTER:
|
||||
planning_strategy = PlanningStrategy.REACT_ROUTER
|
||||
@ -97,13 +96,14 @@ class OrchestratorRuleParser:
|
||||
summary_model_instance = None
|
||||
|
||||
tools = self.to_tools(
|
||||
agent_model_instance=agent_model_instance,
|
||||
tool_configs=tool_configs,
|
||||
callbacks=[agent_callback, DifyStdOutCallbackHandler()],
|
||||
agent_model_instance=agent_model_instance,
|
||||
conversation_message_task=conversation_message_task,
|
||||
rest_tokens=rest_tokens,
|
||||
callbacks=[agent_callback, DifyStdOutCallbackHandler()],
|
||||
return_resource=return_resource,
|
||||
retriever_from=retriever_from
|
||||
retriever_from=retriever_from,
|
||||
dataset_configs=dataset_configs
|
||||
)
|
||||
|
||||
if len(tools) == 0:
|
||||
@ -171,20 +171,12 @@ class OrchestratorRuleParser:
|
||||
|
||||
return None
|
||||
|
||||
def to_tools(self, agent_model_instance: BaseLLM, tool_configs: list,
|
||||
conversation_message_task: ConversationMessageTask,
|
||||
rest_tokens: int, callbacks: Callbacks = None, return_resource: bool = False,
|
||||
retriever_from: str = 'dev') -> list[BaseTool]:
|
||||
def to_tools(self, tool_configs: list, callbacks: Callbacks = None, **kwargs) -> list[BaseTool]:
|
||||
"""
|
||||
Convert app agent tool configs to tools
|
||||
|
||||
:param agent_model_instance:
|
||||
:param rest_tokens:
|
||||
:param tool_configs: app agent tool configs
|
||||
:param conversation_message_task:
|
||||
:param callbacks:
|
||||
:param return_resource:
|
||||
:param retriever_from:
|
||||
:return:
|
||||
"""
|
||||
tools = []
|
||||
@ -196,29 +188,35 @@ class OrchestratorRuleParser:
|
||||
|
||||
tool = None
|
||||
if tool_type == "dataset":
|
||||
tool = self.to_dataset_retriever_tool(tool_val, conversation_message_task, rest_tokens, return_resource, retriever_from)
|
||||
tool = self.to_dataset_retriever_tool(tool_config=tool_val, **kwargs)
|
||||
elif tool_type == "web_reader":
|
||||
tool = self.to_web_reader_tool(agent_model_instance)
|
||||
tool = self.to_web_reader_tool(tool_config=tool_val, **kwargs)
|
||||
elif tool_type == "google_search":
|
||||
tool = self.to_google_search_tool()
|
||||
tool = self.to_google_search_tool(tool_config=tool_val, **kwargs)
|
||||
elif tool_type == "wikipedia":
|
||||
tool = self.to_wikipedia_tool()
|
||||
tool = self.to_wikipedia_tool(tool_config=tool_val, **kwargs)
|
||||
elif tool_type == "current_datetime":
|
||||
tool = self.to_current_datetime_tool()
|
||||
tool = self.to_current_datetime_tool(tool_config=tool_val, **kwargs)
|
||||
|
||||
if tool:
|
||||
tool.callbacks.extend(callbacks)
|
||||
if tool.callbacks is not None:
|
||||
tool.callbacks.extend(callbacks)
|
||||
else:
|
||||
tool.callbacks = callbacks
|
||||
tools.append(tool)
|
||||
|
||||
return tools
|
||||
|
||||
def to_dataset_retriever_tool(self, tool_config: dict, conversation_message_task: ConversationMessageTask,
|
||||
rest_tokens: int, return_resource: bool = False, retriever_from: str = 'dev') \
|
||||
dataset_configs: dict, rest_tokens: int,
|
||||
return_resource: bool = False, retriever_from: str = 'dev',
|
||||
**kwargs) \
|
||||
-> Optional[BaseTool]:
|
||||
"""
|
||||
A dataset tool is a tool that can be used to retrieve information from a dataset
|
||||
:param rest_tokens:
|
||||
:param tool_config:
|
||||
:param dataset_configs:
|
||||
:param conversation_message_task:
|
||||
:param return_resource:
|
||||
:param retriever_from:
|
||||
@ -236,10 +234,20 @@ class OrchestratorRuleParser:
|
||||
if dataset and dataset.available_document_count == 0 and dataset.available_document_count == 0:
|
||||
return None
|
||||
|
||||
k = self._dynamic_calc_retrieve_k(dataset, rest_tokens)
|
||||
top_k = dataset_configs.get("top_k", 2)
|
||||
|
||||
# dynamically adjust top_k when the remaining token number is not enough to support top_k
|
||||
top_k = self._dynamic_calc_retrieve_k(dataset=dataset, top_k=top_k, rest_tokens=rest_tokens)
|
||||
|
||||
score_threshold = None
|
||||
score_threshold_config = dataset_configs.get("score_threshold")
|
||||
if score_threshold_config and score_threshold_config.get("enable"):
|
||||
score_threshold = score_threshold_config.get("value")
|
||||
|
||||
tool = DatasetRetrieverTool.from_dataset(
|
||||
dataset=dataset,
|
||||
k=k,
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold,
|
||||
callbacks=[DatasetToolCallbackHandler(conversation_message_task)],
|
||||
conversation_message_task=conversation_message_task,
|
||||
return_resource=return_resource,
|
||||
@ -248,7 +256,7 @@ class OrchestratorRuleParser:
|
||||
|
||||
return tool
|
||||
|
||||
def to_web_reader_tool(self, agent_model_instance: BaseLLM) -> Optional[BaseTool]:
|
||||
def to_web_reader_tool(self, tool_config: dict, agent_model_instance: BaseLLM, **kwargs) -> Optional[BaseTool]:
|
||||
"""
|
||||
A tool for reading web pages
|
||||
|
||||
@ -269,15 +277,14 @@ class OrchestratorRuleParser:
|
||||
summary_model_instance = None
|
||||
|
||||
tool = WebReaderTool(
|
||||
llm=summary_model_instance.client if summary_model_instance else None,
|
||||
model_instance=summary_model_instance if summary_model_instance else None,
|
||||
max_chunk_length=4000,
|
||||
continue_reading=True,
|
||||
callbacks=[DifyStdOutCallbackHandler()]
|
||||
continue_reading=True
|
||||
)
|
||||
|
||||
return tool
|
||||
|
||||
def to_google_search_tool(self) -> Optional[BaseTool]:
|
||||
def to_google_search_tool(self, tool_config: dict, **kwargs) -> Optional[BaseTool]:
|
||||
tool_provider = SerpAPIToolProvider(tenant_id=self.tenant_id)
|
||||
func_kwargs = tool_provider.credentials_to_func_kwargs()
|
||||
if not func_kwargs:
|
||||
@ -290,47 +297,39 @@ class OrchestratorRuleParser:
|
||||
"is not up to date. "
|
||||
"Input should be a search query.",
|
||||
func=OptimizedSerpAPIWrapper(**func_kwargs).run,
|
||||
args_schema=OptimizedSerpAPIInput,
|
||||
callbacks=[DifyStdOutCallbackHandler()]
|
||||
args_schema=OptimizedSerpAPIInput
|
||||
)
|
||||
|
||||
return tool
|
||||
|
||||
def to_current_datetime_tool(self) -> Optional[BaseTool]:
|
||||
tool = DatetimeTool(
|
||||
callbacks=[DifyStdOutCallbackHandler()]
|
||||
)
|
||||
def to_current_datetime_tool(self, tool_config: dict, **kwargs) -> Optional[BaseTool]:
|
||||
tool = DatetimeTool()
|
||||
|
||||
return tool
|
||||
|
||||
def to_wikipedia_tool(self) -> Optional[BaseTool]:
|
||||
def to_wikipedia_tool(self, tool_config: dict, **kwargs) -> Optional[BaseTool]:
|
||||
class WikipediaInput(BaseModel):
|
||||
query: str = Field(..., description="search query.")
|
||||
|
||||
return WikipediaQueryRun(
|
||||
name="wikipedia",
|
||||
api_wrapper=WikipediaAPIWrapper(doc_content_chars_max=4000),
|
||||
args_schema=WikipediaInput,
|
||||
callbacks=[DifyStdOutCallbackHandler()]
|
||||
args_schema=WikipediaInput
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _dynamic_calc_retrieve_k(cls, dataset: Dataset, rest_tokens: int) -> int:
|
||||
DEFAULT_K = 2
|
||||
CONTEXT_TOKENS_PERCENT = 0.3
|
||||
MAX_K = 10
|
||||
|
||||
def _dynamic_calc_retrieve_k(cls, dataset: Dataset, top_k: int, rest_tokens: int) -> int:
|
||||
if rest_tokens == -1:
|
||||
return DEFAULT_K
|
||||
return top_k
|
||||
|
||||
processing_rule = dataset.latest_process_rule
|
||||
if not processing_rule:
|
||||
return DEFAULT_K
|
||||
return top_k
|
||||
|
||||
if processing_rule.mode == "custom":
|
||||
rules = processing_rule.rules_dict
|
||||
if not rules:
|
||||
return DEFAULT_K
|
||||
return top_k
|
||||
|
||||
segmentation = rules["segmentation"]
|
||||
segment_max_tokens = segmentation["max_tokens"]
|
||||
@ -338,14 +337,7 @@ class OrchestratorRuleParser:
|
||||
segment_max_tokens = DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens']
|
||||
|
||||
# when rest_tokens is less than default context tokens
|
||||
if rest_tokens < segment_max_tokens * DEFAULT_K:
|
||||
if rest_tokens < segment_max_tokens * top_k:
|
||||
return rest_tokens // segment_max_tokens
|
||||
|
||||
context_limit_tokens = math.floor(rest_tokens * CONTEXT_TOKENS_PERCENT)
|
||||
|
||||
# when context_limit_tokens is less than default context tokens, use default_k
|
||||
if context_limit_tokens <= segment_max_tokens * DEFAULT_K:
|
||||
return DEFAULT_K
|
||||
|
||||
# Expand the k value when there's still some room left in the 30% rest tokens space, but less than the MAX_K
|
||||
return min(context_limit_tokens // segment_max_tokens, MAX_K)
|
||||
return min(top_k, 10)
|
||||
|
||||
83
api/core/prompt/advanced_prompt_templates.py
Normal file
83
api/core/prompt/advanced_prompt_templates.py
Normal file
@ -0,0 +1,83 @@
|
||||
CONTEXT = "Use the following context as your learned knowledge, inside <context></context> XML tags.\n\n<context>\n{{#context#}}\n</context>\n\nWhen answer to user:\n- If you don't know, just say that you don't know.\n- If you don't know when you are not sure, ask for clarification.\nAvoid mentioning that you obtained the information from the context.\nAnd answer according to the language of the user's question.\n"
|
||||
|
||||
BAICHUAN_CONTEXT = "用户在与一个客观的助手对话。助手会尊重找到的材料,给出全面专业的解释,但不会过度演绎。同时回答中不会暴露引用的材料:\n\n```\n{{#context#}}\n```\n\n"
|
||||
|
||||
CHAT_APP_COMPLETION_PROMPT_CONFIG = {
|
||||
"completion_prompt_config": {
|
||||
"prompt": {
|
||||
"text": "{{#pre_prompt#}}\nHere is the chat histories between human and assistant, inside <histories></histories> XML tags.\n\n<histories>\n{{#histories#}}\n</histories>\n\n\nHuman: {{#query#}}\n\nAssistant: "
|
||||
},
|
||||
"conversation_histories_role": {
|
||||
"user_prefix": "Human",
|
||||
"assistant_prefix": "Assistant"
|
||||
}
|
||||
},
|
||||
"stop": ["Human:"]
|
||||
}
|
||||
|
||||
CHAT_APP_CHAT_PROMPT_CONFIG = {
|
||||
"chat_prompt_config": {
|
||||
"prompt": [{
|
||||
"role": "system",
|
||||
"text": "{{#pre_prompt#}}"
|
||||
}]
|
||||
}
|
||||
}
|
||||
|
||||
COMPLETION_APP_CHAT_PROMPT_CONFIG = {
|
||||
"chat_prompt_config": {
|
||||
"prompt": [{
|
||||
"role": "user",
|
||||
"text": "{{#pre_prompt#}}"
|
||||
}]
|
||||
}
|
||||
}
|
||||
|
||||
COMPLETION_APP_COMPLETION_PROMPT_CONFIG = {
|
||||
"completion_prompt_config": {
|
||||
"prompt": {
|
||||
"text": "{{#pre_prompt#}}"
|
||||
}
|
||||
},
|
||||
"stop": ["Human:"]
|
||||
}
|
||||
|
||||
BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG = {
|
||||
"completion_prompt_config": {
|
||||
"prompt": {
|
||||
"text": "{{#pre_prompt#}}\n\n用户和助手的历史对话内容如下:\n```\n{{#histories#}}\n```\n\n\n\n用户:{{#query#}}"
|
||||
},
|
||||
"conversation_histories_role": {
|
||||
"user_prefix": "用户",
|
||||
"assistant_prefix": "助手"
|
||||
}
|
||||
},
|
||||
"stop": ["用户:"]
|
||||
}
|
||||
|
||||
BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG = {
|
||||
"chat_prompt_config": {
|
||||
"prompt": [{
|
||||
"role": "system",
|
||||
"text": "{{#pre_prompt#}}"
|
||||
}]
|
||||
}
|
||||
}
|
||||
|
||||
BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG = {
|
||||
"chat_prompt_config": {
|
||||
"prompt": [{
|
||||
"role": "user",
|
||||
"text": "{{#pre_prompt#}}"
|
||||
}]
|
||||
}
|
||||
}
|
||||
|
||||
BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG = {
|
||||
"completion_prompt_config": {
|
||||
"prompt": {
|
||||
"text": "{{#pre_prompt#}}"
|
||||
}
|
||||
},
|
||||
"stop": ["用户:"]
|
||||
}
|
||||
@ -1,38 +1,24 @@
|
||||
import re
|
||||
from langchain.schema import BaseMessage, SystemMessage, AIMessage, HumanMessage
|
||||
|
||||
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, AIMessagePromptTemplate
|
||||
from langchain.schema import BaseMessage
|
||||
|
||||
from core.prompt.prompt_template import JinjaPromptTemplate
|
||||
from core.prompt.prompt_template import PromptTemplateParser
|
||||
|
||||
|
||||
class PromptBuilder:
|
||||
@classmethod
|
||||
def parse_prompt(cls, prompt: str, inputs: dict) -> str:
|
||||
prompt_template = PromptTemplateParser(prompt)
|
||||
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
|
||||
prompt = prompt_template.format(prompt_inputs)
|
||||
return prompt
|
||||
|
||||
@classmethod
|
||||
def to_system_message(cls, prompt_content: str, inputs: dict) -> BaseMessage:
|
||||
prompt_template = JinjaPromptTemplate.from_template(prompt_content)
|
||||
system_prompt_template = SystemMessagePromptTemplate(prompt=prompt_template)
|
||||
prompt_inputs = {k: inputs[k] for k in system_prompt_template.input_variables if k in inputs}
|
||||
system_message = system_prompt_template.format(**prompt_inputs)
|
||||
return system_message
|
||||
return SystemMessage(content=cls.parse_prompt(prompt_content, inputs))
|
||||
|
||||
@classmethod
|
||||
def to_ai_message(cls, prompt_content: str, inputs: dict) -> BaseMessage:
|
||||
prompt_template = JinjaPromptTemplate.from_template(prompt_content)
|
||||
ai_prompt_template = AIMessagePromptTemplate(prompt=prompt_template)
|
||||
prompt_inputs = {k: inputs[k] for k in ai_prompt_template.input_variables if k in inputs}
|
||||
ai_message = ai_prompt_template.format(**prompt_inputs)
|
||||
return ai_message
|
||||
return AIMessage(content=cls.parse_prompt(prompt_content, inputs))
|
||||
|
||||
@classmethod
|
||||
def to_human_message(cls, prompt_content: str, inputs: dict) -> BaseMessage:
|
||||
prompt_template = JinjaPromptTemplate.from_template(prompt_content)
|
||||
human_prompt_template = HumanMessagePromptTemplate(prompt=prompt_template)
|
||||
human_message = human_prompt_template.format(**inputs)
|
||||
return human_message
|
||||
|
||||
@classmethod
|
||||
def process_template(cls, template: str):
|
||||
processed_template = re.sub(r'\{{2}(.+)\}{2}', r'{\1}', template)
|
||||
# processed_template = re.sub(r'\{([a-zA-Z_]\w+?)\}', r'\1', template)
|
||||
# processed_template = re.sub(r'\{\{([a-zA-Z_]\w+?)\}\}', r'{\1}', processed_template)
|
||||
return processed_template
|
||||
return HumanMessage(content=cls.parse_prompt(prompt_content, inputs))
|
||||
|
||||
@ -1,79 +1,39 @@
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from jinja2 import Environment, meta
|
||||
from langchain import PromptTemplate
|
||||
from langchain.formatting import StrictFormatter
|
||||
REGEX = re.compile(r"\{\{([a-zA-Z_][a-zA-Z0-9_]{1,29}|#histories#|#query#|#context#)\}\}")
|
||||
|
||||
|
||||
class JinjaPromptTemplate(PromptTemplate):
|
||||
template_format: str = "jinja2"
|
||||
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
|
||||
class PromptTemplateParser:
|
||||
"""
|
||||
Rules:
|
||||
|
||||
1. Template variables must be enclosed in `{{}}`.
|
||||
2. The template variable Key can only be: letters + numbers + underscore, with a maximum length of 16 characters,
|
||||
and can only start with letters and underscores.
|
||||
3. The template variable Key cannot contain new lines or spaces, and must comply with rule 2.
|
||||
4. In addition to the above, 3 types of special template variable Keys are accepted:
|
||||
`{{#histories#}}` `{{#query#}}` `{{#context#}}`. No other `{{##}}` template variables are allowed.
|
||||
"""
|
||||
|
||||
def __init__(self, template: str):
|
||||
self.template = template
|
||||
self.variable_keys = self.extract()
|
||||
|
||||
def extract(self) -> list:
|
||||
# Regular expression to match the template rules
|
||||
return re.findall(REGEX, self.template)
|
||||
|
||||
def format(self, inputs: dict, remove_template_variables: bool = True) -> str:
|
||||
def replacer(match):
|
||||
key = match.group(1)
|
||||
value = inputs.get(key, match.group(0)) # return original matched string if key not found
|
||||
|
||||
if remove_template_variables:
|
||||
return PromptTemplateParser.remove_template_variables(value)
|
||||
return value
|
||||
|
||||
return re.sub(REGEX, replacer, self.template)
|
||||
|
||||
@classmethod
|
||||
def from_template(cls, template: str, **kwargs: Any) -> PromptTemplate:
|
||||
"""Load a prompt template from a template."""
|
||||
env = Environment()
|
||||
template = template.replace("{{}}", "{}")
|
||||
ast = env.parse(template)
|
||||
input_variables = meta.find_undeclared_variables(ast)
|
||||
|
||||
if "partial_variables" in kwargs:
|
||||
partial_variables = kwargs["partial_variables"]
|
||||
input_variables = {
|
||||
var for var in input_variables if var not in partial_variables
|
||||
}
|
||||
|
||||
return cls(
|
||||
input_variables=list(sorted(input_variables)), template=template, **kwargs
|
||||
)
|
||||
|
||||
|
||||
class OutLinePromptTemplate(PromptTemplate):
|
||||
@classmethod
|
||||
def from_template(cls, template: str, **kwargs: Any) -> PromptTemplate:
|
||||
"""Load a prompt template from a template."""
|
||||
input_variables = {
|
||||
v for _, v, _, _ in OneLineFormatter().parse(template) if v is not None
|
||||
}
|
||||
return cls(
|
||||
input_variables=list(sorted(input_variables)), template=template, **kwargs
|
||||
)
|
||||
|
||||
def format(self, **kwargs: Any) -> str:
|
||||
"""Format the prompt with the inputs.
|
||||
|
||||
Args:
|
||||
kwargs: Any arguments to be passed to the prompt template.
|
||||
|
||||
Returns:
|
||||
A formatted string.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
prompt.format(variable1="foo")
|
||||
"""
|
||||
kwargs = self._merge_partial_and_user_variables(**kwargs)
|
||||
return OneLineFormatter().format(self.template, **kwargs)
|
||||
|
||||
|
||||
class OneLineFormatter(StrictFormatter):
|
||||
def parse(self, format_string):
|
||||
last_end = 0
|
||||
results = []
|
||||
for match in re.finditer(r"{([a-zA-Z_]\w*)}", format_string):
|
||||
field_name = match.group(1)
|
||||
start, end = match.span()
|
||||
|
||||
literal_text = format_string[last_end:start]
|
||||
last_end = end
|
||||
|
||||
results.append((literal_text, field_name, '', None))
|
||||
|
||||
remaining_literal_text = format_string[last_end:]
|
||||
if remaining_literal_text:
|
||||
results.append((remaining_literal_text, None, None, None))
|
||||
|
||||
return results
|
||||
def remove_template_variables(cls, text: str):
|
||||
return re.sub(REGEX, r'{\1}', text)
|
||||
|
||||
@ -61,36 +61,6 @@ User Input: yo, 你今天咋样?
|
||||
User Input:
|
||||
"""
|
||||
|
||||
CONVERSATION_SUMMARY_PROMPT = (
|
||||
"Please generate a short summary of the following conversation.\n"
|
||||
"If the following conversation communicating in English, you should only return an English summary.\n"
|
||||
"If the following conversation communicating in Chinese, you should only return a Chinese summary.\n"
|
||||
"[Conversation Start]\n"
|
||||
"{context}\n"
|
||||
"[Conversation End]\n\n"
|
||||
"summary:"
|
||||
)
|
||||
|
||||
INTRODUCTION_GENERATE_PROMPT = (
|
||||
"I am designing a product for users to interact with an AI through dialogue. "
|
||||
"The Prompt given to the AI before the conversation is:\n\n"
|
||||
"```\n{prompt}\n```\n\n"
|
||||
"Please generate a brief introduction of no more than 50 words that greets the user, based on this Prompt. "
|
||||
"Do not reveal the developer's motivation or deep logic behind the Prompt, "
|
||||
"but focus on building a relationship with the user:\n"
|
||||
)
|
||||
|
||||
MORE_LIKE_THIS_GENERATE_PROMPT = (
|
||||
"-----\n"
|
||||
"{original_completion}\n"
|
||||
"-----\n\n"
|
||||
"Please use the above content as a sample for generating the result, "
|
||||
"and include key information points related to the original sample in the result. "
|
||||
"Try to rephrase this information in different ways and predict according to the rules below.\n\n"
|
||||
"-----\n"
|
||||
"{prompt}\n"
|
||||
)
|
||||
|
||||
SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
|
||||
"Please help me predict the three most likely questions that human would ask, "
|
||||
"and keeping each question under 20 characters.\n"
|
||||
@ -157,10 +127,10 @@ and fill in variables, with a welcome sentence, and keep TLDR.
|
||||
```
|
||||
|
||||
<< MY INTENDED AUDIENCES >>
|
||||
{audiences}
|
||||
{{audiences}}
|
||||
|
||||
<< HOPING TO SOLVE >>
|
||||
{hoping_to_solve}
|
||||
{{hoping_to_solve}}
|
||||
|
||||
<< OUTPUT >>
|
||||
"""
|
||||
@ -1,21 +1,54 @@
|
||||
from typing import List
|
||||
from typing import List, Optional, Any
|
||||
|
||||
import numpy as np
|
||||
from langchain.embeddings import XinferenceEmbeddings
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from xinference_client.client.restful.restful_client import Client
|
||||
|
||||
|
||||
class XinferenceEmbedding(XinferenceEmbeddings):
|
||||
class XinferenceEmbeddings(Embeddings):
|
||||
client: Any
|
||||
server_url: Optional[str]
|
||||
"""URL of the xinference server"""
|
||||
model_uid: Optional[str]
|
||||
"""UID of the launched model"""
|
||||
|
||||
def __init__(
|
||||
self, server_url: Optional[str] = None, model_uid: Optional[str] = None
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
||||
if server_url is None:
|
||||
raise ValueError("Please provide server URL")
|
||||
|
||||
if model_uid is None:
|
||||
raise ValueError("Please provide the model UID")
|
||||
|
||||
self.server_url = server_url
|
||||
|
||||
self.model_uid = model_uid
|
||||
|
||||
self.client = Client(server_url)
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
vectors = super().embed_documents(texts)
|
||||
model = self.client.get_model(self.model_uid)
|
||||
|
||||
embeddings = [
|
||||
model.create_embedding(text)["data"][0]["embedding"] for text in texts
|
||||
]
|
||||
vectors = [list(map(float, e)) for e in embeddings]
|
||||
normalized_vectors = [(vector / np.linalg.norm(vector)).tolist() for vector in vectors]
|
||||
|
||||
return normalized_vectors
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
vector = super().embed_query(text)
|
||||
model = self.client.get_model(self.model_uid)
|
||||
|
||||
embedding_res = model.create_embedding(text)
|
||||
|
||||
embedding = embedding_res["data"][0]["embedding"]
|
||||
|
||||
vector = list(map(float, embedding))
|
||||
normalized_vector = (vector / np.linalg.norm(vector)).tolist()
|
||||
|
||||
return normalized_vector
|
||||
|
||||
283
api/core/third_party/langchain/llms/minimax_llm.py
vendored
Normal file
283
api/core/third_party/langchain/llms/minimax_llm.py
vendored
Normal file
@ -0,0 +1,283 @@
|
||||
import json
|
||||
from typing import Dict, Any, Optional, List, Tuple, Iterator
|
||||
|
||||
import requests
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.schema import BaseMessage, ChatResult, HumanMessage, AIMessage, SystemMessage
|
||||
from langchain.schema.messages import AIMessageChunk
|
||||
from langchain.schema.output import ChatGenerationChunk, ChatGeneration
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
from pydantic import root_validator, Field, BaseModel
|
||||
|
||||
|
||||
class _MinimaxEndpointClient(BaseModel):
|
||||
"""An API client that talks to a Minimax llm endpoint."""
|
||||
|
||||
host: str
|
||||
group_id: str
|
||||
api_key: str
|
||||
api_url: str
|
||||
|
||||
@root_validator(pre=True)
|
||||
def set_api_url(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
if "api_url" not in values:
|
||||
host = values["host"]
|
||||
group_id = values["group_id"]
|
||||
api_url = f"{host}/v1/text/chatcompletion?GroupId={group_id}"
|
||||
values["api_url"] = api_url
|
||||
return values
|
||||
|
||||
def post(self, **request: Any) -> Any:
|
||||
stream = 'stream' in request and request['stream']
|
||||
|
||||
headers = {"Authorization": f"Bearer {self.api_key}"}
|
||||
response = requests.post(self.api_url, headers=headers, json=request, stream=stream, timeout=(5, 60))
|
||||
if not response.ok:
|
||||
raise ValueError(f"HTTP {response.status_code} error: {response.text}")
|
||||
|
||||
if not stream:
|
||||
if response.json()["base_resp"]["status_code"] > 0:
|
||||
raise ValueError(
|
||||
f"API {response.json()['base_resp']['status_code']}"
|
||||
f" error: {response.json()['base_resp']['status_msg']}"
|
||||
)
|
||||
return response.json()
|
||||
else:
|
||||
return response
|
||||
|
||||
|
||||
class MinimaxChatLLM(BaseChatModel):
|
||||
|
||||
_client: _MinimaxEndpointClient
|
||||
model: str = "abab5.5-chat"
|
||||
"""Model name to use."""
|
||||
max_tokens: int = 256
|
||||
"""Denotes the number of tokens to predict per generation."""
|
||||
temperature: float = 0.7
|
||||
"""A non-negative float that tunes the degree of randomness in generation."""
|
||||
top_p: float = 0.95
|
||||
"""Total probability mass of tokens to consider at each step."""
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
||||
streaming: bool = False
|
||||
"""Whether to stream the response or return it all at once."""
|
||||
minimax_api_host: Optional[str] = None
|
||||
minimax_group_id: Optional[str] = None
|
||||
minimax_api_key: Optional[str] = None
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {"minimax_api_key": "MINIMAX_API_KEY"}
|
||||
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
return True
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
values["minimax_api_key"] = get_from_dict_or_env(
|
||||
values, "minimax_api_key", "MINIMAX_API_KEY"
|
||||
)
|
||||
values["minimax_group_id"] = get_from_dict_or_env(
|
||||
values, "minimax_group_id", "MINIMAX_GROUP_ID"
|
||||
)
|
||||
# Get custom api url from environment.
|
||||
values["minimax_api_host"] = get_from_dict_or_env(
|
||||
values,
|
||||
"minimax_api_host",
|
||||
"MINIMAX_API_HOST",
|
||||
default="https://api.minimax.chat",
|
||||
)
|
||||
values["_client"] = _MinimaxEndpointClient(
|
||||
host=values["minimax_api_host"],
|
||||
api_key=values["minimax_api_key"],
|
||||
group_id=values["minimax_group_id"],
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling OpenAI API."""
|
||||
return {
|
||||
"model": self.model,
|
||||
"tokens_to_generate": self.max_tokens,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"role_meta": {"user_name": "我", "bot_name": "专家"},
|
||||
**self.model_kwargs,
|
||||
}
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return {**{"model": self.model}, **self._default_params}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "minimax"
|
||||
|
||||
def _convert_message_to_dict(self, message: BaseMessage) -> dict:
|
||||
if isinstance(message, HumanMessage):
|
||||
message_dict = {"sender_type": "USER", "text": message.content}
|
||||
elif isinstance(message, AIMessage):
|
||||
message_dict = {"sender_type": "BOT", "text": message.content}
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
return message_dict
|
||||
|
||||
def _create_messages_and_prompt(
|
||||
self, messages: List[BaseMessage]
|
||||
) -> Tuple[List[Dict[str, Any]], str]:
|
||||
prompt = ""
|
||||
dict_messages = []
|
||||
for m in messages:
|
||||
if isinstance(m, SystemMessage):
|
||||
if prompt:
|
||||
prompt += "\n"
|
||||
prompt += f"{m.content}"
|
||||
continue
|
||||
|
||||
message = self._convert_message_to_dict(m)
|
||||
dict_messages.append(message)
|
||||
|
||||
prompt = prompt if prompt else ' '
|
||||
|
||||
return dict_messages, prompt
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if self.streaming:
|
||||
generation: Optional[ChatGenerationChunk] = None
|
||||
llm_output: Optional[Dict] = None
|
||||
for chunk in self._stream(
|
||||
messages=messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
):
|
||||
if generation is None:
|
||||
generation = chunk
|
||||
else:
|
||||
generation += chunk
|
||||
|
||||
if chunk.generation_info is not None \
|
||||
and 'token_usage' in chunk.generation_info:
|
||||
llm_output = {"token_usage": chunk.generation_info['token_usage'], "model_name": self.model}
|
||||
|
||||
assert generation is not None
|
||||
return ChatResult(generations=[generation], llm_output=llm_output)
|
||||
else:
|
||||
message_dicts, prompt = self._create_messages_and_prompt(messages)
|
||||
params = self._default_params
|
||||
params["messages"] = message_dicts
|
||||
params["prompt"] = prompt
|
||||
params.update(kwargs)
|
||||
response = self._client.post(**params)
|
||||
return self._create_chat_result(response, stop)
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
message_dicts, prompt = self._create_messages_and_prompt(messages)
|
||||
params = self._default_params
|
||||
params["messages"] = message_dicts
|
||||
params["prompt"] = prompt
|
||||
params["stream"] = True
|
||||
params.update(kwargs)
|
||||
|
||||
for token in self._client.post(**params).iter_lines():
|
||||
if token:
|
||||
token = token.decode("utf-8")
|
||||
|
||||
if not token.startswith("data:"):
|
||||
data = json.loads(token)
|
||||
if "base_resp" in data and data["base_resp"]["status_code"] > 0:
|
||||
raise ValueError(
|
||||
f"API {data['base_resp']['status_code']}"
|
||||
f" error: {data['base_resp']['status_msg']}"
|
||||
)
|
||||
else:
|
||||
continue
|
||||
|
||||
token = token.lstrip("data:").strip()
|
||||
data = json.loads(token)
|
||||
|
||||
if "base_resp" in data and data["base_resp"]["status_code"] > 0:
|
||||
raise ValueError(
|
||||
f"API {data['base_resp']['status_code']}"
|
||||
f" error: {data['base_resp']['status_msg']}"
|
||||
)
|
||||
|
||||
if not data['choices']:
|
||||
continue
|
||||
|
||||
content = data['choices'][0]['delta']
|
||||
|
||||
chunk_kwargs = {
|
||||
'message': AIMessageChunk(content=content),
|
||||
}
|
||||
|
||||
if 'usage' in data:
|
||||
token_usage = data['usage']
|
||||
overall_token_usage = {
|
||||
'prompt_tokens': 0,
|
||||
'completion_tokens': token_usage.get('total_tokens', 0),
|
||||
'total_tokens': token_usage.get('total_tokens', 0)
|
||||
}
|
||||
chunk_kwargs['generation_info'] = {'token_usage': overall_token_usage}
|
||||
|
||||
yield ChatGenerationChunk(**chunk_kwargs)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(content)
|
||||
|
||||
def _create_chat_result(self, response: Dict[str, Any], stop: Optional[List[str]] = None) -> ChatResult:
|
||||
text = response['reply']
|
||||
if stop is not None:
|
||||
# This is required since the stop tokens
|
||||
# are not enforced by the model parameters
|
||||
text = enforce_stop_tokens(text, stop)
|
||||
|
||||
generations = [ChatGeneration(message=AIMessage(content=text))]
|
||||
usage = response.get("usage")
|
||||
|
||||
# only return total_tokens in minimax response
|
||||
token_usage = {
|
||||
'prompt_tokens': 0,
|
||||
'completion_tokens': usage.get('total_tokens', 0),
|
||||
'total_tokens': usage.get('total_tokens', 0)
|
||||
}
|
||||
llm_output = {"token_usage": token_usage, "model_name": self.model}
|
||||
return ChatResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
"""Get the number of tokens in the messages.
|
||||
|
||||
Useful for checking if an input will fit in a model's context window.
|
||||
|
||||
Args:
|
||||
messages: The message inputs to tokenize.
|
||||
|
||||
Returns:
|
||||
The sum of the number of tokens across the messages.
|
||||
"""
|
||||
return sum([self.get_num_tokens(m.content) for m in messages])
|
||||
|
||||
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
|
||||
token_usage: dict = {}
|
||||
for output in llm_outputs:
|
||||
if output is None:
|
||||
# Happens in streaming
|
||||
continue
|
||||
token_usage = output["token_usage"]
|
||||
|
||||
return {"token_usage": token_usage, "model_name": self.model}
|
||||
@ -1,16 +1,53 @@
|
||||
from typing import Optional, List, Any, Union, Generator
|
||||
from typing import Optional, List, Any, Union, Generator, Mapping
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms import Xinference
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from xinference.client import (
|
||||
from xinference_client.client.restful.restful_client import (
|
||||
RESTfulChatglmCppChatModelHandle,
|
||||
RESTfulChatModelHandle,
|
||||
RESTfulGenerateModelHandle,
|
||||
RESTfulGenerateModelHandle, Client,
|
||||
)
|
||||
|
||||
|
||||
class XinferenceLLM(Xinference):
|
||||
class XinferenceLLM(LLM):
|
||||
client: Any
|
||||
server_url: Optional[str]
|
||||
"""URL of the xinference server"""
|
||||
model_uid: Optional[str]
|
||||
"""UID of the launched model"""
|
||||
|
||||
def __init__(
|
||||
self, server_url: Optional[str] = None, model_uid: Optional[str] = None
|
||||
):
|
||||
super().__init__(
|
||||
**{
|
||||
"server_url": server_url,
|
||||
"model_uid": model_uid,
|
||||
}
|
||||
)
|
||||
|
||||
if self.server_url is None:
|
||||
raise ValueError("Please provide server URL")
|
||||
|
||||
if self.model_uid is None:
|
||||
raise ValueError("Please provide the model UID")
|
||||
|
||||
self.client = Client(server_url)
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "xinference"
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return {
|
||||
**{"server_url": self.server_url},
|
||||
**{"model_uid": self.model_uid},
|
||||
}
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import json
|
||||
from typing import Type
|
||||
from typing import Type, Optional
|
||||
|
||||
from flask import current_app
|
||||
from langchain.tools import BaseTool
|
||||
@ -28,9 +28,10 @@ class DatasetRetrieverTool(BaseTool):
|
||||
|
||||
tenant_id: str
|
||||
dataset_id: str
|
||||
k: int = 3
|
||||
top_k: int = 2
|
||||
score_threshold: Optional[float] = None
|
||||
conversation_message_task: ConversationMessageTask
|
||||
return_resource: str
|
||||
return_resource: bool
|
||||
retriever_from: str
|
||||
|
||||
@classmethod
|
||||
@ -66,7 +67,7 @@ class DatasetRetrieverTool(BaseTool):
|
||||
)
|
||||
)
|
||||
|
||||
documents = kw_table_index.search(query, search_kwargs={'k': self.k})
|
||||
documents = kw_table_index.search(query, search_kwargs={'k': self.top_k})
|
||||
return str("\n".join([document.page_content for document in documents]))
|
||||
else:
|
||||
|
||||
@ -80,20 +81,21 @@ class DatasetRetrieverTool(BaseTool):
|
||||
return ''
|
||||
except ProviderTokenNotInitError:
|
||||
return ''
|
||||
embeddings = CacheEmbedding(embedding_model)
|
||||
|
||||
embeddings = CacheEmbedding(embedding_model)
|
||||
vector_index = VectorIndex(
|
||||
dataset=dataset,
|
||||
config=current_app.config,
|
||||
embeddings=embeddings
|
||||
)
|
||||
|
||||
if self.k > 0:
|
||||
if self.top_k > 0:
|
||||
documents = vector_index.search(
|
||||
query,
|
||||
search_type='similarity_score_threshold',
|
||||
search_kwargs={
|
||||
'k': self.k,
|
||||
'k': self.top_k,
|
||||
'score_threshold': self.score_threshold,
|
||||
'filter': {
|
||||
'group_id': [dataset.id]
|
||||
}
|
||||
|
||||
@ -11,8 +11,8 @@ from typing import Type
|
||||
|
||||
import requests
|
||||
from bs4 import BeautifulSoup, NavigableString, Comment, CData
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chains.summarize import load_summarize_chain
|
||||
from langchain.chains import RefineDocumentsChain
|
||||
from langchain.chains.summarize import refine_prompts
|
||||
from langchain.schema import Document
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from langchain.tools.base import BaseTool
|
||||
@ -20,8 +20,10 @@ from newspaper import Article
|
||||
from pydantic import BaseModel, Field
|
||||
from regex import regex
|
||||
|
||||
from core.chain.llm_chain import LLMChain
|
||||
from core.data_loader import file_extractor
|
||||
from core.data_loader.file_extractor import FileExtractor
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
|
||||
FULL_TEMPLATE = """
|
||||
TITLE: {title}
|
||||
@ -65,7 +67,7 @@ class WebReaderTool(BaseTool):
|
||||
summary_chunk_overlap: int = 0
|
||||
summary_separators: list[str] = ["\n\n", "。", ".", " ", ""]
|
||||
continue_reading: bool = True
|
||||
llm: BaseLanguageModel = None
|
||||
model_instance: BaseLLM = None
|
||||
|
||||
def _run(self, url: str, summary: bool = False, cursor: int = 0) -> str:
|
||||
try:
|
||||
@ -78,7 +80,7 @@ class WebReaderTool(BaseTool):
|
||||
except Exception as e:
|
||||
return f'Read this website failed, caused by: {str(e)}.'
|
||||
|
||||
if summary and self.llm:
|
||||
if summary and self.model_instance:
|
||||
character_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
|
||||
chunk_size=self.summary_chunk_tokens,
|
||||
chunk_overlap=self.summary_chunk_overlap,
|
||||
@ -95,10 +97,9 @@ class WebReaderTool(BaseTool):
|
||||
if len(docs) > 5:
|
||||
docs = docs[:5]
|
||||
|
||||
chain = load_summarize_chain(self.llm, chain_type="refine", callbacks=self.callbacks)
|
||||
chain = self.get_summary_chain()
|
||||
try:
|
||||
page_contents = chain.run(docs)
|
||||
# todo use cache
|
||||
except Exception as e:
|
||||
return f'Read this website failed, caused by: {str(e)}.'
|
||||
else:
|
||||
@ -114,6 +115,23 @@ class WebReaderTool(BaseTool):
|
||||
async def _arun(self, url: str) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_summary_chain(self) -> RefineDocumentsChain:
|
||||
initial_chain = LLMChain(
|
||||
model_instance=self.model_instance,
|
||||
prompt=refine_prompts.PROMPT
|
||||
)
|
||||
refine_chain = LLMChain(
|
||||
model_instance=self.model_instance,
|
||||
prompt=refine_prompts.REFINE_PROMPT
|
||||
)
|
||||
return RefineDocumentsChain(
|
||||
initial_llm_chain=initial_chain,
|
||||
refine_llm_chain=refine_chain,
|
||||
document_variable_name="text",
|
||||
initial_response_name="existing_answer",
|
||||
callbacks=self.callbacks
|
||||
)
|
||||
|
||||
|
||||
def page_result(text: str, cursor: int, max_length: int) -> str:
|
||||
"""Page through `text` and return a substring of `max_length` characters starting from `cursor`."""
|
||||
|
||||
@ -4,5 +4,4 @@ from .clean_when_document_deleted import handle
|
||||
from .clean_when_dataset_deleted import handle
|
||||
from .update_app_dataset_join_when_app_model_config_updated import handle
|
||||
from .generate_conversation_name_when_first_message_created import handle
|
||||
from .generate_conversation_summary_when_few_message_created import handle
|
||||
from .create_document_index import handle
|
||||
|
||||
@ -1,14 +0,0 @@
|
||||
from events.message_event import message_was_created
|
||||
from tasks.generate_conversation_summary_task import generate_conversation_summary_task
|
||||
|
||||
|
||||
@message_was_created.connect
|
||||
def handle(sender, **kwargs):
|
||||
message = sender
|
||||
conversation = kwargs.get('conversation')
|
||||
is_first_message = kwargs.get('is_first_message')
|
||||
|
||||
if not is_first_message and conversation.mode == 'chat' and not conversation.summary:
|
||||
history_message_count = conversation.message_count
|
||||
if history_message_count >= 5:
|
||||
generate_conversation_summary_task.delay(conversation.id)
|
||||
@ -28,6 +28,10 @@ model_config_fields = {
|
||||
'dataset_query_variable': fields.String,
|
||||
'pre_prompt': fields.String,
|
||||
'agent_mode': fields.Raw(attribute='agent_mode_dict'),
|
||||
'prompt_type': fields.String,
|
||||
'chat_prompt_config': fields.Raw(attribute='chat_prompt_config_dict'),
|
||||
'completion_prompt_config': fields.Raw(attribute='completion_prompt_config_dict'),
|
||||
'dataset_configs': fields.Raw(attribute='dataset_configs_dict')
|
||||
}
|
||||
|
||||
app_detail_fields = {
|
||||
|
||||
@ -123,6 +123,7 @@ conversation_with_summary_fields = {
|
||||
'from_end_user_id': fields.String,
|
||||
'from_end_user_session_id': fields.String,
|
||||
'from_account_id': fields.String,
|
||||
'name': fields.String,
|
||||
'summary': fields.String(attribute='summary_or_query'),
|
||||
'read_at': TimestampField,
|
||||
'created_at': TimestampField,
|
||||
|
||||
@ -0,0 +1,37 @@
|
||||
"""add advanced prompt templates
|
||||
|
||||
Revision ID: b3a09c049e8e
|
||||
Revises: 2e9819ca5b28
|
||||
Create Date: 2023-10-10 15:23:23.395420
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'b3a09c049e8e'
|
||||
down_revision = '2e9819ca5b28'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('prompt_type', sa.String(length=255), nullable=False, server_default='simple'))
|
||||
batch_op.add_column(sa.Column('chat_prompt_config', sa.Text(), nullable=True))
|
||||
batch_op.add_column(sa.Column('completion_prompt_config', sa.Text(), nullable=True))
|
||||
batch_op.add_column(sa.Column('dataset_configs', sa.Text(), nullable=True))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
|
||||
batch_op.drop_column('dataset_configs')
|
||||
batch_op.drop_column('completion_prompt_config')
|
||||
batch_op.drop_column('chat_prompt_config')
|
||||
batch_op.drop_column('prompt_type')
|
||||
|
||||
# ### end Alembic commands ###
|
||||
@ -93,6 +93,10 @@ class AppModelConfig(db.Model):
|
||||
agent_mode = db.Column(db.Text)
|
||||
sensitive_word_avoidance = db.Column(db.Text)
|
||||
retriever_resource = db.Column(db.Text)
|
||||
prompt_type = db.Column(db.String(255), nullable=False, default='simple')
|
||||
chat_prompt_config = db.Column(db.Text)
|
||||
completion_prompt_config = db.Column(db.Text)
|
||||
dataset_configs = db.Column(db.Text)
|
||||
|
||||
@property
|
||||
def app(self):
|
||||
@ -139,6 +143,18 @@ class AppModelConfig(db.Model):
|
||||
def agent_mode_dict(self) -> dict:
|
||||
return json.loads(self.agent_mode) if self.agent_mode else {"enabled": False, "strategy": None, "tools": []}
|
||||
|
||||
@property
|
||||
def chat_prompt_config_dict(self) -> dict:
|
||||
return json.loads(self.chat_prompt_config) if self.chat_prompt_config else {}
|
||||
|
||||
@property
|
||||
def completion_prompt_config_dict(self) -> dict:
|
||||
return json.loads(self.completion_prompt_config) if self.completion_prompt_config else {}
|
||||
|
||||
@property
|
||||
def dataset_configs_dict(self) -> dict:
|
||||
return json.loads(self.dataset_configs) if self.dataset_configs else {"top_k": 2, "score_threshold": {"enable": False}}
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"provider": "",
|
||||
@ -155,7 +171,11 @@ class AppModelConfig(db.Model):
|
||||
"user_input_form": self.user_input_form_list,
|
||||
"dataset_query_variable": self.dataset_query_variable,
|
||||
"pre_prompt": self.pre_prompt,
|
||||
"agent_mode": self.agent_mode_dict
|
||||
"agent_mode": self.agent_mode_dict,
|
||||
"prompt_type": self.prompt_type,
|
||||
"chat_prompt_config": self.chat_prompt_config_dict,
|
||||
"completion_prompt_config": self.completion_prompt_config_dict,
|
||||
"dataset_configs": self.dataset_configs_dict
|
||||
}
|
||||
|
||||
def from_model_config_dict(self, model_config: dict):
|
||||
@ -177,6 +197,13 @@ class AppModelConfig(db.Model):
|
||||
self.agent_mode = json.dumps(model_config['agent_mode'])
|
||||
self.retriever_resource = json.dumps(model_config['retriever_resource']) \
|
||||
if model_config.get('retriever_resource') else None
|
||||
self.prompt_type = model_config.get('prompt_type', 'simple')
|
||||
self.chat_prompt_config = json.dumps(model_config.get('chat_prompt_config')) \
|
||||
if model_config.get('chat_prompt_config') else None
|
||||
self.completion_prompt_config = json.dumps(model_config.get('completion_prompt_config')) \
|
||||
if model_config.get('completion_prompt_config') else None
|
||||
self.dataset_configs = json.dumps(model_config.get('dataset_configs')) \
|
||||
if model_config.get('dataset_configs') else None
|
||||
return self
|
||||
|
||||
def copy(self):
|
||||
@ -196,7 +223,12 @@ class AppModelConfig(db.Model):
|
||||
user_input_form=self.user_input_form,
|
||||
dataset_query_variable=self.dataset_query_variable,
|
||||
pre_prompt=self.pre_prompt,
|
||||
agent_mode=self.agent_mode
|
||||
agent_mode=self.agent_mode,
|
||||
retriever_resource=self.retriever_resource,
|
||||
prompt_type=self.prompt_type,
|
||||
chat_prompt_config=self.chat_prompt_config,
|
||||
completion_prompt_config=self.completion_prompt_config,
|
||||
dataset_configs=self.dataset_configs
|
||||
)
|
||||
|
||||
return new_app_model_config
|
||||
|
||||
@ -44,12 +44,12 @@ readabilipy==0.2.0
|
||||
google-search-results==2.4.2
|
||||
replicate~=0.9.0
|
||||
websocket-client~=1.6.1
|
||||
dashscope~=1.5.0
|
||||
dashscope~=1.11.0
|
||||
huggingface_hub~=0.16.4
|
||||
transformers~=4.31.0
|
||||
stripe~=5.5.0
|
||||
pandas==1.5.3
|
||||
xinference==0.5.2
|
||||
xinference-client~=0.1.2
|
||||
safetensors==0.3.2
|
||||
zhipuai==1.0.7
|
||||
werkzeug==2.3.7
|
||||
|
||||
63
api/services/advanced_prompt_template_service.py
Normal file
63
api/services/advanced_prompt_template_service.py
Normal file
@ -0,0 +1,63 @@
|
||||
|
||||
import copy
|
||||
|
||||
from core.prompt.advanced_prompt_templates import CHAT_APP_COMPLETION_PROMPT_CONFIG, CHAT_APP_CHAT_PROMPT_CONFIG, COMPLETION_APP_CHAT_PROMPT_CONFIG, COMPLETION_APP_COMPLETION_PROMPT_CONFIG, \
|
||||
BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG, BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG, BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG, BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG, CONTEXT, BAICHUAN_CONTEXT
|
||||
|
||||
class AdvancedPromptTemplateService:
|
||||
|
||||
@classmethod
|
||||
def get_prompt(cls, args: dict) -> dict:
|
||||
app_mode = args['app_mode']
|
||||
model_mode = args['model_mode']
|
||||
model_name = args['model_name']
|
||||
has_context = args['has_context']
|
||||
|
||||
if 'baichuan' in model_name:
|
||||
return cls.get_baichuan_prompt(app_mode, model_mode, has_context)
|
||||
else:
|
||||
return cls.get_common_prompt(app_mode, model_mode, has_context)
|
||||
|
||||
@classmethod
|
||||
def get_common_prompt(cls, app_mode: str, model_mode:str, has_context: str) -> dict:
|
||||
context_prompt = copy.deepcopy(CONTEXT)
|
||||
|
||||
if app_mode == 'chat':
|
||||
if model_mode == 'completion':
|
||||
return cls.get_completion_prompt(copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt)
|
||||
elif model_mode == 'chat':
|
||||
return cls.get_chat_prompt(copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt)
|
||||
elif app_mode == 'completion':
|
||||
if model_mode == 'completion':
|
||||
return cls.get_completion_prompt(copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt)
|
||||
elif model_mode == 'chat':
|
||||
return cls.get_chat_prompt(copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt)
|
||||
|
||||
@classmethod
|
||||
def get_completion_prompt(cls, prompt_template: dict, has_context: str, context: str) -> dict:
|
||||
if has_context == 'true':
|
||||
prompt_template['completion_prompt_config']['prompt']['text'] = context + prompt_template['completion_prompt_config']['prompt']['text']
|
||||
|
||||
return prompt_template
|
||||
|
||||
@classmethod
|
||||
def get_chat_prompt(cls, prompt_template: dict, has_context: str, context: str) -> dict:
|
||||
if has_context == 'true':
|
||||
prompt_template['chat_prompt_config']['prompt'][0]['text'] = context + prompt_template['chat_prompt_config']['prompt'][0]['text']
|
||||
|
||||
return prompt_template
|
||||
|
||||
@classmethod
|
||||
def get_baichuan_prompt(cls, app_mode: str, model_mode:str, has_context: str) -> dict:
|
||||
baichuan_context_prompt = copy.deepcopy(BAICHUAN_CONTEXT)
|
||||
|
||||
if app_mode == 'chat':
|
||||
if model_mode == 'completion':
|
||||
return cls.get_completion_prompt(copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt)
|
||||
elif model_mode == 'chat':
|
||||
return cls.get_chat_prompt(copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt)
|
||||
elif app_mode == 'completion':
|
||||
if model_mode == 'completion':
|
||||
return cls.get_completion_prompt(copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt)
|
||||
elif model_mode == 'chat':
|
||||
return cls.get_chat_prompt(copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt)
|
||||
@ -3,7 +3,7 @@ import uuid
|
||||
|
||||
from core.agent.agent_executor import PlanningStrategy
|
||||
from core.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from core.model_providers.models.entity.model_params import ModelType
|
||||
from core.model_providers.models.entity.model_params import ModelType, ModelMode
|
||||
from models.account import Account
|
||||
from services.dataset_service import DatasetService
|
||||
|
||||
@ -34,40 +34,31 @@ class AppModelConfigService:
|
||||
# max_tokens
|
||||
if 'max_tokens' not in cp:
|
||||
cp["max_tokens"] = 512
|
||||
#
|
||||
# if not isinstance(cp["max_tokens"], int) or cp["max_tokens"] <= 0 or cp["max_tokens"] > \
|
||||
# llm_constant.max_context_token_length[model_name]:
|
||||
# raise ValueError(
|
||||
# "max_tokens must be an integer greater than 0 "
|
||||
# "and not exceeding the maximum value of the corresponding model")
|
||||
#
|
||||
|
||||
# temperature
|
||||
if 'temperature' not in cp:
|
||||
cp["temperature"] = 1
|
||||
#
|
||||
# if not isinstance(cp["temperature"], (float, int)) or cp["temperature"] < 0 or cp["temperature"] > 2:
|
||||
# raise ValueError("temperature must be a float between 0 and 2")
|
||||
#
|
||||
|
||||
# top_p
|
||||
if 'top_p' not in cp:
|
||||
cp["top_p"] = 1
|
||||
|
||||
# if not isinstance(cp["top_p"], (float, int)) or cp["top_p"] < 0 or cp["top_p"] > 2:
|
||||
# raise ValueError("top_p must be a float between 0 and 2")
|
||||
#
|
||||
# presence_penalty
|
||||
if 'presence_penalty' not in cp:
|
||||
cp["presence_penalty"] = 0
|
||||
|
||||
# if not isinstance(cp["presence_penalty"], (float, int)) or cp["presence_penalty"] < -2 or cp["presence_penalty"] > 2:
|
||||
# raise ValueError("presence_penalty must be a float between -2 and 2")
|
||||
#
|
||||
# presence_penalty
|
||||
if 'frequency_penalty' not in cp:
|
||||
cp["frequency_penalty"] = 0
|
||||
|
||||
# if not isinstance(cp["frequency_penalty"], (float, int)) or cp["frequency_penalty"] < -2 or cp["frequency_penalty"] > 2:
|
||||
# raise ValueError("frequency_penalty must be a float between -2 and 2")
|
||||
# stop
|
||||
if 'stop' not in cp:
|
||||
cp["stop"] = []
|
||||
elif not isinstance(cp["stop"], list):
|
||||
raise ValueError("stop in model.completion_params must be of list type")
|
||||
|
||||
if len(cp["stop"]) > 4:
|
||||
raise ValueError("stop sequences must be less than 4")
|
||||
|
||||
# Filter out extra parameters
|
||||
filtered_cp = {
|
||||
@ -75,7 +66,8 @@ class AppModelConfigService:
|
||||
"temperature": cp["temperature"],
|
||||
"top_p": cp["top_p"],
|
||||
"presence_penalty": cp["presence_penalty"],
|
||||
"frequency_penalty": cp["frequency_penalty"]
|
||||
"frequency_penalty": cp["frequency_penalty"],
|
||||
"stop": cp["stop"]
|
||||
}
|
||||
|
||||
return filtered_cp
|
||||
@ -211,6 +203,10 @@ class AppModelConfigService:
|
||||
model_ids = [m['id'] for m in model_list]
|
||||
if config["model"]["name"] not in model_ids:
|
||||
raise ValueError("model.name must be in the specified model list")
|
||||
|
||||
# model.mode
|
||||
if 'mode' not in config['model'] or not config['model']["mode"]:
|
||||
config['model']["mode"] = ""
|
||||
|
||||
# model.completion_params
|
||||
if 'completion_params' not in config["model"]:
|
||||
@ -339,6 +335,9 @@ class AppModelConfigService:
|
||||
# dataset_query_variable
|
||||
AppModelConfigService.is_dataset_query_variable_valid(config, mode)
|
||||
|
||||
# advanced prompt validation
|
||||
AppModelConfigService.is_advanced_prompt_valid(config, mode)
|
||||
|
||||
# Filter out extra parameters
|
||||
filtered_config = {
|
||||
"opening_statement": config["opening_statement"],
|
||||
@ -351,12 +350,17 @@ class AppModelConfigService:
|
||||
"model": {
|
||||
"provider": config["model"]["provider"],
|
||||
"name": config["model"]["name"],
|
||||
"mode": config['model']["mode"],
|
||||
"completion_params": config["model"]["completion_params"]
|
||||
},
|
||||
"user_input_form": config["user_input_form"],
|
||||
"dataset_query_variable": config.get('dataset_query_variable'),
|
||||
"pre_prompt": config["pre_prompt"],
|
||||
"agent_mode": config["agent_mode"]
|
||||
"agent_mode": config["agent_mode"],
|
||||
"prompt_type": config["prompt_type"],
|
||||
"chat_prompt_config": config["chat_prompt_config"],
|
||||
"completion_prompt_config": config["completion_prompt_config"],
|
||||
"dataset_configs": config["dataset_configs"]
|
||||
}
|
||||
|
||||
return filtered_config
|
||||
@ -375,4 +379,51 @@ class AppModelConfigService:
|
||||
|
||||
if dataset_exists and not dataset_query_variable:
|
||||
raise ValueError("Dataset query variable is required when dataset is exist")
|
||||
|
||||
|
||||
@staticmethod
|
||||
def is_advanced_prompt_valid(config: dict, app_mode: str) -> None:
|
||||
# prompt_type
|
||||
if 'prompt_type' not in config or not config["prompt_type"]:
|
||||
config["prompt_type"] = "simple"
|
||||
|
||||
if config['prompt_type'] not in ['simple', 'advanced']:
|
||||
raise ValueError("prompt_type must be in ['simple', 'advanced']")
|
||||
|
||||
# chat_prompt_config
|
||||
if 'chat_prompt_config' not in config or not config["chat_prompt_config"]:
|
||||
config["chat_prompt_config"] = {}
|
||||
|
||||
if not isinstance(config["chat_prompt_config"], dict):
|
||||
raise ValueError("chat_prompt_config must be of object type")
|
||||
|
||||
# completion_prompt_config
|
||||
if 'completion_prompt_config' not in config or not config["completion_prompt_config"]:
|
||||
config["completion_prompt_config"] = {}
|
||||
|
||||
if not isinstance(config["completion_prompt_config"], dict):
|
||||
raise ValueError("completion_prompt_config must be of object type")
|
||||
|
||||
# dataset_configs
|
||||
if 'dataset_configs' not in config or not config["dataset_configs"]:
|
||||
config["dataset_configs"] = {"top_k": 2, "score_threshold": {"enable": False}}
|
||||
|
||||
if not isinstance(config["dataset_configs"], dict):
|
||||
raise ValueError("dataset_configs must be of object type")
|
||||
|
||||
if config['prompt_type'] == 'advanced':
|
||||
if not config['chat_prompt_config'] and not config['completion_prompt_config']:
|
||||
raise ValueError("chat_prompt_config or completion_prompt_config is required when prompt_type is advanced")
|
||||
|
||||
if config['model']["mode"] not in ['chat', 'completion']:
|
||||
raise ValueError("model.mode must be in ['chat', 'completion'] when prompt_type is advanced")
|
||||
|
||||
if app_mode == 'chat' and config['model']["mode"] == ModelMode.COMPLETION.value:
|
||||
user_prefix = config['completion_prompt_config']['conversation_histories_role']['user_prefix']
|
||||
assistant_prefix = config['completion_prompt_config']['conversation_histories_role']['assistant_prefix']
|
||||
|
||||
if not user_prefix:
|
||||
config['completion_prompt_config']['conversation_histories_role']['user_prefix'] = 'Human'
|
||||
|
||||
if not assistant_prefix:
|
||||
config['completion_prompt_config']['conversation_histories_role']['assistant_prefix'] = 'Assistant'
|
||||
|
||||
@ -244,7 +244,8 @@ class CompletionService:
|
||||
|
||||
@classmethod
|
||||
def generate_more_like_this(cls, app_model: App, user: Union[Account | EndUser],
|
||||
message_id: str, streaming: bool = True) -> Union[dict | Generator]:
|
||||
message_id: str, streaming: bool = True,
|
||||
retriever_from: str = 'dev') -> Union[dict | Generator]:
|
||||
if not user:
|
||||
raise ValueError('user cannot be None')
|
||||
|
||||
@ -266,14 +267,11 @@ class CompletionService:
|
||||
raise MoreLikeThisDisabledError()
|
||||
|
||||
app_model_config = message.app_model_config
|
||||
|
||||
if message.override_model_configs:
|
||||
override_model_configs = json.loads(message.override_model_configs)
|
||||
pre_prompt = override_model_configs.get("pre_prompt", '')
|
||||
elif app_model_config:
|
||||
pre_prompt = app_model_config.pre_prompt
|
||||
else:
|
||||
raise AppModelConfigBrokenError()
|
||||
model_dict = app_model_config.model_dict
|
||||
completion_params = model_dict.get('completion_params')
|
||||
completion_params['temperature'] = 0.9
|
||||
model_dict['completion_params'] = completion_params
|
||||
app_model_config.model = json.dumps(model_dict)
|
||||
|
||||
generate_task_id = str(uuid.uuid4())
|
||||
|
||||
@ -282,58 +280,28 @@ class CompletionService:
|
||||
|
||||
user = cls.get_real_user_instead_of_proxy_obj(user)
|
||||
|
||||
generate_worker_thread = threading.Thread(target=cls.generate_more_like_this_worker, kwargs={
|
||||
generate_worker_thread = threading.Thread(target=cls.generate_worker, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'generate_task_id': generate_task_id,
|
||||
'detached_app_model': app_model,
|
||||
'app_model_config': app_model_config,
|
||||
'detached_message': message,
|
||||
'pre_prompt': pre_prompt,
|
||||
'query': message.query,
|
||||
'inputs': message.inputs,
|
||||
'detached_user': user,
|
||||
'streaming': streaming
|
||||
'detached_conversation': None,
|
||||
'streaming': streaming,
|
||||
'is_model_config_override': True,
|
||||
'retriever_from': retriever_from
|
||||
})
|
||||
|
||||
generate_worker_thread.start()
|
||||
|
||||
cls.countdown_and_close(current_app._get_current_object(), generate_worker_thread, pubsub, user, generate_task_id)
|
||||
# wait for 10 minutes to close the thread
|
||||
cls.countdown_and_close(current_app._get_current_object(), generate_worker_thread, pubsub, user,
|
||||
generate_task_id)
|
||||
|
||||
return cls.compact_response(pubsub, streaming)
|
||||
|
||||
@classmethod
|
||||
def generate_more_like_this_worker(cls, flask_app: Flask, generate_task_id: str, detached_app_model: App,
|
||||
app_model_config: AppModelConfig, detached_message: Message, pre_prompt: str,
|
||||
detached_user: Union[Account, EndUser], streaming: bool):
|
||||
with flask_app.app_context():
|
||||
# fixed the state of the model object when it detached from the original session
|
||||
user = db.session.merge(detached_user)
|
||||
app_model = db.session.merge(detached_app_model)
|
||||
message = db.session.merge(detached_message)
|
||||
|
||||
try:
|
||||
# run
|
||||
Completion.generate_more_like_this(
|
||||
task_id=generate_task_id,
|
||||
app=app_model,
|
||||
user=user,
|
||||
message=message,
|
||||
pre_prompt=pre_prompt,
|
||||
app_model_config=app_model_config,
|
||||
streaming=streaming
|
||||
)
|
||||
except ConversationTaskStoppedException:
|
||||
pass
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError,
|
||||
ModelCurrentlyNotSupportError) as e:
|
||||
PubHandler.pub_error(user, generate_task_id, e)
|
||||
except LLMAuthorizationError:
|
||||
PubHandler.pub_error(user, generate_task_id, LLMAuthorizationError('Incorrect API key provided'))
|
||||
except Exception as e:
|
||||
logging.exception("Unknown Error in completion")
|
||||
PubHandler.pub_error(user, generate_task_id, e)
|
||||
finally:
|
||||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
def get_cleaned_inputs(cls, user_inputs: dict, app_model_config: AppModelConfig):
|
||||
if user_inputs is None:
|
||||
|
||||
@ -385,9 +385,6 @@ class DocumentService:
|
||||
|
||||
@staticmethod
|
||||
def delete_document(document):
|
||||
if document.indexing_status in ["parsing", "cleaning", "splitting", "indexing"]:
|
||||
raise DocumentIndexingError()
|
||||
|
||||
# trigger document_was_deleted signal
|
||||
document_was_deleted.send(document.id, dataset_id=document.dataset_id)
|
||||
|
||||
|
||||
@ -482,6 +482,9 @@ class ProviderService:
|
||||
'features': []
|
||||
}
|
||||
|
||||
if 'mode' in model:
|
||||
valid_model_dict['model_mode'] = model['mode']
|
||||
|
||||
if 'features' in model:
|
||||
valid_model_dict['features'] = model['features']
|
||||
|
||||
|
||||
@ -31,22 +31,24 @@ def clean_document_task(document_id: str, dataset_id: str):
|
||||
kw_index = IndexBuilder.get_index(dataset, 'economy')
|
||||
|
||||
segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
# check segment is exist
|
||||
if segments:
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
|
||||
# delete from vector index
|
||||
if vector_index:
|
||||
vector_index.delete_by_document_id(document_id)
|
||||
# delete from vector index
|
||||
if vector_index:
|
||||
vector_index.delete_by_document_id(document_id)
|
||||
|
||||
# delete from keyword index
|
||||
if index_node_ids:
|
||||
kw_index.delete_by_ids(index_node_ids)
|
||||
# delete from keyword index
|
||||
if index_node_ids:
|
||||
kw_index.delete_by_ids(index_node_ids)
|
||||
|
||||
for segment in segments:
|
||||
db.session.delete(segment)
|
||||
for segment in segments:
|
||||
db.session.delete(segment)
|
||||
|
||||
db.session.commit()
|
||||
end_at = time.perf_counter()
|
||||
logging.info(
|
||||
click.style('Cleaned document when document deleted: {} latency: {}'.format(document_id, end_at - start_at), fg='green'))
|
||||
db.session.commit()
|
||||
end_at = time.perf_counter()
|
||||
logging.info(
|
||||
click.style('Cleaned document when document deleted: {} latency: {}'.format(document_id, end_at - start_at), fg='green'))
|
||||
except Exception:
|
||||
logging.exception("Cleaned document when document deleted failed")
|
||||
|
||||
@ -30,13 +30,11 @@ def document_indexing_task(dataset_id: str, document_ids: list):
|
||||
Document.dataset_id == dataset_id
|
||||
).first()
|
||||
|
||||
if not document:
|
||||
raise NotFound('Document not found')
|
||||
|
||||
document.indexing_status = 'parsing'
|
||||
document.processing_started_at = datetime.datetime.utcnow()
|
||||
documents.append(document)
|
||||
db.session.add(document)
|
||||
if document:
|
||||
document.indexing_status = 'parsing'
|
||||
document.processing_started_at = datetime.datetime.utcnow()
|
||||
documents.append(document)
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
|
||||
@ -1,55 +0,0 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from core.generator.llm_generator import LLMGenerator
|
||||
from core.model_providers.error import LLMError, ProviderTokenNotInitError
|
||||
from extensions.ext_database import db
|
||||
from models.model import Conversation, Message
|
||||
|
||||
|
||||
@shared_task(queue='generation')
|
||||
def generate_conversation_summary_task(conversation_id: str):
|
||||
"""
|
||||
Async Generate conversation summary
|
||||
:param conversation_id:
|
||||
|
||||
Usage: generate_conversation_summary_task.delay(conversation_id)
|
||||
"""
|
||||
logging.info(click.style('Start generate conversation summary: {}'.format(conversation_id), fg='green'))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
conversation = db.session.query(Conversation).filter(Conversation.id == conversation_id).first()
|
||||
if not conversation:
|
||||
raise NotFound('Conversation not found')
|
||||
|
||||
try:
|
||||
# get conversation messages count
|
||||
history_message_count = conversation.message_count
|
||||
if history_message_count >= 5 and not conversation.summary:
|
||||
app_model = conversation.app
|
||||
if not app_model:
|
||||
return
|
||||
|
||||
history_messages = db.session.query(Message).filter(Message.conversation_id == conversation.id) \
|
||||
.order_by(Message.created_at.asc()).all()
|
||||
|
||||
conversation.summary = LLMGenerator.generate_conversation_summary(app_model.tenant_id, history_messages)
|
||||
db.session.add(conversation)
|
||||
db.session.commit()
|
||||
except (LLMError, ProviderTokenNotInitError):
|
||||
conversation.summary = '[No Summary]'
|
||||
db.session.commit()
|
||||
pass
|
||||
except Exception as e:
|
||||
conversation.summary = '[No Summary]'
|
||||
db.session.commit()
|
||||
logging.exception(e)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logging.info(
|
||||
click.style('Conversation summary generated: {} latency: {}'.format(conversation_id, end_at - start_at),
|
||||
fg='green'))
|
||||
@ -44,7 +44,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key):
|
||||
def test_get_num_tokens(mock_decrypt):
|
||||
model = get_mock_model('claude-2')
|
||||
rst = model.get_num_tokens([
|
||||
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
|
||||
PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
|
||||
])
|
||||
assert rst == 6
|
||||
|
||||
|
||||
@ -69,7 +69,7 @@ def test_chat_get_num_tokens(mock_decrypt, mocker):
|
||||
openai_model = get_mock_azure_openai_model('gpt-35-turbo', mocker)
|
||||
rst = openai_model.get_num_tokens([
|
||||
PromptMessage(type=MessageType.SYSTEM, content='you are a kindness Assistant.'),
|
||||
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
|
||||
PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
|
||||
])
|
||||
assert rst == 22
|
||||
|
||||
|
||||
@ -48,7 +48,7 @@ def test_chat_get_num_tokens(mock_decrypt):
|
||||
model = get_mock_model('baichuan2-53b')
|
||||
rst = model.get_num_tokens([
|
||||
PromptMessage(type=MessageType.SYSTEM, content='you are a kindness Assistant.'),
|
||||
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
|
||||
PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
|
||||
])
|
||||
assert rst > 0
|
||||
|
||||
@ -59,7 +59,7 @@ def test_chat_run(mock_decrypt, mocker):
|
||||
|
||||
model = get_mock_model('baichuan2-53b')
|
||||
messages = [
|
||||
PromptMessage(type=MessageType.HUMAN, content='Are you Human? you MUST only answer `y` or `n`?')
|
||||
PromptMessage(type=MessageType.USER, content='Are you Human? you MUST only answer `y` or `n`?')
|
||||
]
|
||||
rst = model.run(
|
||||
messages,
|
||||
@ -73,7 +73,7 @@ def test_chat_stream_run(mock_decrypt, mocker):
|
||||
|
||||
model = get_mock_model('baichuan2-53b', streaming=True)
|
||||
messages = [
|
||||
PromptMessage(type=MessageType.HUMAN, content='Are you Human? you MUST only answer `y` or `n`?')
|
||||
PromptMessage(type=MessageType.USER, content='Are you Human? you MUST only answer `y` or `n`?')
|
||||
]
|
||||
rst = model.run(
|
||||
messages
|
||||
|
||||
@ -71,7 +71,7 @@ def test_hosted_inference_api_get_num_tokens(mock_decrypt, mock_model_info, mock
|
||||
mocker
|
||||
)
|
||||
rst = model.get_num_tokens([
|
||||
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
|
||||
PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
|
||||
])
|
||||
assert rst == 5
|
||||
|
||||
@ -88,7 +88,7 @@ def test_inference_endpoints_get_num_tokens(mock_decrypt, mock_model_info, mocke
|
||||
mocker
|
||||
)
|
||||
rst = model.get_num_tokens([
|
||||
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
|
||||
PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
|
||||
])
|
||||
assert rst == 5
|
||||
|
||||
|
||||
@ -48,7 +48,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key):
|
||||
def test_get_num_tokens(mock_decrypt):
|
||||
model = get_mock_model('abab5.5-chat')
|
||||
rst = model.get_num_tokens([
|
||||
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
|
||||
PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
|
||||
])
|
||||
assert rst == 5
|
||||
|
||||
|
||||
@ -52,7 +52,7 @@ def test_chat_get_num_tokens(mock_decrypt):
|
||||
openai_model = get_mock_openai_model('gpt-3.5-turbo')
|
||||
rst = openai_model.get_num_tokens([
|
||||
PromptMessage(type=MessageType.SYSTEM, content='you are a kindness Assistant.'),
|
||||
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
|
||||
PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
|
||||
])
|
||||
assert rst == 22
|
||||
|
||||
|
||||
@ -55,7 +55,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key):
|
||||
def test_get_num_tokens(mock_decrypt, mocker):
|
||||
model = get_mock_model('facebook/opt-125m', mocker)
|
||||
rst = model.get_num_tokens([
|
||||
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
|
||||
PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
|
||||
])
|
||||
assert rst == 5
|
||||
|
||||
|
||||
@ -58,7 +58,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key):
|
||||
def test_get_num_tokens(mock_decrypt, mocker):
|
||||
model = get_mock_model('a16z-infra/llama-2-13b-chat', '2a7f981751ec7fdf87b5b91ad4db53683a98082e9ff7bfd12c8cd5ea85980a52', mocker)
|
||||
rst = model.get_num_tokens([
|
||||
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
|
||||
PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
|
||||
])
|
||||
assert rst == 7
|
||||
|
||||
|
||||
@ -52,7 +52,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key):
|
||||
def test_get_num_tokens(mock_decrypt):
|
||||
model = get_mock_model('spark')
|
||||
rst = model.get_num_tokens([
|
||||
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
|
||||
PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
|
||||
])
|
||||
assert rst == 6
|
||||
|
||||
|
||||
@ -44,9 +44,9 @@ def decrypt_side_effect(tenant_id, encrypted_api_key):
|
||||
|
||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||
def test_get_num_tokens(mock_decrypt):
|
||||
model = get_mock_model('qwen-v1')
|
||||
model = get_mock_model('qwen-turbo')
|
||||
rst = model.get_num_tokens([
|
||||
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
|
||||
PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
|
||||
])
|
||||
assert rst == 5
|
||||
|
||||
@ -55,7 +55,7 @@ def test_get_num_tokens(mock_decrypt):
|
||||
def test_run(mock_decrypt, mocker):
|
||||
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
|
||||
|
||||
model = get_mock_model('qwen-v1')
|
||||
model = get_mock_model('qwen-turbo')
|
||||
rst = model.run(
|
||||
[PromptMessage(content='Human: Are you Human? you MUST only answer `y` or `n`? \nAssistant: ')],
|
||||
stop=['\nHuman:'],
|
||||
|
||||
@ -46,7 +46,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key):
|
||||
def test_get_num_tokens(mock_decrypt):
|
||||
model = get_mock_model('ernie-bot')
|
||||
rst = model.get_num_tokens([
|
||||
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
|
||||
PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
|
||||
])
|
||||
assert rst == 5
|
||||
|
||||
|
||||
@ -57,7 +57,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key):
|
||||
def test_get_num_tokens(mock_decrypt, mocker):
|
||||
model = get_mock_model('llama-2-chat', mocker)
|
||||
rst = model.get_num_tokens([
|
||||
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
|
||||
PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
|
||||
])
|
||||
assert rst == 5
|
||||
|
||||
|
||||
@ -46,7 +46,7 @@ def test_chat_get_num_tokens(mock_decrypt):
|
||||
model = get_mock_model('chatglm_lite')
|
||||
rst = model.get_num_tokens([
|
||||
PromptMessage(type=MessageType.SYSTEM, content='you are a kindness Assistant.'),
|
||||
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
|
||||
PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
|
||||
])
|
||||
assert rst > 0
|
||||
|
||||
@ -57,7 +57,7 @@ def test_chat_run(mock_decrypt, mocker):
|
||||
|
||||
model = get_mock_model('chatglm_lite')
|
||||
messages = [
|
||||
PromptMessage(type=MessageType.HUMAN, content='Are you Human? you MUST only answer `y` or `n`?')
|
||||
PromptMessage(type=MessageType.USER, content='Are you Human? you MUST only answer `y` or `n`?')
|
||||
]
|
||||
rst = model.run(
|
||||
messages,
|
||||
@ -71,7 +71,7 @@ def test_chat_stream_run(mock_decrypt, mocker):
|
||||
|
||||
model = get_mock_model('chatglm_lite', streaming=True)
|
||||
messages = [
|
||||
PromptMessage(type=MessageType.HUMAN, content='Are you Human? you MUST only answer `y` or `n`?')
|
||||
PromptMessage(type=MessageType.USER, content='Are you Human? you MUST only answer `y` or `n`?')
|
||||
]
|
||||
rst = model.run(
|
||||
messages
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from typing import Type
|
||||
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules
|
||||
from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules, ModelMode
|
||||
from core.model_providers.models.llm.openai_model import OpenAIModel
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
|
||||
@ -12,7 +12,10 @@ class FakeModelProvider(BaseModelProvider):
|
||||
return 'fake'
|
||||
|
||||
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
|
||||
return [{'id': 'test_model', 'name': 'Test Model'}]
|
||||
return [{'id': 'test_model', 'name': 'Test Model', 'mode': 'completion'}]
|
||||
|
||||
def _get_text_generation_model_mode(self, model_name) -> str:
|
||||
return ModelMode.COMPLETION.value
|
||||
|
||||
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
|
||||
return OpenAIModel
|
||||
|
||||
@ -24,7 +24,7 @@ def test_get_supported_model_list(mocker):
|
||||
provider = FakeModelProvider(provider=Provider())
|
||||
result = provider.get_supported_model_list(ModelType.TEXT_GENERATION)
|
||||
|
||||
assert result == [{'id': 'test_model', 'name': 'test_model'}]
|
||||
assert result == [{'id': 'test_model', 'name': 'test_model', 'mode': 'completion'}]
|
||||
|
||||
|
||||
def test_check_quota_over_limit(mocker):
|
||||
|
||||
@ -2,6 +2,8 @@ import pytest
|
||||
from unittest.mock import patch
|
||||
import json
|
||||
|
||||
from langchain.schema import ChatResult, ChatGeneration, AIMessage
|
||||
|
||||
from core.model_providers.providers.base import CredentialsValidateFailedError
|
||||
from core.model_providers.providers.minimax_provider import MinimaxProvider
|
||||
from models.provider import ProviderType, Provider
|
||||
@ -24,7 +26,8 @@ def decrypt_side_effect(tenant_id, encrypted_key):
|
||||
|
||||
|
||||
def test_is_provider_credentials_valid_or_raise_valid(mocker):
|
||||
mocker.patch('langchain.llms.minimax.Minimax._call', return_value='abc')
|
||||
mocker.patch('core.third_party.langchain.llms.minimax_llm.MinimaxChatLLM._generate',
|
||||
return_value=ChatResult(generations=[ChatGeneration(message=AIMessage(content='abc'))]))
|
||||
|
||||
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(VALIDATE_CREDENTIAL)
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@ version: '3.1'
|
||||
services:
|
||||
# API service
|
||||
api:
|
||||
image: langgenius/dify-api:0.3.26
|
||||
image: langgenius/dify-api:0.3.28
|
||||
restart: always
|
||||
environment:
|
||||
# Startup mode, 'api' starts the API server.
|
||||
@ -124,7 +124,7 @@ services:
|
||||
# worker service
|
||||
# The Celery worker for processing the queue.
|
||||
worker:
|
||||
image: langgenius/dify-api:0.3.26
|
||||
image: langgenius/dify-api:0.3.28
|
||||
restart: always
|
||||
environment:
|
||||
# Startup mode, 'worker' starts the Celery worker for processing the queue.
|
||||
@ -192,7 +192,7 @@ services:
|
||||
|
||||
# Frontend web application.
|
||||
web:
|
||||
image: langgenius/dify-web:0.3.26
|
||||
image: langgenius/dify-web:0.3.28
|
||||
restart: always
|
||||
environment:
|
||||
EDITION: SELF_HOSTED
|
||||
|
||||
@ -122,7 +122,7 @@ const NewAppDialog = ({ show, onSuccess, onClose }: NewAppDialogProps) => {
|
||||
<input ref={nameInputRef} className='h-10 px-3 text-sm font-normal bg-gray-100 rounded-lg grow' placeholder={t('app.appNamePlaceholder') || ''}/>
|
||||
</div>
|
||||
|
||||
<div className='h-[247px]'>
|
||||
<div className='h-[247px] overflow-y-auto'>
|
||||
<div className={style.newItemCaption}>
|
||||
<h3 className='inline'>{t('app.newApp.captionAppType')}</h3>
|
||||
{isWithTemplate && (
|
||||
|
||||
@ -2,127 +2,23 @@ import { CodeGroup } from '@/app/components/develop/code.tsx'
|
||||
import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from '@/app/components/develop/md.tsx'
|
||||
|
||||
# Dataset API
|
||||
<br/>
|
||||
<br/>
|
||||
<Heading
|
||||
url='/datasets'
|
||||
method='POST'
|
||||
title='Create an empty dataset'
|
||||
name='#create_empty_dataset'
|
||||
/>
|
||||
<Row>
|
||||
<Col>
|
||||
### Request Body
|
||||
<Properties>
|
||||
<Property name='name' type='string' key='name'>
|
||||
Dataset name
|
||||
</Property>
|
||||
</Properties>
|
||||
</Col>
|
||||
<Col sticky>
|
||||
<CodeGroup
|
||||
title="Request"
|
||||
tag="POST"
|
||||
label="/datasets"
|
||||
targetCode={`curl --location --request POST '${props.apiBaseUrl}/datasets' \\\n--header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{"name": "name"}'`}
|
||||
>
|
||||
```bash {{ title: 'cURL' }}
|
||||
curl --location --request POST '${apiBaseUrl}/v1/datasets' \
|
||||
--header 'Authorization: Bearer {api_key}' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data-raw '{
|
||||
"name": "name"
|
||||
}'
|
||||
```
|
||||
</CodeGroup>
|
||||
<CodeGroup title="Response">
|
||||
```json {{ title: 'Response' }}
|
||||
{
|
||||
"id": "",
|
||||
"name": "name",
|
||||
"description": null,
|
||||
"provider": "vendor",
|
||||
"permission": "only_me",
|
||||
"data_source_type": null,
|
||||
"indexing_technique": null,
|
||||
"app_count": 0,
|
||||
"document_count": 0,
|
||||
"word_count": 0,
|
||||
"created_by": "",
|
||||
"created_at": 1695636173,
|
||||
"updated_by": "",
|
||||
"updated_at": 1695636173,
|
||||
"embedding_model": null,
|
||||
"embedding_model_provider": null,
|
||||
"embedding_available": null
|
||||
}
|
||||
```
|
||||
</CodeGroup>
|
||||
</Col>
|
||||
</Row>
|
||||
|
||||
---
|
||||
<div>
|
||||
### Authentication
|
||||
|
||||
Service API of Dify authenticates using an `API-Key`.
|
||||
|
||||
It is suggested that developers store the `API-Key` in the backend instead of sharing or storing it in the client side to avoid the leakage of the `API-Key`, which may lead to property loss.
|
||||
|
||||
All API requests should include your `API-Key` in the **`Authorization`** HTTP Header, as shown below:
|
||||
|
||||
<CodeGroup title="Code">
|
||||
```javascript
|
||||
Authorization: Bearer {API_KEY}
|
||||
|
||||
<Heading
|
||||
url='/datasets'
|
||||
method='GET'
|
||||
title='Dataset list'
|
||||
name='#dataset_list'
|
||||
/>
|
||||
<Row>
|
||||
<Col>
|
||||
### Query
|
||||
<Properties>
|
||||
<Property name='page' type='string' key='page'>
|
||||
Page number
|
||||
</Property>
|
||||
<Property name='limit' type='string' key='limit'>
|
||||
Number of items returned, default 20, range 1-100
|
||||
</Property>
|
||||
</Properties>
|
||||
</Col>
|
||||
<Col sticky>
|
||||
<CodeGroup
|
||||
title="Request"
|
||||
tag="POST"
|
||||
label="/datasets"
|
||||
targetCode={`curl --location --request GET '${props.apiBaseUrl}/datasets?page=1&limit=20' \\\n--header 'Authorization: Bearer {api_key}'`}
|
||||
>
|
||||
```bash {{ title: 'cURL' }}
|
||||
curl --location --request GET '${props.apiBaseUrl}/datasets?page=1&limit=20' \
|
||||
--header 'Authorization: Bearer {api_key}'
|
||||
```
|
||||
</CodeGroup>
|
||||
<CodeGroup title="Response">
|
||||
```json {{ title: 'Response' }}
|
||||
{
|
||||
"data": [
|
||||
{
|
||||
"id": "",
|
||||
"name": "name",
|
||||
"description": "desc",
|
||||
"permission": "only_me",
|
||||
"data_source_type": "upload_file",
|
||||
"indexing_technique": "",
|
||||
"app_count": 2,
|
||||
"document_count": 10,
|
||||
"word_count": 1200,
|
||||
"created_by": "",
|
||||
"created_at": "",
|
||||
"updated_by": "",
|
||||
"updated_at": ""
|
||||
},
|
||||
...
|
||||
],
|
||||
"has_more": true,
|
||||
"limit": 20,
|
||||
"total": 50,
|
||||
"page": 1
|
||||
}
|
||||
```
|
||||
</CodeGroup>
|
||||
</Col>
|
||||
</Row>
|
||||
</CodeGroup>
|
||||
</div>
|
||||
|
||||
---
|
||||
|
||||
@ -329,6 +225,128 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
||||
|
||||
---
|
||||
|
||||
<Heading
|
||||
url='/datasets'
|
||||
method='POST'
|
||||
title='Create an empty dataset'
|
||||
name='#create_empty_dataset'
|
||||
/>
|
||||
<Row>
|
||||
<Col>
|
||||
### Request Body
|
||||
<Properties>
|
||||
<Property name='name' type='string' key='name'>
|
||||
Dataset name
|
||||
</Property>
|
||||
</Properties>
|
||||
</Col>
|
||||
<Col sticky>
|
||||
<CodeGroup
|
||||
title="Request"
|
||||
tag="POST"
|
||||
label="/datasets"
|
||||
targetCode={`curl --location --request POST '${props.apiBaseUrl}/datasets' \\\n--header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{"name": "name"}'`}
|
||||
>
|
||||
```bash {{ title: 'cURL' }}
|
||||
curl --location --request POST '${apiBaseUrl}/v1/datasets' \
|
||||
--header 'Authorization: Bearer {api_key}' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data-raw '{
|
||||
"name": "name"
|
||||
}'
|
||||
```
|
||||
</CodeGroup>
|
||||
<CodeGroup title="Response">
|
||||
```json {{ title: 'Response' }}
|
||||
{
|
||||
"id": "",
|
||||
"name": "name",
|
||||
"description": null,
|
||||
"provider": "vendor",
|
||||
"permission": "only_me",
|
||||
"data_source_type": null,
|
||||
"indexing_technique": null,
|
||||
"app_count": 0,
|
||||
"document_count": 0,
|
||||
"word_count": 0,
|
||||
"created_by": "",
|
||||
"created_at": 1695636173,
|
||||
"updated_by": "",
|
||||
"updated_at": 1695636173,
|
||||
"embedding_model": null,
|
||||
"embedding_model_provider": null,
|
||||
"embedding_available": null
|
||||
}
|
||||
```
|
||||
</CodeGroup>
|
||||
</Col>
|
||||
</Row>
|
||||
|
||||
---
|
||||
|
||||
<Heading
|
||||
url='/datasets'
|
||||
method='GET'
|
||||
title='Dataset list'
|
||||
name='#dataset_list'
|
||||
/>
|
||||
<Row>
|
||||
<Col>
|
||||
### Query
|
||||
<Properties>
|
||||
<Property name='page' type='string' key='page'>
|
||||
Page number
|
||||
</Property>
|
||||
<Property name='limit' type='string' key='limit'>
|
||||
Number of items returned, default 20, range 1-100
|
||||
</Property>
|
||||
</Properties>
|
||||
</Col>
|
||||
<Col sticky>
|
||||
<CodeGroup
|
||||
title="Request"
|
||||
tag="POST"
|
||||
label="/datasets"
|
||||
targetCode={`curl --location --request GET '${props.apiBaseUrl}/datasets?page=1&limit=20' \\\n--header 'Authorization: Bearer {api_key}'`}
|
||||
>
|
||||
```bash {{ title: 'cURL' }}
|
||||
curl --location --request GET '${props.apiBaseUrl}/datasets?page=1&limit=20' \
|
||||
--header 'Authorization: Bearer {api_key}'
|
||||
```
|
||||
</CodeGroup>
|
||||
<CodeGroup title="Response">
|
||||
```json {{ title: 'Response' }}
|
||||
{
|
||||
"data": [
|
||||
{
|
||||
"id": "",
|
||||
"name": "name",
|
||||
"description": "desc",
|
||||
"permission": "only_me",
|
||||
"data_source_type": "upload_file",
|
||||
"indexing_technique": "",
|
||||
"app_count": 2,
|
||||
"document_count": 10,
|
||||
"word_count": 1200,
|
||||
"created_by": "",
|
||||
"created_at": "",
|
||||
"updated_by": "",
|
||||
"updated_at": ""
|
||||
},
|
||||
...
|
||||
],
|
||||
"has_more": true,
|
||||
"limit": 20,
|
||||
"total": 50,
|
||||
"page": 1
|
||||
}
|
||||
```
|
||||
</CodeGroup>
|
||||
</Col>
|
||||
</Row>
|
||||
|
||||
---
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/documents/{document_id}/update_by_text'
|
||||
method='POST'
|
||||
|
||||
@ -2,127 +2,23 @@ import { CodeGroup } from '@/app/components/develop/code.tsx'
|
||||
import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from '@/app/components/develop/md.tsx'
|
||||
|
||||
# 数据集 API
|
||||
<br/>
|
||||
<br/>
|
||||
<Heading
|
||||
url='/datasets'
|
||||
method='POST'
|
||||
title='创建空数据集'
|
||||
name='#create_empty_dataset'
|
||||
/>
|
||||
<Row>
|
||||
<Col>
|
||||
### Request Body
|
||||
<Properties>
|
||||
<Property name='name' type='string' key='name'>
|
||||
数据集名称
|
||||
</Property>
|
||||
</Properties>
|
||||
</Col>
|
||||
<Col sticky>
|
||||
<CodeGroup
|
||||
title="Request"
|
||||
tag="POST"
|
||||
label="/datasets"
|
||||
targetCode={`curl --location --request POST '${props.apiBaseUrl}/datasets' \\\n--header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{"name": "name"}'`}
|
||||
>
|
||||
```bash {{ title: 'cURL' }}
|
||||
curl --location --request POST '${props.apiBaseUrl}/datasets' \
|
||||
--header 'Authorization: Bearer {api_key}' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data-raw '{
|
||||
"name": "name"
|
||||
}'
|
||||
```
|
||||
</CodeGroup>
|
||||
<CodeGroup title="Response">
|
||||
```json {{ title: 'Response' }}
|
||||
{
|
||||
"id": "",
|
||||
"name": "name",
|
||||
"description": null,
|
||||
"provider": "vendor",
|
||||
"permission": "only_me",
|
||||
"data_source_type": null,
|
||||
"indexing_technique": null,
|
||||
"app_count": 0,
|
||||
"document_count": 0,
|
||||
"word_count": 0,
|
||||
"created_by": "",
|
||||
"created_at": 1695636173,
|
||||
"updated_by": "",
|
||||
"updated_at": 1695636173,
|
||||
"embedding_model": null,
|
||||
"embedding_model_provider": null,
|
||||
"embedding_available": null
|
||||
}
|
||||
```
|
||||
</CodeGroup>
|
||||
</Col>
|
||||
</Row>
|
||||
|
||||
---
|
||||
<div>
|
||||
### 鉴权
|
||||
|
||||
Dify Service API 使用 `API-Key` 进行鉴权。
|
||||
|
||||
建议开发者把 `API-Key` 放在后端存储,而非分享或者放在客户端存储,以免 `API-Key` 泄露,导致财产损失。
|
||||
|
||||
所有 API 请求都应在 **`Authorization`** HTTP Header 中包含您的 `API-Key`,如下所示:
|
||||
|
||||
<CodeGroup title="Code">
|
||||
```javascript
|
||||
Authorization: Bearer {API_KEY}
|
||||
|
||||
<Heading
|
||||
url='/datasets'
|
||||
method='GET'
|
||||
title='数据集列表'
|
||||
name='#dataset_list'
|
||||
/>
|
||||
<Row>
|
||||
<Col>
|
||||
### Query
|
||||
<Properties>
|
||||
<Property name='page' type='string' key='page'>
|
||||
页码
|
||||
</Property>
|
||||
<Property name='limit' type='string' key='limit'>
|
||||
返回条数,默认 20,范围 1-100
|
||||
</Property>
|
||||
</Properties>
|
||||
</Col>
|
||||
<Col sticky>
|
||||
<CodeGroup
|
||||
title="Request"
|
||||
tag="POST"
|
||||
label="/datasets"
|
||||
targetCode={`curl --location --request GET '${props.apiBaseUrl}/datasets?page=1&limit=20' \\\n--header 'Authorization: Bearer {api_key}'`}
|
||||
>
|
||||
```bash {{ title: 'cURL' }}
|
||||
curl --location --request GET '${props.apiBaseUrl}/datasets?page=1&limit=20' \
|
||||
--header 'Authorization: Bearer {api_key}'
|
||||
```
|
||||
</CodeGroup>
|
||||
<CodeGroup title="Response">
|
||||
```json {{ title: 'Response' }}
|
||||
{
|
||||
"data": [
|
||||
{
|
||||
"id": "",
|
||||
"name": "数据集名称",
|
||||
"description": "描述信息",
|
||||
"permission": "only_me",
|
||||
"data_source_type": "upload_file",
|
||||
"indexing_technique": "",
|
||||
"app_count": 2,
|
||||
"document_count": 10,
|
||||
"word_count": 1200,
|
||||
"created_by": "",
|
||||
"created_at": "",
|
||||
"updated_by": "",
|
||||
"updated_at": ""
|
||||
},
|
||||
...
|
||||
],
|
||||
"has_more": true,
|
||||
"limit": 20,
|
||||
"total": 50,
|
||||
"page": 1
|
||||
}
|
||||
```
|
||||
</CodeGroup>
|
||||
</Col>
|
||||
</Row>
|
||||
</CodeGroup>
|
||||
</div>
|
||||
|
||||
---
|
||||
|
||||
@ -329,6 +225,128 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
||||
|
||||
---
|
||||
|
||||
<Heading
|
||||
url='/datasets'
|
||||
method='POST'
|
||||
title='创建空数据集'
|
||||
name='#create_empty_dataset'
|
||||
/>
|
||||
<Row>
|
||||
<Col>
|
||||
### Request Body
|
||||
<Properties>
|
||||
<Property name='name' type='string' key='name'>
|
||||
数据集名称
|
||||
</Property>
|
||||
</Properties>
|
||||
</Col>
|
||||
<Col sticky>
|
||||
<CodeGroup
|
||||
title="Request"
|
||||
tag="POST"
|
||||
label="/datasets"
|
||||
targetCode={`curl --location --request POST '${props.apiBaseUrl}/datasets' \\\n--header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{"name": "name"}'`}
|
||||
>
|
||||
```bash {{ title: 'cURL' }}
|
||||
curl --location --request POST '${props.apiBaseUrl}/datasets' \
|
||||
--header 'Authorization: Bearer {api_key}' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data-raw '{
|
||||
"name": "name"
|
||||
}'
|
||||
```
|
||||
</CodeGroup>
|
||||
<CodeGroup title="Response">
|
||||
```json {{ title: 'Response' }}
|
||||
{
|
||||
"id": "",
|
||||
"name": "name",
|
||||
"description": null,
|
||||
"provider": "vendor",
|
||||
"permission": "only_me",
|
||||
"data_source_type": null,
|
||||
"indexing_technique": null,
|
||||
"app_count": 0,
|
||||
"document_count": 0,
|
||||
"word_count": 0,
|
||||
"created_by": "",
|
||||
"created_at": 1695636173,
|
||||
"updated_by": "",
|
||||
"updated_at": 1695636173,
|
||||
"embedding_model": null,
|
||||
"embedding_model_provider": null,
|
||||
"embedding_available": null
|
||||
}
|
||||
```
|
||||
</CodeGroup>
|
||||
</Col>
|
||||
</Row>
|
||||
|
||||
---
|
||||
|
||||
<Heading
|
||||
url='/datasets'
|
||||
method='GET'
|
||||
title='数据集列表'
|
||||
name='#dataset_list'
|
||||
/>
|
||||
<Row>
|
||||
<Col>
|
||||
### Query
|
||||
<Properties>
|
||||
<Property name='page' type='string' key='page'>
|
||||
页码
|
||||
</Property>
|
||||
<Property name='limit' type='string' key='limit'>
|
||||
返回条数,默认 20,范围 1-100
|
||||
</Property>
|
||||
</Properties>
|
||||
</Col>
|
||||
<Col sticky>
|
||||
<CodeGroup
|
||||
title="Request"
|
||||
tag="POST"
|
||||
label="/datasets"
|
||||
targetCode={`curl --location --request GET '${props.apiBaseUrl}/datasets?page=1&limit=20' \\\n--header 'Authorization: Bearer {api_key}'`}
|
||||
>
|
||||
```bash {{ title: 'cURL' }}
|
||||
curl --location --request GET '${props.apiBaseUrl}/datasets?page=1&limit=20' \
|
||||
--header 'Authorization: Bearer {api_key}'
|
||||
```
|
||||
</CodeGroup>
|
||||
<CodeGroup title="Response">
|
||||
```json {{ title: 'Response' }}
|
||||
{
|
||||
"data": [
|
||||
{
|
||||
"id": "",
|
||||
"name": "数据集名称",
|
||||
"description": "描述信息",
|
||||
"permission": "only_me",
|
||||
"data_source_type": "upload_file",
|
||||
"indexing_technique": "",
|
||||
"app_count": 2,
|
||||
"document_count": 10,
|
||||
"word_count": 1200,
|
||||
"created_by": "",
|
||||
"created_at": "",
|
||||
"updated_by": "",
|
||||
"updated_at": ""
|
||||
},
|
||||
...
|
||||
],
|
||||
"has_more": true,
|
||||
"limit": 20,
|
||||
"total": 50,
|
||||
"page": 1
|
||||
}
|
||||
```
|
||||
</CodeGroup>
|
||||
</Col>
|
||||
</Row>
|
||||
|
||||
---
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/documents/{document_id}/update_by_text'
|
||||
method='POST'
|
||||
|
||||
@ -53,6 +53,7 @@ export type IChatProps = {
|
||||
isShowConfigElem?: boolean
|
||||
dataSets?: DataSet[]
|
||||
isShowCitationHitInfo?: boolean
|
||||
isShowPromptLog?: boolean
|
||||
}
|
||||
|
||||
const Chat: FC<IChatProps> = ({
|
||||
@ -81,6 +82,7 @@ const Chat: FC<IChatProps> = ({
|
||||
isShowConfigElem,
|
||||
dataSets,
|
||||
isShowCitationHitInfo,
|
||||
isShowPromptLog,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const { notify } = useContext(ToastContext)
|
||||
@ -186,7 +188,18 @@ const Chat: FC<IChatProps> = ({
|
||||
isShowCitationHitInfo={isShowCitationHitInfo}
|
||||
/>
|
||||
}
|
||||
return <Question key={item.id} id={item.id} content={item.content} more={item.more} useCurrentUserAvatar={useCurrentUserAvatar} />
|
||||
return (
|
||||
<Question
|
||||
key={item.id}
|
||||
id={item.id}
|
||||
content={item.content}
|
||||
more={item.more}
|
||||
useCurrentUserAvatar={useCurrentUserAvatar}
|
||||
item={item}
|
||||
isShowPromptLog={isShowPromptLog}
|
||||
isResponsing={isResponsing}
|
||||
/>
|
||||
)
|
||||
})}
|
||||
</div>
|
||||
{
|
||||
|
||||
70
web/app/components/app/chat/log/index.tsx
Normal file
70
web/app/components/app/chat/log/index.tsx
Normal file
@ -0,0 +1,70 @@
|
||||
import type { Dispatch, FC, ReactNode, RefObject, SetStateAction } from 'react'
|
||||
import { useEffect, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { File02 } from '@/app/components/base/icons/src/vender/line/files'
|
||||
import PromptLogModal from '@/app/components/base/prompt-log-modal'
|
||||
import Tooltip from '@/app/components/base/tooltip'
|
||||
|
||||
export type LogData = {
|
||||
role: string
|
||||
text: string
|
||||
}
|
||||
|
||||
type LogProps = {
|
||||
containerRef: RefObject<HTMLElement>
|
||||
log: LogData[]
|
||||
children?: (v: Dispatch<SetStateAction<boolean>>) => ReactNode
|
||||
}
|
||||
const Log: FC<LogProps> = ({
|
||||
containerRef,
|
||||
children,
|
||||
log,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const [showModal, setShowModal] = useState(false)
|
||||
const [width, setWidth] = useState(0)
|
||||
|
||||
const adjustModalWidth = () => {
|
||||
if (containerRef.current)
|
||||
setWidth(document.body.clientWidth - (containerRef.current?.clientWidth + 56 + 16))
|
||||
}
|
||||
|
||||
useEffect(() => {
|
||||
adjustModalWidth()
|
||||
}, [])
|
||||
|
||||
return (
|
||||
<>
|
||||
{
|
||||
children
|
||||
? children(setShowModal)
|
||||
: (
|
||||
<Tooltip selector='prompt-log-modal-trigger' content={t('common.operation.log') || ''}>
|
||||
<div className={`
|
||||
hidden absolute -left-[14px] -top-[14px] group-hover:block w-7 h-7
|
||||
p-0.5 rounded-lg border-[0.5px] border-gray-100 bg-white shadow-md cursor-pointer
|
||||
`}>
|
||||
<div
|
||||
className='flex items-center justify-center rounded-md w-full h-full hover:bg-gray-100'
|
||||
onClick={() => setShowModal(true)}
|
||||
>
|
||||
<File02 className='w-4 h-4 text-gray-500' />
|
||||
</div>
|
||||
</div>
|
||||
</Tooltip>
|
||||
)
|
||||
}
|
||||
{
|
||||
showModal && (
|
||||
<PromptLogModal
|
||||
width={width}
|
||||
log={log}
|
||||
onCancel={() => setShowModal(false)}
|
||||
/>
|
||||
)
|
||||
}
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
export default Log
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user