Compare commits

...

43 Commits
0.4.1 ... 0.4.3

Author SHA1 Message Date
97c972f14d feat: bump version 0.4.3 (#1930) 2024-01-04 21:16:47 +08:00
3fa5204b0c feat: optimize performance (#1928) 2024-01-04 20:48:54 +08:00
5a756ca981 fix: xinference cache (#1926) 2024-01-04 20:39:58 +08:00
01f9feff9f fix a typo in file agent_app_runner.py (#1927) 2024-01-04 20:39:06 +08:00
2757494265 alter schedule timedelta (#1923)
Co-authored-by: jyong <jyong@dify.ai>
2024-01-04 18:10:16 +08:00
b88a8f7bb1 feat: optimize invoke errors (#1922) 2024-01-04 17:49:55 +08:00
b4225bedb5 fix: app create raise error when no available model providers (#1921) 2024-01-04 17:33:26 +08:00
a82b4d315a Fix comparison bug in ApplicationQueueManager (#1919) 2024-01-04 17:33:08 +08:00
3d92784bd4 fix: email template style (#1914) 2024-01-04 16:53:11 +08:00
c06e766d7e feat: model parameter prefefined (#1917) 2024-01-04 16:46:51 +08:00
4a3d15b6de fix customer spliter character (#1915)
Co-authored-by: jyong <jyong@dify.ai>
2024-01-04 16:21:48 +08:00
a798dcfae9 web: Add style CI workflow to enforce eslint checks on web module (#1910) 2024-01-04 15:37:51 +08:00
b4a170cb8a ci: Properly cache pip packages (#1912) 2024-01-04 15:31:07 +08:00
665318da3d fix: remove useless code. (#1913) 2024-01-04 15:27:05 +08:00
66cdf577f5 fix: model quota format (#1909) 2024-01-04 14:51:26 +08:00
891218615e fix: window size changed causes result regeneration (#1908) 2024-01-04 14:07:38 +08:00
a938e1f184 fix: notion_indexing_estimate embedding_model_instance NPE (#1907) 2024-01-04 13:28:52 +08:00
7c7ee633c1 fix: spark credentials validate (#1906) 2024-01-04 13:20:45 +08:00
18af84e193 fix: array oob in azure openai embeddings (#1905) 2024-01-04 13:11:54 +08:00
025b859c7e fix: tongyi generate error (#1904) 2024-01-04 12:57:45 +08:00
0e239a4f71 fix: read file encoding error (#1902)
Co-authored-by: maple <1071520@gi>
2024-01-04 12:52:10 +08:00
ca85b0afbe fix: remove useless code (#1903) 2024-01-04 11:10:20 +08:00
a0a9461f79 Fix/add qdrant timeout default value (#1901)
Co-authored-by: jyong <jyong@dify.ai>
2024-01-04 10:58:47 +08:00
6a2eb5f442 fix: customize model schema fetch failed raise error (#1900) 2024-01-04 10:53:50 +08:00
0c5892bcb6 fix: zhipuai chatglm turbo prompts must user, assistant in sequence (#1899) 2024-01-04 10:39:21 +08:00
91ff07fcf7 bump version to 0.4.2 (#1898) 2024-01-04 01:35:07 +08:00
bb7af56e69 fix: zhipuai history format wrong (#1897) 2024-01-04 01:30:23 +08:00
77f9e8ce0f add example api url endpoint in placeholder (#1887)
Co-authored-by: takatost <takatost@gmail.com>
2024-01-04 01:16:51 +08:00
5ca4c4a44d add qdrant client timeout limit (#1894)
Co-authored-by: jyong <jyong@dify.ai>
2024-01-03 22:23:04 +08:00
a44022c388 Grammar fix (#1892) 2024-01-03 22:13:12 +08:00
6333cf43a8 fix: anthropic messages empty raise errors (#1893) 2024-01-03 22:12:14 +08:00
91ee62d1ab fix: huggingface and replicate. (#1888) 2024-01-03 18:29:44 +08:00
ede69b4659 fix: gemini block error (#1877)
Co-authored-by: chenhe <guchenhe@gmail.com>
2024-01-03 17:45:15 +08:00
61aaeff413 Fix variable name in AgentApplicationRunner (#1884) 2024-01-03 17:44:41 +08:00
4e1cd75f6f fix: model parameter stop sequence (#1885) 2024-01-03 17:15:29 +08:00
a8ff2e95da fix: model parameter modal initial value (#1883) 2024-01-03 17:10:37 +08:00
4d502ea44d fix: openai embedding list out of bound (#1879) 2024-01-03 15:30:22 +08:00
66b3588897 doc: Respect and prevent updating existed yarn lockfile when installing dependencies (#1871) 2024-01-03 15:27:19 +08:00
9134849744 fix: remove tiktoken from text splitter (#1876) 2024-01-03 13:02:56 +08:00
fcf8512956 fix: more like this. (#1875) 2024-01-03 12:51:19 +08:00
ae975b10e9 fix: openai origin credential not start with { (#1874) 2024-01-03 12:10:43 +08:00
b43f1441a9 Fix/model runtime (#1873) 2024-01-03 11:36:57 +08:00
5a2aa83030 fix: ciphertext error (#1872) 2024-01-03 11:20:46 +08:00
113 changed files with 936 additions and 1280 deletions

View File

@ -31,28 +31,19 @@ jobs:
HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL: c
MOCK_SWITCH: true
steps:
- name: Checkout code
uses: actions/checkout@v2
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v2
uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: Cache pip dependencies
uses: actions/cache@v2
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ hashFiles('api/requirements.txt') }}
restore-keys: ${{ runner.os }}-pip-
cache: 'pip'
cache-dependency-path: ./api/requirements.txt
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pytest
pip install -r api/requirements.txt
run: pip install -r ./api/requirements.txt
- name: Run pytest
run: pytest api/tests/integration_tests/model_runtime/anthropic api/tests/integration_tests/model_runtime/azure_openai api/tests/integration_tests/model_runtime/openai api/tests/integration_tests/model_runtime/chatglm api/tests/integration_tests/model_runtime/google api/tests/integration_tests/model_runtime/xinference api/tests/integration_tests/model_runtime/huggingface_hub/test_llm.py

34
.github/workflows/style.yml vendored Normal file
View File

@ -0,0 +1,34 @@
name: Style check
on:
pull_request:
branches:
- main
push:
branches:
- deploy/dev
jobs:
test:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Setup NodeJS
uses: actions/setup-node@v4
with:
node-version: 18
cache: yarn
cache-dependency-path: ./web/package.json
- name: Web dependencies
run: |
cd ./web
yarn install --frozen-lockfile
- name: Web style check
run: |
cd ./web
yarn run lint

View File

@ -91,7 +91,7 @@ After running, you can access the Dify dashboard in your browser at [http://loca
### Helm Chart
A big thanks to @BorisPolonsky for providing us with a [Helm Chart](https://helm.sh/) version, which allows Dify to be deployed on Kubernetes.
Big thanks to @BorisPolonsky for providing us with a [Helm Chart](https://helm.sh/) version, which allows Dify to be deployed on Kubernetes.
You can go to https://github.com/BorisPolonsky/dify-helm for deployment information.
### Configuration

View File

@ -65,6 +65,7 @@ WEAVIATE_BATCH_SIZE=100
# Qdrant configuration, use `http://localhost:6333` for local mode or `https://your-qdrant-cluster-url.qdrant.io` for remote mode
QDRANT_URL=http://localhost:6333
QDRANT_API_KEY=difyai123456
QDRANT_CLIENT_TIMEOUT=20
# Milvus configuration
MILVUS_HOST=127.0.0.1

View File

@ -36,6 +36,7 @@ DEFAULTS = {
'SENTRY_PROFILES_SAMPLE_RATE': 1.0,
'WEAVIATE_GRPC_ENABLED': 'True',
'WEAVIATE_BATCH_SIZE': 100,
'QDRANT_CLIENT_TIMEOUT': 20,
'CELERY_BACKEND': 'database',
'LOG_LEVEL': 'INFO',
'HOSTED_OPENAI_QUOTA_LIMIT': 200,
@ -87,7 +88,7 @@ class Config:
# ------------------------
# General Configurations.
# ------------------------
self.CURRENT_VERSION = "0.4.1"
self.CURRENT_VERSION = "0.4.3"
self.COMMIT_SHA = get_env('COMMIT_SHA')
self.EDITION = "SELF_HOSTED"
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
@ -197,6 +198,7 @@ class Config:
# qdrant settings
self.QDRANT_URL = get_env('QDRANT_URL')
self.QDRANT_API_KEY = get_env('QDRANT_API_KEY')
self.QDRANT_CLIENT_TIMEOUT = get_env('QDRANT_CLIENT_TIMEOUT')
# milvus / zilliz setting
self.MILVUS_HOST = get_env('MILVUS_HOST')

View File

@ -141,15 +141,9 @@ class AppListApi(Resource):
model_type=ModelType.LLM
)
except ProviderTokenNotInitError:
raise ProviderNotInitializeError(
f"No Default System Reasoning Model available. Please configure "
f"in the Settings -> Model Provider.")
model_instance = None
if not model_instance:
raise ProviderNotInitializeError(
f"No Default System Reasoning Model available. Please configure "
f"in the Settings -> Model Provider.")
else:
if model_instance:
model_dict = app_model_config.model_dict
model_dict['provider'] = model_instance.provider
model_dict['name'] = model_instance.model

View File

@ -58,7 +58,7 @@ class ChatMessageAudioApi(Resource):
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(str(e))
raise CompletionRequestError(e.description)
except ValueError as e:
raise e
except Exception as e:

View File

@ -78,7 +78,7 @@ class CompletionMessageApi(Resource):
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(str(e))
raise CompletionRequestError(e.description)
except ValueError as e:
raise e
except Exception as e:
@ -153,7 +153,7 @@ class ChatMessageApi(Resource):
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(str(e))
raise CompletionRequestError(e.description)
except ValueError as e:
raise e
except Exception as e:

View File

@ -38,7 +38,7 @@ class RuleGenerateApi(Resource):
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(str(e))
raise CompletionRequestError(e.description)
return rules

View File

@ -228,7 +228,7 @@ class MessageMoreLikeThisApi(Resource):
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(str(e))
raise CompletionRequestError(e.description)
except ValueError as e:
raise e
except Exception as e:
@ -256,7 +256,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
yield "data: " + json.dumps(
api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
except InvokeError as e:
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(e.description)).get_json()) + "\n\n"
except ValueError as e:
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
except Exception:
@ -296,7 +296,7 @@ class MessageSuggestedQuestionApi(Resource):
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(str(e))
raise CompletionRequestError(e.description)
except Exception:
logging.exception("internal server error.")
raise InternalServerError()

View File

@ -54,7 +54,7 @@ class ChatAudioApi(InstalledAppResource):
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(str(e))
raise CompletionRequestError(e.description)
except ValueError as e:
raise e
except Exception as e:

View File

@ -70,7 +70,7 @@ class CompletionApi(InstalledAppResource):
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(str(e))
raise CompletionRequestError(e.description)
except ValueError as e:
raise e
except Exception as e:
@ -134,7 +134,7 @@ class ChatApi(InstalledAppResource):
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(str(e))
raise CompletionRequestError(e.description)
except ValueError as e:
raise e
except Exception as e:
@ -175,7 +175,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
except ModelCurrentlyNotSupportError:
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
except InvokeError as e:
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(e.description)).get_json()) + "\n\n"
except ValueError as e:
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
except Exception:

View File

@ -104,7 +104,7 @@ class MessageMoreLikeThisApi(InstalledAppResource):
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(str(e))
raise CompletionRequestError(e.description)
except ValueError as e:
raise e
except Exception:
@ -131,7 +131,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
except ModelCurrentlyNotSupportError:
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
except InvokeError as e:
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(e.description)).get_json()) + "\n\n"
except ValueError as e:
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
except Exception:
@ -169,7 +169,7 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(str(e))
raise CompletionRequestError(e.description)
except Exception:
logging.exception("internal server error.")
raise InternalServerError()

View File

@ -54,7 +54,7 @@ class UniversalChatAudioApi(UniversalChatResource):
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(str(e))
raise CompletionRequestError(e.description)
except ValueError as e:
raise e
except Exception as e:

View File

@ -89,7 +89,7 @@ class UniversalChatApi(UniversalChatResource):
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(str(e))
raise CompletionRequestError(e.description)
except ValueError as e:
raise e
except Exception as e:
@ -126,7 +126,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
except ModelCurrentlyNotSupportError:
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
except InvokeError as e:
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(e.description)).get_json()) + "\n\n"
except ValueError as e:
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
except Exception:

View File

@ -133,7 +133,7 @@ class UniversalChatMessageSuggestedQuestionApi(UniversalChatResource):
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(str(e))
raise CompletionRequestError(e.description)
except Exception:
logging.exception("internal server error.")
raise InternalServerError()

View File

@ -50,7 +50,7 @@ class AudioApi(AppApiResource):
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(str(e))
raise CompletionRequestError(e.description)
except ValueError as e:
raise e
except Exception as e:

View File

@ -67,7 +67,7 @@ class CompletionApi(AppApiResource):
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(str(e))
raise CompletionRequestError(e.description)
except ValueError as e:
raise e
except Exception as e:
@ -131,7 +131,7 @@ class ChatApi(AppApiResource):
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(str(e))
raise CompletionRequestError(e.description)
except ValueError as e:
raise e
except Exception as e:
@ -171,7 +171,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
except ModelCurrentlyNotSupportError:
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
except InvokeError as e:
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(e.description)).get_json()) + "\n\n"
except ValueError as e:
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
except Exception:

View File

@ -52,7 +52,7 @@ class AudioApi(WebApiResource):
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(str(e))
raise CompletionRequestError(e.description)
except ValueError as e:
raise e
except Exception as e:

View File

@ -64,7 +64,7 @@ class CompletionApi(WebApiResource):
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(str(e))
raise CompletionRequestError(e.description)
except ValueError as e:
raise e
except Exception as e:
@ -124,7 +124,7 @@ class ChatApi(WebApiResource):
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(str(e))
raise CompletionRequestError(e.description)
except ValueError as e:
raise e
except Exception as e:
@ -164,7 +164,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
except ModelCurrentlyNotSupportError:
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
except InvokeError as e:
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(e.description)).get_json()) + "\n\n"
except ValueError as e:
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
except Exception:

View File

@ -138,7 +138,7 @@ class MessageMoreLikeThisApi(WebApiResource):
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(str(e))
raise CompletionRequestError(e.description)
except ValueError as e:
raise e
except Exception:
@ -165,7 +165,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
except ModelCurrentlyNotSupportError:
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
except InvokeError as e:
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(e.description)).get_json()) + "\n\n"
except ValueError as e:
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
except Exception:
@ -202,7 +202,7 @@ class MessageSuggestedQuestionApi(WebApiResource):
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(str(e))
raise CompletionRequestError(e.description)
except Exception:
logging.exception("internal server error.")
raise InternalServerError()

View File

@ -75,7 +75,7 @@ class AgentApplicationRunner(AppRunner):
# reorganize all inputs and template to prompt messages
# Include: prompt template, inputs, query(optional), files(optional)
# memory(optional)
prompt_messages, stop = self.originze_prompt_messages(
prompt_messages, stop = self.organize_prompt_messages(
app_record=app_record,
model_config=app_orchestration_config.model_config,
prompt_template_entity=app_orchestration_config.prompt_template,
@ -153,7 +153,7 @@ class AgentApplicationRunner(AppRunner):
# reorganize all inputs and template to prompt messages
# Include: prompt template, inputs, query(optional), files(optional)
# memory(optional), external data, dataset context(optional)
prompt_messages, stop = self.originze_prompt_messages(
prompt_messages, stop = self.organize_prompt_messages(
app_record=app_record,
model_config=app_orchestration_config.model_config,
prompt_template_entity=app_orchestration_config.prompt_template,
@ -237,8 +237,8 @@ class AgentApplicationRunner(AppRunner):
all_message_tokens = 0
all_answer_tokens = 0
for agent_thought in agent_thoughts:
all_message_tokens += agent_thought.message_tokens
all_answer_tokens += agent_thought.answer_tokens
all_message_tokens += agent_thought.message_token
all_answer_tokens += agent_thought.answer_token
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)

View File

@ -50,7 +50,7 @@ class AppRunner:
max_tokens = 0
# get prompt messages without memory and context
prompt_messages, stop = self.originze_prompt_messages(
prompt_messages, stop = self.organize_prompt_messages(
app_record=app_record,
model_config=model_config,
prompt_template_entity=prompt_template_entity,
@ -107,7 +107,7 @@ class AppRunner:
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
model_config.parameters[parameter_rule.name] = max_tokens
def originze_prompt_messages(self, app_record: App,
def organize_prompt_messages(self, app_record: App,
model_config: ModelConfigEntity,
prompt_template_entity: PromptTemplateEntity,
inputs: dict[str, str],

View File

@ -79,7 +79,7 @@ class BasicApplicationRunner(AppRunner):
# organize all inputs and template to prompt messages
# Include: prompt template, inputs, query(optional), files(optional)
# memory(optional)
prompt_messages, stop = self.originze_prompt_messages(
prompt_messages, stop = self.organize_prompt_messages(
app_record=app_record,
model_config=app_orchestration_config.model_config,
prompt_template_entity=app_orchestration_config.prompt_template,
@ -164,7 +164,7 @@ class BasicApplicationRunner(AppRunner):
# reorganize all inputs and template to prompt messages
# Include: prompt template, inputs, query(optional), files(optional)
# memory(optional), external data, dataset context(optional)
prompt_messages, stop = self.originze_prompt_messages(
prompt_messages, stop = self.organize_prompt_messages(
app_record=app_record,
model_config=app_orchestration_config.model_config,
prompt_template_entity=app_orchestration_config.prompt_template,

View File

@ -473,7 +473,7 @@ class ApplicationManager:
more_like_this_dict = copy_app_model_config_dict.get('more_like_this')
if more_like_this_dict:
if 'enabled' in more_like_this_dict and more_like_this_dict['enabled']:
properties['more_like_this'] = copy_app_model_config_dict.get('opening_statement')
properties['more_like_this'] = True
# speech to text
speech_to_text_dict = copy_app_model_config_dict.get('speech_to_text')

View File

@ -173,7 +173,7 @@ class ApplicationQueueManager:
return
user_prefix = 'account' if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user'
if result != f"{user_prefix}-{user_id}":
if result.decode('utf-8') != f"{user_prefix}-{user_id}":
return
stopped_cache_key = cls._generate_stopped_cache_key(task_id)

View File

@ -65,7 +65,8 @@ class FileExtractor:
elif file_extension == '.pdf':
loader = PdfLoader(file_path, upload_file=upload_file)
elif file_extension in ['.md', '.markdown']:
loader = UnstructuredMarkdownLoader(file_path, unstructured_api_url)
loader = UnstructuredMarkdownLoader(file_path, unstructured_api_url) if is_automatic \
else MarkdownLoader(file_path, autodetect_encoding=True)
elif file_extension in ['.htm', '.html']:
loader = HTMLLoader(file_path)
elif file_extension == '.docx':
@ -84,7 +85,8 @@ class FileExtractor:
loader = UnstructuredXmlLoader(file_path, unstructured_api_url)
else:
# txt
loader = UnstructuredTextLoader(file_path, unstructured_api_url)
loader = UnstructuredTextLoader(file_path, unstructured_api_url) if is_automatic \
else TextLoader(file_path, autodetect_encoding=True)
else:
if file_extension == '.xlsx':
loader = ExcelLoader(file_path)

View File

@ -1,5 +1,6 @@
import datetime
import json
import logging
import time
from json import JSONDecodeError
from typing import Optional, List, Dict, Tuple, Iterator
@ -9,6 +10,7 @@ from pydantic import BaseModel
from core.entities.model_entities import ModelWithProviderEntity, ModelStatus, SimpleModelProviderEntity
from core.entities.provider_entities import SystemConfiguration, CustomConfiguration, SystemConfigurationStatus
from core.helper import encrypter
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.provider_entities import ProviderEntity, CredentialFormSchema, FormType
from core.model_runtime.model_providers import model_provider_factory
@ -18,6 +20,8 @@ from core.model_runtime.utils import encoders
from extensions.ext_database import db
from models.provider import ProviderType, Provider, ProviderModel, TenantPreferredModelProvider
logger = logging.getLogger(__name__)
class ProviderConfiguration(BaseModel):
"""
@ -168,6 +172,14 @@ class ProviderConfiguration(BaseModel):
db.session.add(provider_record)
db.session.commit()
provider_model_credentials_cache = ProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=provider_record.id,
cache_type=ProviderCredentialsCacheType.PROVIDER
)
provider_model_credentials_cache.delete()
self.switch_preferred_provider_type(ProviderType.CUSTOM)
def delete_custom_credentials(self) -> None:
@ -190,6 +202,14 @@ class ProviderConfiguration(BaseModel):
db.session.delete(provider_record)
db.session.commit()
provider_model_credentials_cache = ProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=provider_record.id,
cache_type=ProviderCredentialsCacheType.PROVIDER
)
provider_model_credentials_cache.delete()
def get_custom_model_credentials(self, model_type: ModelType, model: str, obfuscated: bool = False) \
-> Optional[dict]:
"""
@ -311,6 +331,14 @@ class ProviderConfiguration(BaseModel):
db.session.add(provider_model_record)
db.session.commit()
provider_model_credentials_cache = ProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=provider_model_record.id,
cache_type=ProviderCredentialsCacheType.MODEL
)
provider_model_credentials_cache.delete()
def delete_custom_model_credentials(self, model_type: ModelType, model: str) -> None:
"""
Delete custom model credentials.
@ -332,6 +360,14 @@ class ProviderConfiguration(BaseModel):
db.session.delete(provider_model_record)
db.session.commit()
provider_model_credentials_cache = ProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=provider_model_record.id,
cache_type=ProviderCredentialsCacheType.MODEL
)
provider_model_credentials_cache.delete()
def get_provider_instance(self) -> ModelProvider:
"""
Get provider instance.
@ -544,13 +580,17 @@ class ProviderConfiguration(BaseModel):
if model_configuration.model_type not in model_types:
continue
custom_model_schema = (
provider_instance.get_model_instance(model_configuration.model_type)
.get_customizable_model_schema_from_credentials(
model_configuration.model,
model_configuration.credentials
try:
custom_model_schema = (
provider_instance.get_model_instance(model_configuration.model_type)
.get_customizable_model_schema_from_credentials(
model_configuration.model,
model_configuration.credentials
)
)
)
except Exception as ex:
logger.warning(f'get custom model schema failed, {ex}')
continue
if not custom_model_schema:
continue

View File

@ -61,7 +61,7 @@ class Extensible:
builtin_file_path = os.path.join(subdir_path, '__builtin__')
if os.path.exists(builtin_file_path):
with open(builtin_file_path, 'r') as f:
with open(builtin_file_path, 'r', encoding='utf-8') as f:
position = int(f.read().strip())
if (extension_name + '.py') not in file_names:
@ -93,7 +93,7 @@ class Extensible:
json_path = os.path.join(subdir_path, 'schema.json')
json_data = {}
if os.path.exists(json_path):
with open(json_path, 'r') as f:
with open(json_path, 'r', encoding='utf-8') as f:
json_data = json.load(f)
extensions[extension_name] = ModuleExtension(

View File

@ -1,35 +0,0 @@
{
"label": {
"en-US": "Weather Search",
"zh-Hans": "天气查询"
},
"form_schema": [
{
"type": "select",
"label": {
"en-US": "Temperature Unit",
"zh-Hans": "温度单位"
},
"variable": "temperature_unit",
"required": true,
"options": [
{
"label": {
"en-US": "Fahrenheit",
"zh-Hans": "华氏度"
},
"value": "fahrenheit"
},
{
"label": {
"en-US": "Centigrade",
"zh-Hans": "摄氏度"
},
"value": "centigrade"
}
],
"default": "centigrade",
"placeholder": "Please select temperature unit"
}
]
}

View File

@ -1,45 +0,0 @@
from typing import Optional
from core.external_data_tool.base import ExternalDataTool
class WeatherSearch(ExternalDataTool):
"""
The name of custom type must be unique, keep the same with directory and file name.
"""
name: str = "weather_search"
@classmethod
def validate_config(cls, tenant_id: str, config: dict) -> None:
"""
schema.json validation. It will be called when user save the config.
Example:
.. code-block:: python
config = {
"temperature_unit": "centigrade"
}
:param tenant_id: the id of workspace
:param config: the variables of form config
:return:
"""
if not config.get('temperature_unit'):
raise ValueError('temperature unit is required')
def query(self, inputs: dict, query: Optional[str] = None) -> str:
"""
Query the external data tool.
:param inputs: user inputs
:param query: the query of chat app
:return: the tool query result
"""
city = inputs.get('city')
temperature_unit = self.config.get('temperature_unit')
if temperature_unit == 'fahrenheit':
return f'Weather in {city} is 32°F'
else:
return f'Weather in {city} is 0°C'

View File

@ -0,0 +1,51 @@
import json
from enum import Enum
from json import JSONDecodeError
from typing import Optional
from extensions.ext_redis import redis_client
class ProviderCredentialsCacheType(Enum):
PROVIDER = "provider"
MODEL = "provider_model"
class ProviderCredentialsCache:
def __init__(self, tenant_id: str, identity_id: str, cache_type: ProviderCredentialsCacheType):
self.cache_key = f"{cache_type.value}_credentials:tenant_id:{tenant_id}:id:{identity_id}"
def get(self) -> Optional[dict]:
"""
Get cached model provider credentials.
:return:
"""
cached_provider_credentials = redis_client.get(self.cache_key)
if cached_provider_credentials:
try:
cached_provider_credentials = cached_provider_credentials.decode('utf-8')
cached_provider_credentials = json.loads(cached_provider_credentials)
except JSONDecodeError:
return None
return cached_provider_credentials
else:
return None
def set(self, credentials: dict) -> None:
"""
Cache model provider credentials.
:param credentials: provider credentials
:return:
"""
redis_client.setex(self.cache_key, 3600, json.dumps(credentials))
def delete(self) -> None:
"""
Delete cached model provider credentials.
:return:
"""
redis_client.delete(self.cache_key)

View File

@ -18,6 +18,7 @@ from models.dataset import Dataset, DatasetCollectionBinding
class QdrantConfig(BaseModel):
endpoint: str
api_key: Optional[str]
timeout: float = 20
root_path: Optional[str]
def to_qdrant_params(self):
@ -33,6 +34,7 @@ class QdrantConfig(BaseModel):
return {
'url': self.endpoint,
'api_key': self.api_key,
'timeout': self.timeout
}

View File

@ -49,7 +49,8 @@ class VectorIndex:
config=QdrantConfig(
endpoint=config.get('QDRANT_URL'),
api_key=config.get('QDRANT_API_KEY'),
root_path=current_app.root_path
root_path=current_app.root_path,
timeout=config.get('QDRANT_CLIENT_TIMEOUT')
),
embeddings=embeddings
)

View File

@ -5,12 +5,12 @@ import re
import threading
import time
import uuid
from typing import Optional, List, cast
from typing import Optional, List, cast, Type, Union, Literal, AbstractSet, Collection, Any
from flask import current_app, Flask
from flask_login import current_user
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
from langchain.text_splitter import TextSplitter, TS, TokenTextSplitter
from sqlalchemy.orm.exc import ObjectDeletedError
from core.data_loader.file_extractor import FileExtractor
@ -23,7 +23,8 @@ from core.errors.error import ProviderTokenNotInitError
from core.model_runtime.entities.model_entities import ModelType, PriceType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
from core.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter, EnhanceRecursiveCharacterTextSplitter
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
@ -58,7 +59,7 @@ class IndexingRunner:
first()
# load file
text_docs = self._load_data(dataset_document)
text_docs = self._load_data(dataset_document, processing_rule.mode == 'automatic')
# get splitter
splitter = self._get_splitter(processing_rule)
@ -112,15 +113,14 @@ class IndexingRunner:
for document_segment in document_segments:
db.session.delete(document_segment)
db.session.commit()
# load file
text_docs = self._load_data(dataset_document)
# get the process rule
processing_rule = db.session.query(DatasetProcessRule). \
filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
first()
# load file
text_docs = self._load_data(dataset_document, processing_rule.mode == 'automatic')
# get splitter
splitter = self._get_splitter(processing_rule)
@ -237,14 +237,15 @@ class IndexingRunner:
preview_texts = []
total_segments = 0
for file_detail in file_details:
# load data from file
text_docs = FileExtractor.load(file_detail)
processing_rule = DatasetProcessRule(
mode=tmp_processing_rule["mode"],
rules=json.dumps(tmp_processing_rule["rules"])
)
# load data from file
text_docs = FileExtractor.load(file_detail, is_automatic=processing_rule.mode == 'automatic')
# get splitter
splitter = self._get_splitter(processing_rule)
@ -381,13 +382,15 @@ class IndexingRunner:
)
total_segments += len(documents)
embedding_model_type_instance = embedding_model_instance.model_type_instance
embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
embedding_model_type_instance = None
if embedding_model_instance:
embedding_model_type_instance = embedding_model_instance.model_type_instance
embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
for document in documents:
if len(preview_texts) < 5:
preview_texts.append(document.page_content)
if indexing_technique == 'high_quality' or embedding_model_instance:
if indexing_technique == 'high_quality' and embedding_model_type_instance:
tokens += embedding_model_type_instance.get_num_tokens(
model=embedding_model_instance.model,
credentials=embedding_model_instance.credentials,
@ -456,7 +459,7 @@ class IndexingRunner:
one_or_none()
if file_detail:
text_docs = FileExtractor.load(file_detail, is_automatic=True)
text_docs = FileExtractor.load(file_detail, is_automatic=automatic)
elif dataset_document.data_source_type == 'notion_import':
loader = NotionLoader.from_document(dataset_document)
text_docs = loader.load()
@ -502,7 +505,8 @@ class IndexingRunner:
if separator:
separator = separator.replace('\\n', '\n')
character_splitter = FixedRecursiveCharacterTextSplitter.from_tiktoken_encoder(
character_splitter = FixedRecursiveCharacterTextSplitter.from_gpt2_encoder(
chunk_size=segmentation["max_tokens"],
chunk_overlap=0,
fixed_separator=separator,
@ -510,7 +514,7 @@ class IndexingRunner:
)
else:
# Automatic segmentation
character_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
character_splitter = EnhanceRecursiveCharacterTextSplitter.from_gpt2_encoder(
chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'],
chunk_overlap=0,
separators=["\n\n", "", ".", " ", ""]

View File

@ -147,13 +147,13 @@ class AIModel(ABC):
# read _position.yaml file
position_map = {}
if os.path.exists(position_file_path):
with open(position_file_path, 'r') as f:
with open(position_file_path, 'r', encoding='utf-8') as f:
position_map = yaml.safe_load(f)
# traverse all model_schema_yaml_paths
for model_schema_yaml_path in model_schema_yaml_paths:
# read yaml data from yaml file
with open(model_schema_yaml_path, 'r') as f:
with open(model_schema_yaml_path, 'r', encoding='utf-8') as f:
yaml_data = yaml.safe_load(f)
new_parameter_rules = []
@ -236,16 +236,6 @@ class AIModel(ABC):
:param credentials: model credentials
:return: model schema
"""
if 'schema' in credentials:
schema_dict = json.loads(credentials['schema'])
try:
model_instance = AIModelEntity.parse_obj(schema_dict)
return model_instance
except ValidationError as e:
logging.exception(f"Invalid model schema for {model}")
return self._get_customizable_model_schema(model, credentials)
return self._get_customizable_model_schema(model, credentials)
def _get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:

View File

@ -132,8 +132,8 @@ class LargeLanguageModel(AIModel):
system_fingerprint = None
real_model = model
for chunk in result:
try:
try:
for chunk in result:
yield chunk
self._trigger_new_chunk_callbacks(
@ -156,8 +156,8 @@ class LargeLanguageModel(AIModel):
if chunk.system_fingerprint:
system_fingerprint = chunk.system_fingerprint
except Exception as e:
raise self._transform_invoke_error(e)
except Exception as e:
raise self._transform_invoke_error(e)
self._trigger_after_invoke_callbacks(
model=model,

View File

@ -47,7 +47,7 @@ class ModelProvider(ABC):
yaml_path = os.path.join(current_path, f'{provider_name}.yaml')
yaml_data = {}
if os.path.exists(yaml_path):
with open(yaml_path, 'r') as f:
with open(yaml_path, 'r', encoding='utf-8') as f:
yaml_data = yaml.safe_load(f)
try:

View File

@ -252,6 +252,9 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
:param messages: List of PromptMessage to combine.
:return: Combined string with necessary human_prompt and ai_prompt tags.
"""
if not messages:
return ''
messages = messages.copy() # don't mutate the original list
if not isinstance(messages[-1], AssistantPromptMessage):
messages.append(AssistantPromptMessage(content=""))

View File

@ -54,7 +54,7 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
_iter = range(0, len(tokens), max_chunks)
for i in _iter:
embeddings, embedding_used_tokens = self._embedding_invoke(
embeddings_batch, embedding_used_tokens = self._embedding_invoke(
model=model,
client=client,
texts=tokens[i: i + max_chunks],
@ -62,7 +62,7 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
)
used_tokens += embedding_used_tokens
batched_embeddings += [data for data in embeddings]
batched_embeddings += embeddings_batch
results: list[list[list[float]]] = [[] for _ in range(len(texts))]
num_tokens_in_batch: list[list[int]] = [[] for _ in range(len(texts))]
@ -73,7 +73,7 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
for i in range(len(texts)):
_result = results[i]
if len(_result) == 0:
embeddings, embedding_used_tokens = self._embedding_invoke(
embeddings_batch, embedding_used_tokens = self._embedding_invoke(
model=model,
client=client,
texts=[""],
@ -81,7 +81,7 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
)
used_tokens += embedding_used_tokens
average = embeddings[0]
average = embeddings_batch[0]
else:
average = np.average(_result, axis=0, weights=num_tokens_in_batch[i])
embeddings[i] = (average / np.linalg.norm(average)).tolist()

View File

@ -3,6 +3,7 @@ from typing import Optional, Generator, Union, List
import google.generativeai as genai
import google.api_core.exceptions as exceptions
import google.generativeai.client as client
from google.generativeai.types import HarmCategory, HarmBlockThreshold
from google.generativeai.types import GenerateContentResponse, ContentType
from google.generativeai.types.content_types import to_part
@ -124,7 +125,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
last_msg = prompt_messages[-1]
content = self._format_message_to_glm_content(last_msg)
history.append(content)
else:
else:
for msg in prompt_messages: # makes message roles strictly alternating
content = self._format_message_to_glm_content(msg)
if history and history[-1]["role"] == content["role"]:
@ -139,13 +140,21 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
new_custom_client = new_client_manager.make_client("generative")
google_model._client = new_custom_client
safety_settings={
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
}
response = google_model.generate_content(
contents=history,
generation_config=genai.types.GenerationConfig(
**config_kwargs
),
stream=stream
stream=stream,
safety_settings=safety_settings
)
if stream:
@ -169,7 +178,6 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
content=response.text
)
# calculate num tokens
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
@ -202,11 +210,11 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
for chunk in response:
content = chunk.text
index += 1
assistant_prompt_message = AssistantPromptMessage(
content=content if content else '',
)
if not response._done:
# transform assistant message to prompt message

View File

@ -154,20 +154,31 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
content=chunk.token.text
)
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
if chunk.details:
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=assistant_prompt_message,
usage=usage,
),
)
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=assistant_prompt_message,
usage=usage,
finish_reason=chunk.details.finish_reason,
),
)
else:
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=assistant_prompt_message,
),
)
def _handle_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], response: any) -> LLMResult:
if isinstance(response, str):

View File

@ -1,7 +1,7 @@
from typing import Generator, List, Optional, Union, cast
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage, LLMResultChunk, LLMResultChunkDelta, LLMMode
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool, AssistantPromptMessage, UserPromptMessage, SystemPromptMessage
from core.model_runtime.entities.model_entities import AIModelEntity, ParameterRule, ParameterType, FetchFrom, ModelType
from core.model_runtime.entities.model_entities import AIModelEntity, ParameterRule, ParameterType, FetchFrom, ModelType, ModelPropertyKey
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.errors.invoke import InvokeConnectionError, InvokeServerUnavailableError, InvokeRateLimitError, \
@ -156,9 +156,9 @@ class LocalAILarguageModel(LargeLanguageModel):
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
completion_model = None
if credentials['completion_type'] == 'chat_completion':
completion_model = LLMMode.CHAT
completion_model = LLMMode.CHAT.value
elif credentials['completion_type'] == 'completion':
completion_model = LLMMode.COMPLETION
completion_model = LLMMode.COMPLETION.value
else:
raise ValueError(f"Unknown completion type {credentials['completion_type']}")
@ -202,7 +202,7 @@ class LocalAILarguageModel(LargeLanguageModel):
),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.LLM,
model_properties={ 'mode': completion_model } if completion_model else {},
model_properties={ ModelPropertyKey.MODE: completion_model } if completion_model else {},
parameter_rules=rules
)

View File

@ -30,6 +30,10 @@ class ModelProviderExtension(BaseModel):
class ModelProviderFactory:
model_provider_extensions: dict[str, ModelProviderExtension] = None
def __init__(self) -> None:
# for cache in memory
self.get_providers()
def get_providers(self) -> list[ProviderEntity]:
"""
Get all providers
@ -212,7 +216,7 @@ class ModelProviderFactory:
# read _position.yaml file
position_map = {}
if os.path.exists(position_file_path):
with open(position_file_path, 'r') as f:
with open(position_file_path, 'r', encoding='utf-8') as f:
position_map = yaml.safe_load(f)
# traverse all model_provider_dir_paths

View File

@ -68,7 +68,7 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel):
for i in _iter:
# call embedding model
embeddings, embedding_used_tokens = self._embedding_invoke(
embeddings_batch, embedding_used_tokens = self._embedding_invoke(
model=model,
client=client,
texts=tokens[i: i + max_chunks],
@ -76,7 +76,7 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel):
)
used_tokens += embedding_used_tokens
batched_embeddings += [data for data in embeddings]
batched_embeddings += embeddings_batch
results: list[list[list[float]]] = [[] for _ in range(len(texts))]
num_tokens_in_batch: list[list[int]] = [[] for _ in range(len(texts))]
@ -87,7 +87,7 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel):
for i in range(len(texts)):
_result = results[i]
if len(_result) == 0:
embeddings, embedding_used_tokens = self._embedding_invoke(
embeddings_batch, embedding_used_tokens = self._embedding_invoke(
model=model,
client=client,
texts=[""],
@ -95,7 +95,7 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel):
)
used_tokens += embedding_used_tokens
average = embeddings[0]
average = embeddings_batch[0]
else:
average = np.average(_result, axis=0, weights=num_tokens_in_batch[i])
embeddings[i] = (average / np.linalg.norm(average)).tolist()

View File

@ -117,9 +117,9 @@ class _CommonOAI_API_Compat:
if model_type == ModelType.LLM:
if credentials['mode'] == 'chat':
entity.model_properties[ModelPropertyKey.MODE] = LLMMode.CHAT
entity.model_properties[ModelPropertyKey.MODE] = LLMMode.CHAT.value
elif credentials['mode'] == 'completion':
entity.model_properties[ModelPropertyKey.MODE] = LLMMode.COMPLETION
entity.model_properties[ModelPropertyKey.MODE] = LLMMode.COMPLETION.value
else:
raise ValueError(f"Unknown completion type {credentials['completion_type']}")

View File

@ -1,19 +1,21 @@
import logging
from decimal import Decimal
from urllib.parse import urljoin
import requests
import json
from typing import Optional, Generator, Union, List, cast
from sympy import comp
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.utils import helper
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessage, AssistantPromptMessage, PromptMessageContent, \
PromptMessageContentType, PromptMessageFunction, PromptMessageTool, UserPromptMessage, SystemPromptMessage, ToolPromptMessage
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType, PriceConfig, ParameterRule, DefaultParameterName, \
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessage, \
AssistantPromptMessage, PromptMessageContent, \
PromptMessageContentType, PromptMessageFunction, PromptMessageTool, UserPromptMessage, SystemPromptMessage, \
ToolPromptMessage
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType, PriceConfig, ParameterRule, \
DefaultParameterName, \
ParameterType, ModelPropertyKey, FetchFrom, AIModelEntity
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.errors.invoke import InvokeError
@ -72,7 +74,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
:return:
"""
return self._num_tokens_from_messages(model, prompt_messages, tools)
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials using requests to ensure compatibility with all providers following OpenAI's API standard.
@ -91,6 +93,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
headers["Authorization"] = f"Bearer {api_key}"
endpoint_url = credentials['endpoint_url']
if not endpoint_url.endswith('/'):
endpoint_url += '/'
# prepare the payload for a simple ping to the model
data = {
@ -107,11 +111,13 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
"content": "ping"
},
]
endpoint_url = urljoin(endpoint_url, 'chat/completions')
elif completion_type is LLMMode.COMPLETION:
data['prompt'] = 'ping'
endpoint_url = urljoin(endpoint_url, 'completions')
else:
raise ValueError("Unsupported completion type for model configuration.")
# send a post request to validate the credentials
response = requests.post(
endpoint_url,
@ -121,8 +127,24 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
)
if response.status_code != 200:
raise CredentialsValidateFailedError(f'Credentials validation failed with status code {response.status_code}: {response.text}')
raise CredentialsValidateFailedError(
f'Credentials validation failed with status code {response.status_code}')
try:
json_result = response.json()
except json.JSONDecodeError as e:
raise CredentialsValidateFailedError(f'Credentials validation failed: JSON decode error')
if (completion_type is LLMMode.CHAT
and ('object' not in json_result or json_result['object'] != 'chat.completion')):
raise CredentialsValidateFailedError(
f'Credentials validation failed: invalid response object, must be \'chat.completion\'')
elif (completion_type is LLMMode.COMPLETION
and ('object' not in json_result or json_result['object'] != 'text_completion')):
raise CredentialsValidateFailedError(
f'Credentials validation failed: invalid response object, must be \'text_completion\'')
except CredentialsValidateFailedError:
raise
except Exception as ex:
raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {str(ex)}')
@ -136,8 +158,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
model_type=ModelType.LLM,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={
ModelPropertyKey.CONTEXT_SIZE: credentials.get('context_size'),
ModelPropertyKey.MODE: 'chat'
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')),
ModelPropertyKey.MODE: credentials.get('mode'),
},
parameter_rules=[
ParameterRule(
@ -199,11 +221,11 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
return entity
# validate_credentials method has been rewritten to use the requests library for compatibility with all providers following OpenAI's API standard.
def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, stream: bool = True, \
user: Optional[str] = None) -> Union[LLMResult, Generator]:
def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
stream: bool = True, \
user: Optional[str] = None) -> Union[LLMResult, Generator]:
"""
Invoke llm completion model
@ -225,7 +247,9 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
headers["Authorization"] = f"Bearer {api_key}"
endpoint_url = credentials["endpoint_url"]
if not endpoint_url.endswith('/'):
endpoint_url += '/'
data = {
"model": model,
"stream": stream,
@ -235,8 +259,10 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
completion_type = LLMMode.value_of(credentials['mode'])
if completion_type is LLMMode.CHAT:
endpoint_url = urljoin(endpoint_url, 'chat/completions')
data['messages'] = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
elif completion_type == LLMMode.COMPLETION:
endpoint_url = urljoin(endpoint_url, 'completions')
data['prompt'] = prompt_messages[0].content
else:
raise ValueError("Unsupported completion type for model configuration.")
@ -247,8 +273,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
data["tool_choice"] = "auto"
for tool in tools:
formatted_tools.append( helper.dump_model(PromptMessageFunction(function=tool)))
formatted_tools.append(helper.dump_model(PromptMessageFunction(function=tool)))
data["tools"] = formatted_tools
if stop:
@ -256,7 +282,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
if user:
data["user"] = user
response = requests.post(
endpoint_url,
headers=headers,
@ -277,8 +303,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
return self._handle_generate_response(model, credentials, response, prompt_messages)
def _handle_generate_stream_response(self, model: str, credentials: dict, response: requests.Response,
prompt_messages: list[PromptMessage]) -> Generator:
def _handle_generate_stream_response(self, model: str, credentials: dict, response: requests.Response,
prompt_messages: list[PromptMessage]) -> Generator:
"""
Handle llm stream response
@ -315,51 +341,64 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
if chunk:
decoded_chunk = chunk.decode('utf-8').strip().lstrip('data: ').lstrip()
chunk_json = None
try:
chunk_json = json.loads(decoded_chunk)
# stream ended
except json.JSONDecodeError as e:
yield create_final_llm_result_chunk(
index=chunk_index + 1,
index=chunk_index + 1,
message=AssistantPromptMessage(content=""),
finish_reason="Non-JSON encountered."
)
if len(chunk_json['choices']) == 0:
if not chunk_json or len(chunk_json['choices']) == 0:
continue
delta = chunk_json['choices'][0]['delta']
chunk_index = chunk_json['choices'][0]['index']
choice = chunk_json['choices'][0]
chunk_index = choice['index'] if 'index' in choice else chunk_index
if delta.get('finish_reason') is None and (delta.get('content') is None or delta.get('content') == ''):
if 'delta' in choice:
delta = choice['delta']
if delta.get('content') is None or delta.get('content') == '':
continue
assistant_message_tool_calls = delta.get('tool_calls', None)
# assistant_message_function_call = delta.delta.function_call
# extract tool calls from response
if assistant_message_tool_calls:
tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
# function_call = self._extract_response_function_call(assistant_message_function_call)
# tool_calls = [function_call] if function_call else []
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=delta.get('content', ''),
tool_calls=tool_calls if assistant_message_tool_calls else []
)
full_assistant_content += delta.get('content', '')
elif 'text' in choice:
if choice.get('text') is None or choice.get('text') == '':
continue
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=choice.get('text', '')
)
full_assistant_content += choice.get('text', '')
else:
continue
assistant_message_tool_calls = delta.get('tool_calls', None)
# assistant_message_function_call = delta.delta.function_call
# extract tool calls from response
if assistant_message_tool_calls:
tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
# function_call = self._extract_response_function_call(assistant_message_function_call)
# tool_calls = [function_call] if function_call else []
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=delta.get('content', ''),
tool_calls=tool_calls if assistant_message_tool_calls else []
)
full_assistant_content += delta.get('content', '')
# check payload indicator for completion
if chunk_json['choices'][0].get('finish_reason') is not None:
yield create_final_llm_result_chunk(
index=chunk_index,
message=assistant_prompt_message,
finish_reason=chunk_json['choices'][0]['finish_reason']
)
else:
yield LLMResultChunk(
model=model,
@ -375,10 +414,12 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
message=AssistantPromptMessage(content=""),
finish_reason="End of stream."
)
def _handle_generate_response(self, model: str, credentials: dict, response: requests.Response,
prompt_messages: list[PromptMessage]) -> LLMResult:
chunk_index += 1
def _handle_generate_response(self, model: str, credentials: dict, response: requests.Response,
prompt_messages: list[PromptMessage]) -> LLMResult:
response_json = response.json()
completion_type = LLMMode.value_of(credentials['mode'])
@ -457,7 +498,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
message = cast(AssistantPromptMessage, message)
message_dict = {"role": "assistant", "content": message.content}
if message.tool_calls:
message_dict["tool_calls"] = [helper.dump_model(PromptMessageFunction(function=tool_call)) for tool_call in
message_dict["tool_calls"] = [helper.dump_model(PromptMessageFunction(function=tool_call)) for tool_call
in
message.tool_calls]
# function_call = message.tool_calls[0]
# message_dict["function_call"] = {
@ -486,7 +528,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
message_dict["name"] = message.name
return message_dict
def _num_tokens_from_string(self, model: str, text: str,
tools: Optional[list[PromptMessageTool]] = None) -> int:
"""
@ -509,10 +551,10 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
"""
Approximate num tokens with GPT2 tokenizer.
"""
tokens_per_message = 3
tokens_per_name = 1
num_tokens = 0
messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages]
for message in messages_dict:
@ -601,7 +643,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
num_tokens += self._get_num_tokens_by_gpt2(required_field)
return num_tokens
def _extract_response_tool_calls(self,
response_tool_calls: list[dict]) \
-> list[AssistantPromptMessage.ToolCall]:

View File

@ -33,8 +33,8 @@ model_credential_schema:
type: text-input
required: true
placeholder:
zh_Hans: 在此输入您的 API endpoint URL
en_US: Enter your API endpoint URL
zh_Hans: Base URL, eg. https://api.openai.com/v1
en_US: Base URL, eg. https://api.openai.com/v1
- variable: mode
show_on:
- variable: __model_type

View File

@ -1,6 +1,7 @@
import time
from decimal import Decimal
from typing import Optional
from urllib.parse import urljoin
import requests
import json
@ -42,8 +43,11 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel):
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
endpoint_url = credentials.get('endpoint_url')
if not endpoint_url.endswith('/'):
endpoint_url += '/'
endpoint_url = credentials['endpoint_url']
endpoint_url = urljoin(endpoint_url, 'embeddings')
extra_model_kwargs = {}
if user:
@ -144,8 +148,11 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel):
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
endpoint_url = credentials.get('endpoint_url')
if not endpoint_url.endswith('/'):
endpoint_url += '/'
endpoint_url = credentials['endpoint_url']
endpoint_url = urljoin(endpoint_url, 'embeddings')
payload = {
'input': 'ping',
@ -160,8 +167,19 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel):
)
if response.status_code != 200:
raise CredentialsValidateFailedError(f"Invalid response status: {response.status_code}")
raise CredentialsValidateFailedError(
f'Credentials validation failed with status code {response.status_code}')
try:
json_result = response.json()
except json.JSONDecodeError as e:
raise CredentialsValidateFailedError(f'Credentials validation failed: JSON decode error')
if 'model' not in json_result:
raise CredentialsValidateFailedError(
f'Credentials validation failed: invalid response')
except CredentialsValidateFailedError:
raise
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@ -175,7 +193,7 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel):
model_type=ModelType.TEXT_EMBEDDING,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={
ModelPropertyKey.CONTEXT_SIZE: credentials.get('context_size'),
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')),
ModelPropertyKey.MAX_CHUNKS: 1,
},
parameter_rules=[],

View File

@ -6,7 +6,7 @@ from core.model_runtime.model_providers.openllm.llm.openllm_generate import Open
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage, LLMResultChunk, LLMResultChunkDelta, LLMMode
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool, AssistantPromptMessage, UserPromptMessage, SystemPromptMessage
from core.model_runtime.entities.model_entities import AIModelEntity, ParameterRule, ParameterType, FetchFrom, ModelType
from core.model_runtime.entities.model_entities import AIModelEntity, ParameterRule, ParameterType, FetchFrom, ModelType, ModelPropertyKey
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.errors.invoke import InvokeConnectionError, InvokeServerUnavailableError, InvokeRateLimitError, \
InvokeAuthorizationError, InvokeBadRequestError, InvokeError
@ -198,7 +198,7 @@ class OpenLLMLargeLanguageModel(LargeLanguageModel):
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.LLM,
model_properties={
'mode': LLMMode.COMPLETION,
ModelPropertyKey.MODE: LLMMode.COMPLETION.value,
},
parameter_rules=rules
)

View File

@ -8,7 +8,7 @@ from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.llm_entities import LLMResult, LLMMode, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool, AssistantPromptMessage, \
PromptMessageRole, UserPromptMessage, SystemPromptMessage
from core.model_runtime.entities.model_entities import ParameterRule, AIModelEntity, FetchFrom, ModelType
from core.model_runtime.entities.model_entities import ParameterRule, AIModelEntity, FetchFrom, ModelType, ModelPropertyKey
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.model_providers.replicate._common import _CommonReplicate
@ -91,7 +91,7 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.LLM,
model_properties={
'mode': model_type.value
ModelPropertyKey.MODE: model_type.value
},
parameter_rules=self._get_customizable_model_parameter_rules(model, credentials)
)
@ -116,7 +116,7 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
)
for key, value in input_properties:
if key not in ['system_prompt', 'prompt']:
if key not in ['system_prompt', 'prompt'] and 'stop' not in key:
value_type = value.get('type')
if not value_type:
@ -151,9 +151,17 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
index = -1
current_completion: str = ""
stop_condition_reached = False
prediction_output_length = 10000
is_prediction_output_finished = False
for output in prediction.output_iterator():
current_completion += output
if not is_prediction_output_finished and prediction.status == 'succeeded':
prediction_output_length = len(prediction.output) - 1
is_prediction_output_finished = True
if stop:
for s in stop:
if s in current_completion:
@ -172,20 +180,30 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
content=output if output else ''
)
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
if index < prediction_output_length:
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=assistant_prompt_message
)
)
else:
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=assistant_prompt_message,
usage=usage,
),
)
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=assistant_prompt_message,
usage=usage
)
)
def _handle_generate_response(self, model: str, credentials: dict, prediction: Prediction, stop: list[str],
prompt_messages: list[PromptMessage]) -> LLMResult:

View File

@ -19,13 +19,23 @@ class SparkProvider(ModelProvider):
try:
model_instance = self.get_model_instance(ModelType.LLM)
# Use `claude-instant-1` model for validate,
model_instance.validate_credentials(
model='spark-1.5',
credentials=credentials
)
except CredentialsValidateFailedError as ex:
raise ex
try:
model_instance = self.get_model_instance(ModelType.LLM)
model_instance.validate_credentials(
model='spark-3',
credentials=credentials
)
except CredentialsValidateFailedError as ex:
raise ex
except Exception as ex:
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
raise ex
except Exception as ex:
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
raise ex

View File

@ -52,9 +52,13 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
:param tools: tools for tool calling
:return:
"""
# transform credentials to kwargs for model instance
credentials_kwargs = self._to_credential_kwargs(credentials)
response = dashscope.Tokenization.call(
model=model,
prompt=self._convert_messages_to_prompt(prompt_messages),
**credentials_kwargs
)
if response.status_code == HTTPStatus.OK:
@ -108,10 +112,6 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
# transform credentials to kwargs for model instance
credentials_kwargs = self._to_credential_kwargs(credentials)
dashscope.api_key = credentials_kwargs['api_key']
print(credentials_kwargs, 'credentials_kwargs')
client = EnhanceTongyi(
model_name=model,
streaming=stream,
@ -121,7 +121,8 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
params = {
'model': model,
'prompt': self._convert_messages_to_prompt(prompt_messages),
**model_parameters
**model_parameters,
**credentials_kwargs
}
if stream:
responses = stream_generate_with_retry(
@ -222,7 +223,6 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
:param credentials:
:return:
"""
print(credentials, 'credentials')
credentials_kwargs = {
"api_key": credentials['dashscope_api_key'],
}

View File

@ -18,7 +18,7 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool, UserPromptMessage, SystemPromptMessage, AssistantPromptMessage
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import FetchFrom, ModelType, ParameterRule, ParameterType
from core.model_runtime.entities.model_entities import FetchFrom, ModelType, ParameterRule, ParameterType, ModelPropertyKey
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.xinference.llm.xinference_helper import XinferenceHelper, XinferenceModelExtraParameter
from core.model_runtime.errors.invoke import InvokeConnectionError, InvokeServerUnavailableError, InvokeRateLimitError, \
@ -56,10 +56,18 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
}
"""
try:
XinferenceHelper.get_xinference_extra_parameter(
extra_param = XinferenceHelper.get_xinference_extra_parameter(
server_url=credentials['server_url'],
model_uid=credentials['model_uid']
)
if 'completion_type' not in credentials:
if 'chat' in extra_param.model_ability:
credentials['completion_type'] = 'chat'
elif 'generate' in extra_param.model_ability:
credentials['completion_type'] = 'completion'
else:
raise ValueError(f'xinference model ability {extra_param.model_ability} is not supported')
except RuntimeError as e:
raise CredentialsValidateFailedError(f'Xinference credentials validate failed: {e}')
except KeyError as e:
@ -256,17 +264,26 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
]
completion_type = None
extra_args = XinferenceHelper.get_xinference_extra_parameter(
server_url=credentials['server_url'],
model_uid=credentials['model_uid']
)
if 'chat' in extra_args.model_ability:
completion_type = LLMMode.CHAT
elif 'generate' in extra_args.model_ability:
completion_type = LLMMode.COMPLETION
if 'completion_type' in credentials:
if credentials['completion_type'] == 'chat':
completion_type = LLMMode.CHAT.value
elif credentials['completion_type'] == 'completion':
completion_type = LLMMode.COMPLETION.value
else:
raise ValueError(f'completion_type {credentials["completion_type"]} is not supported')
else:
raise NotImplementedError(f'xinference model ability {extra_args.model_ability} is not supported')
extra_args = XinferenceHelper.get_xinference_extra_parameter(
server_url=credentials['server_url'],
model_uid=credentials['model_uid']
)
if 'chat' in extra_args.model_ability:
completion_type = LLMMode.CHAT.value
elif 'generate' in extra_args.model_ability:
completion_type = LLMMode.COMPLETION.value
else:
raise ValueError(f'xinference model ability {extra_args.model_ability} is not supported')
entity = AIModelEntity(
model=model,
@ -276,7 +293,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.LLM,
model_properties={
'mode': completion_type,
ModelPropertyKey.MODE: completion_type,
},
parameter_rules=rules
)

View File

@ -33,10 +33,13 @@ class XinferenceHelper:
@staticmethod
def _clean_cache() -> None:
with cache_lock:
for model_uid, model in cache.items():
if model['expires'] < time():
try:
with cache_lock:
expired_keys = [model_uid for model_uid, model in cache.items() if model['expires'] < time()]
for model_uid in expired_keys:
del cache[model_uid]
except RuntimeError as e:
pass
@staticmethod
def _get_xinference_extra_parameter(server_url: str, model_uid: str) -> XinferenceModelExtraParameter:

View File

@ -8,8 +8,9 @@ from typing import (
Union
)
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool, UserPromptMessage, AssistantPromptMessage, \
SystemPromptMessage
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool, UserPromptMessage, \
AssistantPromptMessage, \
SystemPromptMessage, PromptMessageRole
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, \
LLMResultChunkDelta
from core.model_runtime.errors.validate import CredentialsValidateFailedError
@ -111,13 +112,39 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
if len(prompt_messages) == 0:
raise ValueError('At least one message is required')
if prompt_messages[0].role.value == 'system':
if prompt_messages[0].role == PromptMessageRole.SYSTEM:
if not prompt_messages[0].content:
prompt_messages = prompt_messages[1:]
# resolve zhipuai model not support system message and user message, assistant message must be in sequence
new_prompt_messages = []
for prompt_message in prompt_messages:
copy_prompt_message = prompt_message.copy()
if copy_prompt_message.role in [PromptMessageRole.USER, PromptMessageRole.SYSTEM, PromptMessageRole.TOOL]:
if not isinstance(copy_prompt_message.content, str):
# not support image message
continue
if new_prompt_messages and new_prompt_messages[-1].role == PromptMessageRole.USER:
new_prompt_messages[-1].content += "\n\n" + copy_prompt_message.content
else:
if copy_prompt_message.role == PromptMessageRole.USER:
new_prompt_messages.append(copy_prompt_message)
else:
new_prompt_message = UserPromptMessage(content=copy_prompt_message.content)
new_prompt_messages.append(new_prompt_message)
else:
if new_prompt_messages and new_prompt_messages[-1].role == PromptMessageRole.ASSISTANT:
new_prompt_messages[-1].content += "\n\n" + copy_prompt_message.content
else:
new_prompt_messages.append(copy_prompt_message)
params = {
'model': model,
'prompt': [{ 'role': prompt_message.role.value, 'content': prompt_message.content } for prompt_message in prompt_messages],
'prompt': [{
'role': prompt_message.role.value,
'content': prompt_message.content
} for prompt_message in new_prompt_messages],
**model_parameters
}

View File

@ -24,7 +24,7 @@ provider_credential_schema:
- variable: api_key
label:
en_US: APIKey
type: text-input
type: secret-input
required: true
placeholder:
zh_Hans: 在此输入您的 APIKey

View File

@ -1,93 +0,0 @@
from core.moderation.base import Moderation, ModerationAction, ModerationInputsResult, ModerationOutputsResult
class CloudServiceModeration(Moderation):
"""
The name of custom type must be unique, keep the same with directory and file name.
"""
name: str = "cloud_service"
@classmethod
def validate_config(cls, tenant_id: str, config: dict) -> None:
"""
schema.json validation. It will be called when user save the config.
Example:
.. code-block:: python
config = {
"cloud_provider": "GoogleCloud",
"api_endpoint": "https://api.example.com",
"api_keys": "123456",
"inputs_config": {
"enabled": True,
"preset_response": "Your content violates our usage policy. Please revise and try again."
},
"outputs_config": {
"enabled": True,
"preset_response": "Your content violates our usage policy. Please revise and try again."
}
}
:param tenant_id: the id of workspace
:param config: the variables of form config
:return:
"""
cls._validate_inputs_and_outputs_config(config, True)
if not config.get("cloud_provider"):
raise ValueError("cloud_provider is required")
if not config.get("api_endpoint"):
raise ValueError("api_endpoint is required")
if not config.get("api_keys"):
raise ValueError("api_keys is required")
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
"""
Moderation for inputs.
:param inputs: user inputs
:param query: the query of chat app, there is empty if is completion app
:return: the moderation result
"""
flagged = False
preset_response = ""
if self.config['inputs_config']['enabled']:
preset_response = self.config['inputs_config']['preset_response']
if query:
inputs['query__'] = query
flagged = self._is_violated(inputs)
# return ModerationInputsResult(flagged=flagged, action=ModerationAction.OVERRIDED, inputs=inputs, query=query)
return ModerationInputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response)
def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
"""
Moderation for outputs.
:param text: the text of LLM response
:return: the moderation result
"""
flagged = False
preset_response = ""
if self.config['outputs_config']['enabled']:
preset_response = self.config['outputs_config']['preset_response']
flagged = self._is_violated({'text': text})
# return ModerationOutputsResult(flagged=flagged, action=ModerationAction.OVERRIDED, text=text)
return ModerationOutputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response)
def _is_violated(self, inputs: dict):
"""
The main logic of moderation.
:param inputs:
:return: the moderation result
"""
return False

View File

@ -1,65 +0,0 @@
{
"label": {
"en-US": "Cloud Service",
"zh-Hans": "云服务"
},
"form_schema": [
{
"type": "select",
"label": {
"en-US": "Cloud Provider",
"zh-Hans": "云厂商"
},
"variable": "cloud_provider",
"required": true,
"options": [
{
"label": {
"en-US": "AWS",
"zh-Hans": "亚马逊"
},
"value": "AWS"
},
{
"label": {
"en-US": "Google Cloud",
"zh-Hans": "谷歌云"
},
"value": "GoogleCloud"
},
{
"label": {
"en-US": "Azure Cloud",
"zh-Hans": "微软云"
},
"value": "Azure"
}
],
"default": "GoogleCloud",
"placeholder": ""
},
{
"type": "text-input",
"label": {
"en-US": "API Endpoint",
"zh-Hans": "API Endpoint"
},
"variable": "api_endpoint",
"required": true,
"max_length": 100,
"default": "",
"placeholder": "https://api.example.com"
},
{
"type": "paragraph",
"label": {
"en-US": "API Key",
"zh-Hans": "API Key"
},
"variable": "api_keys",
"required": true,
"default": "",
"placeholder": "Paste your API key here"
}
]
}

View File

@ -207,7 +207,7 @@ class PromptTransform:
json_file_path = os.path.join(prompt_path, f'{prompt_name}.json')
# Open the JSON file and read its content
with open(json_file_path, 'r') as json_file:
with open(json_file_path, 'r', encoding='utf-8') as json_file:
return json.load(json_file)
def _get_simple_chat_app_chat_model_prompt_messages(self, prompt_rules: dict,

View File

@ -10,6 +10,7 @@ from core.entities.provider_configuration import ProviderConfigurations, Provide
from core.entities.provider_entities import CustomConfiguration, CustomProviderConfiguration, CustomModelConfiguration, \
SystemConfiguration, QuotaConfiguration
from core.helper import encrypter
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.provider_entities import ProviderEntity, CredentialFormSchema, FormType
from core.model_runtime.model_providers import model_provider_factory
@ -79,9 +80,6 @@ class ProviderManager:
# Get All preferred provider types of the workspace
provider_name_to_preferred_model_provider_records_dict = self._get_all_preferred_model_providers(tenant_id)
# Get decoding rsa key and cipher for decrypting credentials
decoding_rsa_key, decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
provider_configurations = ProviderConfigurations(
tenant_id=tenant_id
)
@ -100,19 +98,17 @@ class ProviderManager:
# Convert to custom configuration
custom_configuration = self._to_custom_configuration(
tenant_id,
provider_entity,
provider_records,
provider_model_records,
decoding_rsa_key,
decoding_cipher_rsa
provider_model_records
)
# Convert to system configuration
system_configuration = self._to_system_configuration(
tenant_id,
provider_entity,
provider_records,
decoding_rsa_key,
decoding_cipher_rsa
provider_records
)
# Get preferred provider type
@ -401,28 +397,29 @@ class ProviderManager:
Provider.tenant_id == tenant_id,
Provider.provider_name == provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == ProviderQuotaType.TRIAL.value,
Provider.is_valid == True
Provider.quota_type == ProviderQuotaType.TRIAL.value
).first()
if provider_record and not provider_record.is_valid:
provider_record.is_valid = True
db.session.commit()
provider_name_to_provider_records_dict[provider_name].append(provider_record)
return provider_name_to_provider_records_dict
def _to_custom_configuration(self,
tenant_id: str,
provider_entity: ProviderEntity,
provider_records: list[Provider],
provider_model_records: list[ProviderModel],
decoding_rsa_key,
decoding_cipher_rsa) -> CustomConfiguration:
provider_model_records: list[ProviderModel]) -> CustomConfiguration:
"""
Convert to custom configuration.
:param tenant_id: workspace id
:param provider_entity: provider entity
:param provider_records: provider records
:param provider_model_records: provider model records
:param decoding_rsa_key: decoding rsa key
:param decoding_cipher_rsa: decoding cipher rsa
:return:
"""
# Get provider credential secret variables
@ -445,18 +442,48 @@ class ProviderManager:
# Get custom provider credentials
custom_provider_configuration = None
if custom_provider_record:
try:
provider_credentials = json.loads(custom_provider_record.encrypted_config)
except JSONDecodeError:
provider_credentials = {}
provider_credentials_cache = ProviderCredentialsCache(
tenant_id=tenant_id,
identity_id=custom_provider_record.id,
cache_type=ProviderCredentialsCacheType.PROVIDER
)
for variable in provider_credential_secret_variables:
if variable in provider_credentials:
provider_credentials[variable] = encrypter.decrypt_token_with_decoding(
provider_credentials.get(variable),
decoding_rsa_key,
decoding_cipher_rsa
)
# Get cached provider credentials
cached_provider_credentials = provider_credentials_cache.get()
if not cached_provider_credentials:
try:
# fix origin data
if (custom_provider_record.encrypted_config
and not custom_provider_record.encrypted_config.startswith("{")):
provider_credentials = {
"openai_api_key": custom_provider_record.encrypted_config
}
else:
provider_credentials = json.loads(custom_provider_record.encrypted_config)
except JSONDecodeError:
provider_credentials = {}
# Get decoding rsa key and cipher for decrypting credentials
decoding_rsa_key, decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
for variable in provider_credential_secret_variables:
if variable in provider_credentials:
try:
provider_credentials[variable] = encrypter.decrypt_token_with_decoding(
provider_credentials.get(variable),
decoding_rsa_key,
decoding_cipher_rsa
)
except ValueError:
pass
# cache provider credentials
provider_credentials_cache.set(
credentials=provider_credentials
)
else:
provider_credentials = cached_provider_credentials
custom_provider_configuration = CustomProviderConfiguration(
credentials=provider_credentials
@ -474,18 +501,41 @@ class ProviderManager:
if not provider_model_record.encrypted_config:
continue
try:
provider_model_credentials = json.loads(provider_model_record.encrypted_config)
except JSONDecodeError:
continue
provider_model_credentials_cache = ProviderCredentialsCache(
tenant_id=tenant_id,
identity_id=provider_model_record.id,
cache_type=ProviderCredentialsCacheType.MODEL
)
for variable in model_credential_secret_variables:
if variable in provider_model_credentials:
provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding(
provider_model_credentials.get(variable),
decoding_rsa_key,
decoding_cipher_rsa
)
# Get cached provider model credentials
cached_provider_model_credentials = provider_model_credentials_cache.get()
if not cached_provider_model_credentials:
try:
provider_model_credentials = json.loads(provider_model_record.encrypted_config)
except JSONDecodeError:
continue
# Get decoding rsa key and cipher for decrypting credentials
decoding_rsa_key, decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
for variable in model_credential_secret_variables:
if variable in provider_model_credentials:
try:
provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding(
provider_model_credentials.get(variable),
decoding_rsa_key,
decoding_cipher_rsa
)
except ValueError:
pass
# cache provider model credentials
provider_model_credentials_cache.set(
credentials=provider_model_credentials
)
else:
provider_model_credentials = cached_provider_model_credentials
custom_model_configurations.append(
CustomModelConfiguration(
@ -501,17 +551,15 @@ class ProviderManager:
)
def _to_system_configuration(self,
tenant_id: str,
provider_entity: ProviderEntity,
provider_records: list[Provider],
decoding_rsa_key,
decoding_cipher_rsa) -> SystemConfiguration:
provider_records: list[Provider]) -> SystemConfiguration:
"""
Convert to system configuration.
:param tenant_id: workspace id
:param provider_entity: provider entity
:param provider_records: provider records
:param decoding_rsa_key: decoding rsa key
:param decoding_cipher_rsa: decoding cipher rsa
:return:
"""
# Get hosting configuration
@ -564,26 +612,49 @@ class ProviderManager:
provider_record = quota_type_to_provider_records_dict.get(current_quota_type)
if provider_record:
try:
provider_credentials = json.loads(provider_record.encrypted_config)
except JSONDecodeError:
provider_credentials = {}
# Get provider credential secret variables
provider_credential_secret_variables = self._extract_secret_variables(
provider_entity.provider_credential_schema.credential_form_schemas
if provider_entity.provider_credential_schema else []
provider_credentials_cache = ProviderCredentialsCache(
tenant_id=tenant_id,
identity_id=provider_record.id,
cache_type=ProviderCredentialsCacheType.PROVIDER
)
for variable in provider_credential_secret_variables:
if variable in provider_credentials:
provider_credentials[variable] = encrypter.decrypt_token_with_decoding(
provider_credentials.get(variable),
decoding_rsa_key,
decoding_cipher_rsa
)
# Get cached provider credentials
cached_provider_credentials = provider_credentials_cache.get()
current_using_credentials = provider_credentials
if not cached_provider_credentials:
try:
provider_credentials = json.loads(provider_record.encrypted_config)
except JSONDecodeError:
provider_credentials = {}
# Get provider credential secret variables
provider_credential_secret_variables = self._extract_secret_variables(
provider_entity.provider_credential_schema.credential_form_schemas
if provider_entity.provider_credential_schema else []
)
# Get decoding rsa key and cipher for decrypting credentials
decoding_rsa_key, decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
for variable in provider_credential_secret_variables:
if variable in provider_credentials:
try:
provider_credentials[variable] = encrypter.decrypt_token_with_decoding(
provider_credentials.get(variable),
decoding_rsa_key,
decoding_cipher_rsa
)
except ValueError:
pass
current_using_credentials = provider_credentials
# cache provider credentials
provider_credentials_cache.set(
credentials=current_using_credentials
)
else:
current_using_credentials = cached_provider_credentials
else:
current_using_credentials = {}

View File

@ -7,10 +7,38 @@ from typing import (
Optional,
)
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.text_splitter import RecursiveCharacterTextSplitter, TokenTextSplitter, TS, Type, Union, AbstractSet, Literal, Collection
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
class FixedRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
"""
This class is used to implement from_gpt2_encoder, to prevent using of tiktoken
"""
@classmethod
def from_gpt2_encoder(
cls: Type[TS],
encoding_name: str = "gpt2",
model_name: Optional[str] = None,
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
**kwargs: Any,
):
def _token_encoder(text: str) -> int:
return GPT2Tokenizer.get_num_tokens(text)
if issubclass(cls, TokenTextSplitter):
extra_kwargs = {
"encoding_name": encoding_name,
"model_name": model_name,
"allowed_special": allowed_special,
"disallowed_special": disallowed_special,
}
kwargs = {**kwargs, **extra_kwargs}
return cls(length_function=_token_encoder, **kwargs)
class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter):
def __init__(self, fixed_separator: str = "\n\n", separators: Optional[List[str]] = None, **kwargs: Any):
"""Create a new TextSplitter."""
super().__init__(**kwargs)
@ -65,4 +93,4 @@ class FixedRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
if _good_splits:
merged_text = self._merge_splits(_good_splits, separator)
final_chunks.extend(merged_text)
return final_chunks
return final_chunks

View File

@ -46,11 +46,11 @@ def init_app(app: Flask) -> Celery:
beat_schedule = {
'clean_embedding_cache_task': {
'task': 'schedule.clean_embedding_cache_task.clean_embedding_cache_task',
'schedule': timedelta(minutes=1),
'schedule': timedelta(days=7),
},
'clean_unused_datasets_task': {
'task': 'schedule.clean_unused_datasets_task.clean_unused_datasets_task',
'schedule': timedelta(minutes=10),
'schedule': timedelta(days=7),
}
}
celery_app.conf.update(

View File

@ -6,7 +6,7 @@ from typing import Optional, cast, Tuple
import requests
from flask import current_app
from core.entities.model_entities import ModelWithProviderEntity, ModelStatus, DefaultModelEntity
from core.entities.model_entities import ModelStatus
from core.model_runtime.entities.model_entities import ModelType, ParameterRule
from core.model_runtime.model_providers import model_provider_factory
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel

View File

@ -60,7 +60,7 @@
<p>Dear {{ to }},</p>
<p>{{ inviter_name }} is pleased to invite you to join our workspace on Dify, a platform specifically designed for LLM application development. On Dify, you can explore, create, and collaborate to build and operate AI applications.</p>
<p>You can now log in to Dify using the GitHub or Google account associated with this email.</p>
<p style="text-align: center;"><a class="button" href="{{ url }}">Login Here</a></p>
<p style="text-align: center;"><a style="color: #fff; text-decoration: none" class="button" href="{{ url }}">Login Here</a></p>
</div>
<div class="footer">
<p>Best regards,</p>

View File

@ -60,7 +60,7 @@
<p>尊敬的 {{ to }}</p>
<p>{{ inviter_name }} 现邀请您加入我们在 Dify 的工作区,这是一个专为 LLM 应用开发而设计的平台。在 Dify 上,您可以探索、创造和合作,构建和运营 AI 应用。</p>
<p>您现在可以使用与此邮件相对应的 GitHub 或 Google 账号登录 Dify。</p>
<p style="text-align: center;"><a class="button" href="{{ url }}">在此登录</a></p>
<p style="text-align: center;"><a style="color: #fff; text-decoration: none" class="button" href="{{ url }}">在此登录</a></p>
</div>
<div class="footer">
<p>此致,</p>

View File

@ -2,7 +2,7 @@ version: '3.1'
services:
# API service
api:
image: langgenius/dify-api:0.4.1
image: langgenius/dify-api:0.4.3
restart: always
environment:
# Startup mode, 'api' starts the API server.
@ -92,6 +92,8 @@ services:
QDRANT_URL: http://qdrant:6333
# The Qdrant API key.
QDRANT_API_KEY: difyai123456
# The Qdrant clinet timeout setting.
QDRANT_CLIENT_TIMEOUT: 20
# Milvus configuration Only available when VECTOR_STORE is `milvus`.
# The milvus host.
MILVUS_HOST: 127.0.0.1
@ -128,7 +130,7 @@ services:
# worker service
# The Celery worker for processing the queue.
worker:
image: langgenius/dify-api:0.4.1
image: langgenius/dify-api:0.4.3
restart: always
environment:
# Startup mode, 'worker' starts the Celery worker for processing the queue.
@ -170,6 +172,8 @@ services:
QDRANT_URL: http://qdrant:6333
# The Qdrant API key.
QDRANT_API_KEY: difyai123456
# The Qdrant clinet timeout setting.
QDRANT_CLIENT_TIMEOUT: 20
# Milvus configuration Only available when VECTOR_STORE is `milvus`.
# The milvus host.
MILVUS_HOST: 127.0.0.1
@ -196,7 +200,7 @@ services:
# Frontend web application.
web:
image: langgenius/dify-web:0.4.1
image: langgenius/dify-web:0.4.3
restart: always
environment:
EDITION: SELF_HOSTED

View File

@ -23,6 +23,7 @@
]
}
],
"react-hooks/exhaustive-deps": "warn"
"react-hooks/exhaustive-deps": "warn",
"react/display-name": "warn"
}
}
}

View File

@ -10,7 +10,7 @@ First, install the dependencies:
```bash
npm install
# or
yarn
yarn install --frozen-lockfile
```
Then, configure the environment variables. Create a file named `.env.local` in the current directory and copy the contents from `.env.example`. Modify the values of these environment variables according to your requirements:

View File

@ -1,6 +1,6 @@
'use client'
import { useTranslation } from "react-i18next"
import { useTranslation } from 'react-i18next'
const DatasetFooter = () => {
const { t } = useTranslation()

View File

@ -10,4 +10,4 @@ const TextGeneration: FC<IMainProps> = () => {
)
}
export default React.memo(TextGeneration)
export default React.memo(TextGeneration)

View File

@ -1,13 +1,14 @@
'use client'
import React, { FC } from 'react'
import type { FC } from 'react'
import React from 'react'
import s from './style.module.css'
export interface ILoaidingAnimProps {
export type ILoaidingAnimProps = {
type: 'text' | 'avatar'
}
const LoaidingAnim: FC<ILoaidingAnimProps> = ({
type
type,
}) => {
return (
<div className={`${s['dot-flashing']} ${s[type]}`}></div>

View File

@ -23,7 +23,6 @@ const style = {
overflow: 'auto',
}
// eslint-disable-next-line react/display-name
const Flowchart = React.forwardRef((props: {
PrimitiveCode: string
}, ref) => {

View File

@ -1,12 +1,13 @@
'use client'
import React, { FC } from 'react'
import type { FC } from 'react'
import React from 'react'
export interface IGroupNameProps {
export type IGroupNameProps = {
name: string
}
const GroupName: FC<IGroupNameProps> = ({
name
name,
}) => {
return (
<div className='flex items-center mb-1'>

View File

@ -1,7 +1,8 @@
'use client'
import React, { FC } from 'react'
import type { FC } from 'react'
import React from 'react'
const MoreLikeThisIcon: FC = ({ }) => {
const MoreLikeThisIcon: FC = () => {
return (
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
<path fillRule="evenodd" clipRule="evenodd" d="M5.83914 0.666748H10.1609C10.6975 0.666741 11.1404 0.666734 11.5012 0.696212C11.8759 0.726829 12.2204 0.792538 12.544 0.957399C13.0457 1.21306 13.4537 1.62101 13.7093 2.12277C13.8742 2.44633 13.9399 2.7908 13.9705 3.16553C14 3.52633 14 3.96923 14 4.50587V7.41171C14 7.62908 14 7.73776 13.9652 7.80784C13.9303 7.87806 13.8939 7.91566 13.8249 7.95288C13.756 7.99003 13.6262 7.99438 13.3665 8.00307C12.8879 8.01909 12.4204 8.14633 11.997 8.36429C10.9478 7.82388 9.62021 7.82912 8.53296 8.73228C7.15064 9.88056 6.92784 11.8645 8.0466 13.2641C8.36602 13.6637 8.91519 14.1949 9.40533 14.6492C9.49781 14.7349 9.54405 14.7777 9.5632 14.8041C9.70784 15.003 9.5994 15.2795 9.35808 15.3271C9.32614 15.3334 9.26453 15.3334 9.14129 15.3334H5.83912C5.30248 15.3334 4.85958 15.3334 4.49878 15.304C4.12405 15.2733 3.77958 15.2076 3.45603 15.0428C2.95426 14.7871 2.54631 14.3792 2.29065 13.8774C2.12579 13.5538 2.06008 13.2094 2.02946 12.8346C1.99999 12.4738 1.99999 12.0309 2 11.4943V4.50587C1.99999 3.96924 1.99999 3.52632 2.02946 3.16553C2.06008 2.7908 2.12579 2.44633 2.29065 2.12277C2.54631 1.62101 2.95426 1.21306 3.45603 0.957399C3.77958 0.792538 4.12405 0.726829 4.49878 0.696212C4.85957 0.666734 5.3025 0.666741 5.83914 0.666748ZM4.66667 5.33342C4.29848 5.33342 4 5.63189 4 6.00008C4 6.36827 4.29848 6.66675 4.66667 6.66675H8.66667C9.03486 6.66675 9.33333 6.36827 9.33333 6.00008C9.33333 5.63189 9.03486 5.33342 8.66667 5.33342H4.66667ZM4 8.66675C4 8.29856 4.29848 8.00008 4.66667 8.00008H6C6.36819 8.00008 6.66667 8.29856 6.66667 8.66675C6.66667 9.03494 6.36819 9.33342 6 9.33342H4.66667C4.29848 9.33342 4 9.03494 4 8.66675ZM4.66667 2.66675C4.29848 2.66675 4 2.96523 4 3.33342C4 3.7016 4.29848 4.00008 4.66667 4.00008H10.6667C11.0349 4.00008 11.3333 3.7016 11.3333 3.33342C11.3333 2.96523 11.0349 2.66675 10.6667 2.66675H4.66667Z" fill="#DD2590" />

View File

@ -1,9 +1,10 @@
'use client'
import React, { FC } from 'react'
import type { FC } from 'react'
import React from 'react'
import { useTranslation } from 'react-i18next'
import { PlusIcon } from '@heroicons/react/20/solid'
export interface IOperationBtnProps {
export type IOperationBtnProps = {
type: 'add' | 'edit'
actionName?: string
onClick: () => void
@ -14,13 +15,13 @@ const iconMap = {
edit: (<svg width="14" height="14" viewBox="0 0 14 14" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M6.99998 11.6666H12.25M1.75 11.6666H2.72682C3.01217 11.6666 3.15485 11.6666 3.28912 11.6344C3.40816 11.6058 3.52196 11.5587 3.62635 11.4947C3.74408 11.4226 3.84497 11.3217 4.04675 11.1199L11.375 3.79164C11.8583 3.30839 11.8583 2.52488 11.375 2.04164C10.8918 1.55839 10.1083 1.55839 9.62501 2.04164L2.29674 9.3699C2.09496 9.57168 1.99407 9.67257 1.92192 9.7903C1.85795 9.89469 1.81081 10.0085 1.78224 10.1275C1.75 10.2618 1.75 10.4045 1.75 10.6898V11.6666Z" stroke="#344054" strokeWidth="1.25" strokeLinecap="round" strokeLinejoin="round" />
</svg>
)
),
}
const OperationBtn: FC<IOperationBtnProps> = ({
type,
actionName,
onClick
onClick,
}) => {
const { t } = useTranslation()
return (

View File

@ -1,9 +1,10 @@
'use client'
import React, { FC } from 'react'
import type { FC } from 'react'
import React from 'react'
import s from './style.module.css'
export interface IVarHighlightProps {
export type IVarHighlightProps = {
name: string
}
@ -31,6 +32,4 @@ export const varHighlightHTML = ({ name }: IVarHighlightProps) => {
return html
}
export default React.memo(VarHighlight)

View File

@ -1,10 +1,11 @@
'use client'
import React, { FC } from 'react'
import type { FC } from 'react'
import React from 'react'
import { useTranslation } from 'react-i18next'
import WarningMask from '.'
import Button from '@/app/components/base/button'
export interface IHasNotSetAPIProps {
export type IHasNotSetAPIProps = {
isTrailFinished: boolean
onSetting: () => void
}
@ -18,7 +19,7 @@ const icon = (
const HasNotSetAPI: FC<IHasNotSetAPIProps> = ({
isTrailFinished,
onSetting
onSetting,
}) => {
const { t } = useTranslation()

View File

@ -1,9 +1,10 @@
'use client'
import React, { FC } from 'react'
import type { FC } from 'react'
import React from 'react'
import s from './style.module.css'
export interface IWarningMaskProps {
export type IWarningMaskProps = {
title: string
description: string
footer: React.ReactNode

View File

@ -1,423 +0,0 @@
'use client'
import type { FC } from 'react'
import React, { useEffect, useState } from 'react'
import cn from 'classnames'
import { useTranslation } from 'react-i18next'
import { useBoolean, useClickAway, useGetState } from 'ahooks'
import { InformationCircleIcon } from '@heroicons/react/24/outline'
import produce from 'immer'
import ParamItem from './param-item'
import { SlidersH } from '@/app/components/base/icons/src/vender/line/mediaAndDevices'
import Radio from '@/app/components/base/radio'
import Panel from '@/app/components/base/panel'
import type { CompletionParams } from '@/models/debug'
import { TONE_LIST } from '@/config'
import Toast from '@/app/components/base/toast'
import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback'
import { formatNumber } from '@/utils/format'
import { Brush01 } from '@/app/components/base/icons/src/vender/solid/editor'
import { Scales02 } from '@/app/components/base/icons/src/vender/solid/FinanceAndECommerce'
import { Target04 } from '@/app/components/base/icons/src/vender/solid/general'
import { Sliders02 } from '@/app/components/base/icons/src/vender/solid/mediaAndDevices'
import { fetchModelParams } from '@/service/debug'
import Loading from '@/app/components/base/loading'
import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints'
import type { ModelModeType } from '@/types/app'
import ModelIcon from '@/app/components/header/account-setting/model-provider-page/model-icon'
import ModelName from '@/app/components/header/account-setting/model-provider-page/model-name'
import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector'
import { useTextGenerationCurrentProviderAndModelAndModelList } from '@/app/components/header/account-setting/model-provider-page/hooks'
export type IConfigModelProps = {
isAdvancedMode: boolean
mode: string
modelId: string
provider: string
setModel: (model: { id: string; provider: string; mode: ModelModeType; features: string[] }) => void
completionParams: CompletionParams
onCompletionParamsChange: (newParams: CompletionParams) => void
disabled: boolean
}
const ConfigModel: FC<IConfigModelProps> = ({
isAdvancedMode,
modelId,
provider,
setModel,
completionParams,
onCompletionParamsChange,
disabled,
}) => {
const { t } = useTranslation()
const [isShowConfig, { setFalse: hideConfig, toggle: toogleShowConfig }] = useBoolean(false)
const [maxTokenSettingTipVisible, setMaxTokenSettingTipVisible] = useState(false)
const configContentRef = React.useRef(null)
const {
currentProvider,
currentModel: currModel,
textGenerationModelList,
} = useTextGenerationCurrentProviderAndModelAndModelList(
{ provider, model: modelId },
)
const media = useBreakpoints()
const isMobile = media === MediaType.mobile
// Cache loaded model param
const [allParams, setAllParams, getAllParams] = useGetState<Record<string, Record<string, any>>>({})
const currParams = allParams[provider]?.[modelId]
const hasEnableParams = currParams && Object.keys(currParams).some(key => currParams[key].enabled)
const allSupportParams = ['temperature', 'top_p', 'presence_penalty', 'frequency_penalty', 'max_tokens']
const currSupportParams = currParams ? allSupportParams.filter(key => currParams[key].enabled) : allSupportParams
if (isAdvancedMode)
currSupportParams.push('stop')
useEffect(() => {
(async () => {
if (!allParams[provider]?.[modelId]) {
const res = await fetchModelParams(provider, modelId)
const newAllParams = produce(allParams, (draft) => {
if (!draft[provider])
draft[provider] = {}
draft[provider][modelId] = res
})
setAllParams(newAllParams)
}
})()
}, [provider, modelId, allParams, setAllParams])
useClickAway(() => {
hideConfig()
}, configContentRef)
const selectedModel = { name: modelId } // options.find(option => option.id === modelId)
const ensureModelParamLoaded = (provider: string, modelId: string) => {
return new Promise<void>((resolve) => {
if (getAllParams()[provider]?.[modelId]) {
resolve()
return
}
const runId = setInterval(() => {
if (getAllParams()[provider]?.[modelId]) {
resolve()
clearInterval(runId)
}
}, 500)
})
}
const transformValue = (value: number, fromRange: [number, number], toRange: [number, number]): number => {
const [fromStart = 0, fromEnd] = fromRange
const [toStart = 0, toEnd] = toRange
// The following three if is to avoid precision loss
if (fromStart === toStart && fromEnd === toEnd)
return value
if (value <= fromStart)
return toStart
if (value >= fromEnd)
return toEnd
const fromLength = fromEnd - fromStart
const toLength = toEnd - toStart
let adjustedValue = (value - fromStart) * (toLength / fromLength) + toStart
adjustedValue = parseFloat(adjustedValue.toFixed(2))
return adjustedValue
}
const handleSelectModel = ({ id, provider: nextProvider, mode, features }: { id: string; provider: string; mode: ModelModeType; features: string[] }) => {
return async () => {
const prevParamsRule = getAllParams()[provider]?.[modelId]
setModel({
id,
provider: nextProvider || 'openai',
mode,
features,
})
await ensureModelParamLoaded(nextProvider, id)
const nextParamsRule = getAllParams()[nextProvider]?.[id]
// debugger
const nextSelectModelMaxToken = nextParamsRule.max_tokens.max
const newConCompletionParams = produce(completionParams, (draft: any) => {
if (nextParamsRule.max_tokens.enabled) {
if (completionParams.max_tokens > nextSelectModelMaxToken) {
Toast.notify({
type: 'warning',
message: t('common.model.params.setToCurrentModelMaxTokenTip', { maxToken: formatNumber(nextSelectModelMaxToken) }),
})
draft.max_tokens = parseFloat((nextSelectModelMaxToken * 0.8).toFixed(2))
}
// prev don't have max token
if (!completionParams.max_tokens)
draft.max_tokens = nextParamsRule.max_tokens.default
}
else {
delete draft.max_tokens
}
allSupportParams.forEach((key) => {
if (key === 'max_tokens')
return
if (!nextParamsRule[key].enabled) {
delete draft[key]
return
}
if (draft[key] === undefined) {
draft[key] = nextParamsRule[key].default || 0
return
}
if (!prevParamsRule[key].enabled) {
draft[key] = nextParamsRule[key].default || 0
return
}
draft[key] = transformValue(
draft[key],
[prevParamsRule[key].min, prevParamsRule[key].max],
[nextParamsRule[key].min, nextParamsRule[key].max],
)
})
})
onCompletionParamsChange(newConCompletionParams)
}
}
// only openai support this
function matchToneId(completionParams: CompletionParams): number {
const remvoedCustomeTone = TONE_LIST.slice(0, -1)
const CUSTOM_TONE_ID = 4
const tone = remvoedCustomeTone.find((tone) => {
return tone.config?.temperature === completionParams.temperature
&& tone.config?.top_p === completionParams.top_p
&& tone.config?.presence_penalty === completionParams.presence_penalty
&& tone.config?.frequency_penalty === completionParams.frequency_penalty
})
return tone ? tone.id : CUSTOM_TONE_ID
}
// tone is a preset of completionParams.
const [toneId, setToneId] = React.useState(matchToneId(completionParams)) // default is Balanced
const toneTabBgClassName = ({
1: 'bg-[#F5F8FF]',
2: 'bg-[#F4F3FF]',
3: 'bg-[#F6FEFC]',
})[toneId] || ''
// set completionParams by toneId
const handleToneChange = (id: number) => {
if (id === 4)
return // custom tone
const tone = TONE_LIST.find(tone => tone.id === id)
if (tone) {
setToneId(id)
onCompletionParamsChange({
...tone.config,
max_tokens: completionParams.max_tokens,
} as CompletionParams)
}
}
useEffect(() => {
setToneId(matchToneId(completionParams))
}, [completionParams])
const handleParamChange = (key: string, value: number | string[]) => {
if (value === undefined)
return
if ((completionParams as any)[key] === value)
return
if (key === 'stop') {
onCompletionParamsChange({
...completionParams,
[key]: value as string[],
})
}
else {
const currParamsRule = getAllParams()[provider]?.[modelId]
let notOutRangeValue = parseFloat((value as number).toFixed(2))
notOutRangeValue = Math.max(currParamsRule[key].min, notOutRangeValue)
notOutRangeValue = Math.min(currParamsRule[key].max, notOutRangeValue)
onCompletionParamsChange({
...completionParams,
[key]: notOutRangeValue,
})
}
}
const ableStyle = 'bg-indigo-25 border-[#2A87F5] cursor-pointer'
const diabledStyle = 'bg-[#FFFCF5] border-[#F79009]'
const getToneIcon = (toneId: number) => {
const className = 'w-[14px] h-[14px]'
const res = ({
1: <Brush01 className={className} />,
2: <Scales02 className={className} />,
3: <Target04 className={className} />,
4: <Sliders02 className={className} />,
})[toneId]
return res
}
useEffect(() => {
if (!currParams)
return
const max = currParams.max_tokens.max
const isSupportMaxToken = currParams.max_tokens.enabled
if (isSupportMaxToken && currentProvider?.provider !== 'anthropic' && completionParams.max_tokens > max * 2 / 3)
setMaxTokenSettingTipVisible(true)
else
setMaxTokenSettingTipVisible(false)
}, [currParams, completionParams.max_tokens, setMaxTokenSettingTipVisible, currentProvider])
return (
<div className='relative' ref={configContentRef}>
<div
className={cn('flex items-center border h-8 px-2 space-x-2 rounded-lg', disabled ? diabledStyle : ableStyle)}
onClick={() => !disabled && toogleShowConfig()}
>
{
currentProvider && (
<ModelIcon
className='!w-5 !h-5'
provider={currentProvider}
/>
)
}
{
currModel && (
<ModelName
className='text-gray-900'
modelItem={currModel}
showMode={isAdvancedMode}
/>
)
}
{disabled ? <InformationCircleIcon className='w-4 h-4 text-[#F79009]' /> : <SlidersH className='w-4 h-4 text-indigo-600' />}
</div>
{isShowConfig && (
<Panel
className='absolute z-20 top-8 left-0 sm:left-[unset] sm:right-0 !w-fit sm:!w-[496px] bg-white !overflow-visible shadow-md'
keepUnFold
headerIcon={
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M8.26865 0.790031C8.09143 0.753584 7.90866 0.753584 7.73144 0.790031C7.52659 0.832162 7.3435 0.934713 7.19794 1.01624L7.15826 1.03841L6.17628 1.58395C5.85443 1.76276 5.73846 2.16863 5.91727 2.49049C6.09608 2.81234 6.50195 2.9283 6.82381 2.74949L7.80579 2.20395C7.90681 2.14782 7.95839 2.11946 7.99686 2.10091L8.00004 2.09938L8.00323 2.10091C8.0417 2.11946 8.09327 2.14782 8.1943 2.20395L9.17628 2.74949C9.49814 2.9283 9.90401 2.81234 10.0828 2.49048C10.2616 2.16863 10.1457 1.76276 9.82381 1.58395L8.84183 1.03841L8.80215 1.01624C8.65659 0.934713 8.4735 0.832162 8.26865 0.790031Z" fill="#1C64F2" />
<path d="M12.8238 3.25062C12.5019 3.07181 12.0961 3.18777 11.9173 3.50963C11.7385 3.83148 11.8544 4.23735 12.1763 4.41616L12.6272 4.66668L12.1763 4.91719C11.8545 5.096 11.7385 5.50186 11.9173 5.82372C12.0961 6.14558 12.502 6.26154 12.8238 6.08273L13.3334 5.79966V6.33339C13.3334 6.70158 13.6319 7.00006 14 7.00006C14.3682 7.00006 14.6667 6.70158 14.6667 6.33339V5.29435L14.6668 5.24627C14.6673 5.12441 14.6678 4.98084 14.6452 4.83482C14.6869 4.67472 14.6696 4.49892 14.5829 4.34286C14.4904 4.1764 14.3371 4.06501 14.1662 4.02099C14.0496 3.93038 13.9239 3.86116 13.8173 3.8024L13.7752 3.77915L12.8238 3.25062Z" fill="#1C64F2" />
<path d="M3.8238 4.41616C4.14566 4.23735 4.26162 3.83148 4.08281 3.50963C3.90401 3.18777 3.49814 3.07181 3.17628 3.25062L2.22493 3.77915L2.18284 3.8024C2.07615 3.86116 1.95045 3.9304 1.83382 4.02102C1.66295 4.06506 1.50977 4.17643 1.41731 4.34286C1.33065 4.49886 1.31323 4.67459 1.35493 4.83464C1.33229 4.98072 1.33281 5.12436 1.33326 5.24627L1.33338 5.29435V6.33339C1.33338 6.70158 1.63185 7.00006 2.00004 7.00006C2.36823 7.00006 2.66671 6.70158 2.66671 6.33339V5.79961L3.17632 6.08273C3.49817 6.26154 3.90404 6.14558 4.08285 5.82372C4.26166 5.50186 4.1457 5.096 3.82384 4.91719L3.3729 4.66666L3.8238 4.41616Z" fill="#1C64F2" />
<path d="M2.66671 9.66672C2.66671 9.29853 2.36823 9.00006 2.00004 9.00006C1.63185 9.00006 1.33338 9.29853 1.33338 9.66672V10.7058L1.33326 10.7538C1.33262 10.9298 1.33181 11.1509 1.40069 11.3594C1.46024 11.5397 1.55759 11.7051 1.68622 11.8447C1.835 12.0061 2.02873 12.1128 2.18281 12.1977L2.22493 12.221L3.17628 12.7495C3.49814 12.9283 3.90401 12.8123 4.08281 12.4905C4.26162 12.1686 4.14566 11.7628 3.8238 11.584L2.87245 11.0554C2.76582 10.9962 2.71137 10.9656 2.67318 10.9413L2.66995 10.9392L2.66971 10.9354C2.66699 10.8902 2.66671 10.8277 2.66671 10.7058V9.66672Z" fill="#1C64F2" />
<path d="M14.6667 9.66672C14.6667 9.29853 14.3682 9.00006 14 9.00006C13.6319 9.00006 13.3334 9.29853 13.3334 9.66672V10.7058C13.3334 10.8277 13.3331 10.8902 13.3304 10.9354L13.3301 10.9392L13.3269 10.9413C13.2887 10.9656 13.2343 10.9962 13.1276 11.0554L12.1763 11.584C11.8544 11.7628 11.7385 12.1686 11.9173 12.4905C12.0961 12.8123 12.5019 12.9283 12.8238 12.7495L13.7752 12.221L13.8172 12.1977C13.9713 12.1128 14.1651 12.0061 14.3139 11.8447C14.4425 11.7051 14.5398 11.5397 14.5994 11.3594C14.6683 11.1509 14.6675 10.9298 14.6668 10.7538L14.6667 10.7058V9.66672Z" fill="#1C64F2" />
<path d="M6.82381 13.2506C6.50195 13.0718 6.09608 13.1878 5.91727 13.5096C5.73846 13.8315 5.85443 14.2374 6.17628 14.4162L7.15826 14.9617L7.19793 14.9839C7.29819 15.04 7.41625 15.1061 7.54696 15.1556C7.66589 15.2659 7.82512 15.3333 8.00008 15.3333C8.17507 15.3333 8.33431 15.2659 8.45324 15.1556C8.58391 15.1061 8.70193 15.04 8.80215 14.9839L8.84183 14.9617L9.82381 14.4162C10.1457 14.2374 10.2616 13.8315 10.0828 13.5096C9.90401 13.1878 9.49814 13.0718 9.17628 13.2506L8.66675 13.5337V13C8.66675 12.6318 8.36827 12.3333 8.00008 12.3333C7.63189 12.3333 7.33341 12.6318 7.33341 13V13.5337L6.82381 13.2506Z" fill="#1C64F2" />
<path d="M6.82384 6.58385C6.50199 6.40505 6.09612 6.52101 5.91731 6.84286C5.7385 7.16472 5.85446 7.57059 6.17632 7.7494L7.33341 8.39223V9.66663C7.33341 10.0348 7.63189 10.3333 8.00008 10.3333C8.36827 10.3333 8.66675 10.0348 8.66675 9.66663V8.39223L9.82384 7.7494C10.1457 7.57059 10.2617 7.16472 10.0829 6.84286C9.90404 6.52101 9.49817 6.40505 9.17632 6.58385L8.00008 7.23732L6.82384 6.58385Z" fill="#1C64F2" />
</svg>
}
title={t('appDebug.modelConfig.title')}
>
<div className='py-3 pl-10 pr-6 text-sm'>
<div className="flex items-center justify-between my-5 h-9">
<div>{t('appDebug.modelConfig.model')}</div>
<ModelSelector
defaultModel={{ model: modelId, provider }}
modelList={textGenerationModelList}
onSelect={({ provider, model }) => {
const targetProvider = textGenerationModelList.find(modelItem => modelItem.provider === provider)
const targetModelItem = targetProvider?.models.find(modelItem => modelItem.model === model)
handleSelectModel({
id: model,
provider,
mode: targetModelItem?.model_properties.mode as ModelModeType,
features: targetModelItem?.features || [],
})()
}}
/>
</div>
{hasEnableParams && (
<div className="border-b border-gray-100"></div>
)}
{/* Tone type */}
{['openai', 'azure_openai'].includes(provider) && (
<div className="mt-5 mb-4">
<div className="mb-3 text-sm text-gray-900">{t('appDebug.modelConfig.setTone')}</div>
<Radio.Group className={cn('!rounded-lg', toneTabBgClassName)} value={toneId} onChange={handleToneChange}>
<>
{TONE_LIST.slice(0, 3).map(tone => (
<div className='grow flex items-center' key={tone.id}>
<Radio
value={tone.id}
className={cn(tone.id === toneId && 'rounded-md border border-gray-200 shadow-md', '!mr-0 grow !px-1 sm:!px-2 !justify-center text-[13px] font-medium')}
labelClassName={cn(tone.id === toneId
? ({
1: 'text-[#6938EF]',
2: 'text-[#444CE7]',
3: 'text-[#107569]',
})[toneId]
: 'text-[#667085]', 'flex items-center space-x-2')}
>
<>
{getToneIcon(tone.id)}
{!isMobile && <div>{t(`common.model.tone.${tone.name}`) as string}</div>}
<div className=""></div>
</>
</Radio>
{tone.id !== toneId && tone.id + 1 !== toneId && (<div className='h-5 border-r border-gray-200'></div>)}
</div>
))}
</>
<Radio
value={TONE_LIST[3].id}
className={cn(toneId === 4 && 'rounded-md border border-gray-200 shadow-md', '!mr-0 grow !px-1 sm:!px-2 !justify-center text-[13px] font-medium')}
labelClassName={cn('flex items-center space-x-2 ', toneId === 4 ? 'text-[#155EEF]' : 'text-[#667085]')}
>
<>
{getToneIcon(TONE_LIST[3].id)}
{!isMobile && <div>{t(`common.model.tone.${TONE_LIST[3].name}`) as string}</div>}
</>
</Radio>
</Radio.Group>
</div>
)}
{/* Params */}
<div className={cn(hasEnableParams && 'mt-4', 'space-y-4', !allParams[provider]?.[modelId] && 'flex items-center min-h-[200px]')}>
{(allParams[provider]?.[modelId])
? (
currSupportParams.map(key => (<ParamItem
key={key}
id={key}
name={t(`common.model.params.${key === 'stop' ? 'stop_sequences' : key}`)}
tip={t(`common.model.params.${key === 'stop' ? 'stop_sequences' : key}Tip`)}
{...currParams[key] as any}
value={(completionParams as any)[key] as any}
onChange={handleParamChange}
inputType={key === 'stop' ? 'inputTag' : 'slider'}
/>))
)
: (
<Loading type='area' />
)}
</div>
</div>
{
maxTokenSettingTipVisible && (
<div className='flex py-2 pr-4 pl-5 rounded-bl-xl rounded-br-xl bg-[#FFFAEB] border-t border-[#FEF0C7]'>
<AlertTriangle className='shrink-0 mr-2 mt-[3px] w-3 h-3 text-[#F79009]' />
<div className='mr-2 text-xs font-medium text-gray-700'>{t('common.model.params.maxTokenSettingTip')}</div>
</div>
)
}
</Panel>
)}
</div>
)
}
export default React.memo(ConfigModel)

View File

@ -1,29 +0,0 @@
'use client'
import type { FC } from 'react'
import React from 'react'
import { useTranslation } from 'react-i18next'
import cn from 'classnames'
import type { ModelModeType } from '@/types/app'
type Props = {
className?: string
type: ModelModeType
isHighlight?: boolean
}
const ModelModeTypeLabel: FC<Props> = ({
className,
type,
isHighlight,
}) => {
const { t } = useTranslation()
return (
<div
className={cn(className, isHighlight ? 'border-indigo-300 text-indigo-600' : 'border-gray-300 text-gray-500', 'flex items-center h-4 px-1 border rounded text-xs font-semibold uppercase text-ellipsis overflow-hidden whitespace-nowrap')}
>
{t(`appDebug.modelConfig.modeType.${type}`)}
</div>
)
}
export default React.memo(ModelModeTypeLabel)

View File

@ -1,26 +0,0 @@
'use client'
import type { FC } from 'react'
import React from 'react'
export type IModelNameProps = {
modelId: string
modelDisplayName?: string
}
export const supportI18nModelName = [
'gpt-3.5-turbo', 'gpt-3.5-turbo-16k',
'gpt-4', 'gpt-4-32k',
'text-davinci-003', 'text-embedding-ada-002', 'whisper-1',
'claude-instant-1', 'claude-2',
]
const ModelName: FC<IModelNameProps> = ({
modelDisplayName,
}) => {
return (
<span className='text-ellipsis overflow-hidden whitespace-nowrap' title={modelDisplayName}>
{modelDisplayName}
</span>
)
}
export default React.memo(ModelName)

View File

@ -1,95 +0,0 @@
'use client'
import type { FC } from 'react'
import React, { useEffect } from 'react'
import { useTranslation } from 'react-i18next'
import Tooltip from '@/app/components/base/tooltip'
import Slider from '@/app/components/base/slider'
import TagInput from '@/app/components/base/tag-input'
export const getFitPrecisionValue = (num: number, precision: number | null) => {
if (!precision || !(`${num}`).includes('.'))
return num
const currNumPrecision = (`${num}`).split('.')[1].length
if (currNumPrecision > precision)
return parseFloat(num.toFixed(precision))
return num
}
export type IParamIteProps = {
id: string
name: string
tip: string
value: number | string[]
step?: number
min?: number
max: number
precision: number | null
onChange: (key: string, value: number | string[]) => void
inputType?: 'inputTag' | 'slider'
}
const TIMES_TEMPLATE = '1000000000000'
const ParamItem: FC<IParamIteProps> = ({ id, name, tip, step = 0.1, min = 0, max, precision, value, inputType, onChange }) => {
const { t } = useTranslation()
const getToIntTimes = (num: number) => {
if (precision)
return parseInt(TIMES_TEMPLATE.slice(0, precision + 1), 10)
if (num < 5)
return 10
return 1
}
const times = getToIntTimes(max)
useEffect(() => {
if (precision)
onChange(id, getFitPrecisionValue(value, precision))
}, [value, precision])
return (
<div className="flex items-center justify-between flex-wrap gap-y-2">
<div className="flex flex-col flex-shrink-0">
<div className="flex items-center">
<span className="mr-[6px] text-gray-500 text-[13px] font-medium">{name}</span>
{/* Give tooltip different tip to avoiding hide bug */}
<Tooltip htmlContent={<div className="w-[200px] whitespace-pre-wrap">{tip}</div>} position='top' selector={`param-name-tooltip-${id}`}>
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M8.66667 10.6667H8V8H7.33333M8 5.33333H8.00667M14 8C14 8.78793 13.8448 9.56815 13.5433 10.2961C13.2417 11.0241 12.7998 11.6855 12.2426 12.2426C11.6855 12.7998 11.0241 13.2417 10.2961 13.5433C9.56815 13.8448 8.78793 14 8 14C7.21207 14 6.43185 13.8448 5.7039 13.5433C4.97595 13.2417 4.31451 12.7998 3.75736 12.2426C3.20021 11.6855 2.75825 11.0241 2.45672 10.2961C2.15519 9.56815 2 8.78793 2 8C2 6.4087 2.63214 4.88258 3.75736 3.75736C4.88258 2.63214 6.4087 2 8 2C9.5913 2 11.1174 2.63214 12.2426 3.75736C13.3679 4.88258 14 6.4087 14 8Z" stroke="#9CA3AF" strokeWidth="1.5" strokeLinecap="round" strokeLinejoin="round" />
</svg>
</Tooltip>
</div>
{inputType === 'inputTag' && <div className="text-gray-400 text-xs font-normal">{t('common.model.params.stop_sequencesPlaceholder')}</div>}
</div>
<div className="flex items-center">
{inputType === 'inputTag'
? <TagInput
items={(value ?? []) as string[]}
onChange={newSequences => onChange(id, newSequences)}
customizedConfirmKey='Tab'
/>
: (
<>
<div className="mr-4 w-[120px]">
<Slider value={value * times} min={min * times} max={max * times} onChange={(value) => {
onChange(id, value / times)
}} />
</div>
<input type="number" min={min} max={max} step={step} className="block w-[64px] h-9 leading-9 rounded-lg border-0 pl-1 pl py-1.5 bg-gray-50 text-gray-900 placeholder:text-gray-400 focus:ring-1 focus:ring-inset focus:ring-primary-600" value={value} onChange={(e) => {
let value = getFitPrecisionValue(isNaN(parseFloat(e.target.value)) ? min : parseFloat(e.target.value), precision)
if (value < min)
value = min
if (value > max)
value = max
onChange(id, value)
}} />
</>
)
}
</div>
</div>
)
}
export default React.memo(ParamItem)

View File

@ -1,24 +0,0 @@
'use client'
import type { FC } from 'react'
import React from 'react'
import { useContext } from 'use-context-selector'
import I18n from '@/context/i18n'
import type { ProviderEnum } from '@/app/components/header/account-setting/model-page/declarations'
import ProviderConfig from '@/app/components/header/account-setting/model-page/configs'
export type IProviderNameProps = {
provideName: ProviderEnum
}
const ProviderName: FC<IProviderNameProps> = ({
provideName,
}) => {
const { locale } = useContext(I18n)
return (
<span>
{ProviderConfig[provideName]?.selector?.name[locale]}
</span>
)
}
export default React.memo(ProviderName)

View File

@ -1,16 +1,17 @@
'use client'
import React, { FC } from 'react'
import type { FC } from 'react'
import React from 'react'
import { useTranslation } from 'react-i18next'
import Button from '@/app/components/base/button'
export interface IModalFootProps {
export type IModalFootProps = {
onConfirm: () => void
onCancel: () => void
}
const ModalFoot: FC<IModalFootProps> = ({
onConfirm,
onCancel
onCancel,
}) => {
const { t } = useTranslation()
return (

View File

@ -1,8 +1,9 @@
'use client'
import React, { FC } from 'react'
import type { FC } from 'react'
import React from 'react'
import GroupName from '@/app/components/app/configuration/base/group-name'
export interface IFeatureGroupProps {
export type IFeatureGroupProps = {
title: string
description?: string
children: React.ReactNode
@ -11,7 +12,7 @@ export interface IFeatureGroupProps {
const FeatureGroup: FC<IFeatureGroupProps> = ({
title,
description,
children
children,
}) => {
return (
<div className='mb-6'>

View File

@ -1,14 +1,15 @@
'use client'
import React, { FC } from 'react'
import type { FC } from 'react'
import React from 'react'
export interface ITypeIconProps {
export type ITypeIconProps = {
type: 'upload_file'
size?: 'md' | 'lg'
}
// data_source_type: current only support upload_file
const Icon = ({ type, size = "lg" }: ITypeIconProps) => {
const len = size === "lg" ? 32 : 24
const Icon = ({ type, size = 'lg' }: ITypeIconProps) => {
const len = size === 'lg' ? 32 : 24
const iconMap = {
upload_file: (
<svg width={len} height={len} viewBox="0 0 32 32" fill="none" xmlns="http://www.w3.org/2000/svg">
@ -16,10 +17,9 @@ const Icon = ({ type, size = "lg" }: ITypeIconProps) => {
<path fillRule="evenodd" clipRule="evenodd" d="M8.66669 12.1078C8.66668 11.7564 8.66667 11.4532 8.68707 11.2035C8.7086 10.9399 8.75615 10.6778 8.88468 10.4255C9.07642 10.0492 9.38238 9.74322 9.75871 9.55147C10.011 9.42294 10.2731 9.3754 10.5367 9.35387C10.7864 9.33346 11.0896 9.33347 11.441 9.33349L14.0978 9.33341C14.4935 9.33289 14.8415 9.33243 15.1615 9.4428C15.4417 9.53946 15.697 9.69722 15.9087 9.90465C16.1506 10.1415 16.3058 10.4529 16.4823 10.8071L17.0786 12H19.4942C20.0309 12 20.4738 12 20.8346 12.0295C21.2093 12.0601 21.5538 12.1258 21.8773 12.2907C22.3791 12.5463 22.787 12.9543 23.0427 13.456C23.2076 13.7796 23.2733 14.1241 23.3039 14.4988C23.3334 14.8596 23.3334 15.3025 23.3334 15.8391V18.8276C23.3334 19.3642 23.3334 19.8071 23.3039 20.1679C23.2733 20.5426 23.2076 20.8871 23.0427 21.2107C22.787 21.7124 22.3791 22.1204 21.8773 22.376C21.5538 22.5409 21.2093 22.6066 20.8346 22.6372C20.4738 22.6667 20.0309 22.6667 19.4942 22.6667H12.5058C11.9692 22.6667 11.5263 22.6667 11.1655 22.6372C10.7907 22.6066 10.4463 22.5409 10.1227 22.376C9.62095 22.1204 9.213 21.7124 8.95734 21.2107C8.79248 20.8871 8.72677 20.5426 8.69615 20.1679C8.66667 19.8071 8.66668 19.3642 8.66669 18.8276V12.1078ZM14.0149 10.6668C14.5418 10.6668 14.6463 10.6755 14.7267 10.7033C14.8201 10.7355 14.9052 10.7881 14.9758 10.8572C15.0366 10.9167 15.0911 11.0063 15.3267 11.4776L15.5879 12L10.0001 12C10.0004 11.69 10.0024 11.4781 10.016 11.312C10.0308 11.1309 10.0559 11.0638 10.0727 11.0308C10.1366 10.9054 10.2386 10.8034 10.364 10.7395C10.397 10.7227 10.4641 10.6976 10.6452 10.6828C10.8341 10.6673 11.0823 10.6668 11.4667 10.6668H14.0149Z" fill="#444CE7" />
<rect x="0.25" y="0.25" width="31.5" height="31.5" rx="7.75" stroke="#E0EAFF" strokeWidth="0.5" />
</svg>
)
),
}
return iconMap[type]
}
const TypeIcon: FC<ITypeIconProps> = ({

View File

@ -1,11 +1,12 @@
'use client'
import React, { FC } from 'react'
import type { FC } from 'react'
import React from 'react'
import { useTranslation } from 'react-i18next'
import GroupName from '../../base/group-name'
import MoreLikeThis from './more-like-this'
/*
* Include
* Include
* 1. More like this
*/
const ExperienceEnchanceGroup: FC = () => {

View File

@ -1,10 +1,11 @@
'use client'
import React, { FC } from 'react'
import Panel from '@/app/components/app/configuration/base/feature-panel'
import type { FC } from 'react'
import React from 'react'
import { useTranslation } from 'react-i18next'
import MoreLikeThisIcon from '../../../base/icons/more-like-this-icon'
import { XMarkIcon } from '@heroicons/react/24/outline'
import { useLocalStorageState } from 'ahooks'
import MoreLikeThisIcon from '../../../base/icons/more-like-this-icon'
import Panel from '@/app/components/app/configuration/base/feature-panel'
const GENERATE_NUM = 1

View File

@ -28,7 +28,6 @@ import type { ExternalDataTool } from '@/models/common'
import type { DataSet } from '@/models/datasets'
import type { ModelConfig as BackendModelConfig, VisionSettings } from '@/types/app'
import ConfigContext from '@/context/debug-configuration'
// import ConfigModel from '@/app/components/app/configuration/config-model'
import Config from '@/app/components/app/configuration/config'
import Debug from '@/app/components/app/configuration/debug'
import Confirm from '@/app/components/base/confirm'

View File

@ -1,7 +1,7 @@
'use client'
import type { FC } from 'react'
import { format } from '@/service/base'
import React from 'react'
import { format } from '@/service/base'
export type ITextGenerationProps = {
value: string
@ -16,7 +16,7 @@ const TextGeneration: FC<ITextGenerationProps> = ({
<div
className={className}
dangerouslySetInnerHTML={{
__html: format(value)
__html: format(value),
}}
>
</div>

View File

@ -1,8 +1,9 @@
'use client'
import React, { FC } from 'react'
import type { FC } from 'react'
import React from 'react'
import { useTranslation } from 'react-i18next'
interface IAppUnavailableProps {
type IAppUnavailableProps = {
code?: number
isUnknwonReason?: boolean
unknownReason?: string

View File

@ -1,11 +1,11 @@
'use client'
import React, { FC, useEffect } from 'react'
import type { FC } from 'react'
import React, { useEffect } from 'react'
import cn from 'classnames'
import { useBoolean } from 'ahooks'
import { ChevronRightIcon } from '@heroicons/react/24/outline'
export interface IPanelProps {
export type IPanelProps = {
className?: string
headerIcon: React.ReactNode
title: React.ReactNode
@ -30,23 +30,21 @@ const Panel: FC<IPanelProps> = ({
foldDisabled = false,
onFoldChange,
controlUnFold,
controlFold
controlFold,
}) => {
const [fold, { setTrue: setFold, setFalse: setUnFold, toggle: toggleFold }] = useBoolean(keepUnFold ? false : true)
const [fold, { setTrue: setFold, setFalse: setUnFold, toggle: toggleFold }] = useBoolean(!keepUnFold)
useEffect(() => {
onFoldChange?.(fold)
}, [fold])
useEffect(() => {
if (controlUnFold) {
if (controlUnFold)
setUnFold()
}
}, [controlUnFold])
useEffect(() => {
if (controlFold) {
if (controlFold)
setFold()
}
}, [controlFold])
// overflow-hidden

View File

@ -3,6 +3,8 @@ import React, { useState } from 'react'
import { useRouter } from 'next/navigation'
import { useTranslation } from 'react-i18next'
import { useContext } from 'use-context-selector'
import cn from 'classnames'
import s from './index.module.css'
import Modal from '@/app/components/base/modal'
import Input from '@/app/components/base/input'
import Button from '@/app/components/base/button'
@ -10,12 +12,9 @@ import Button from '@/app/components/base/button'
import { ToastContext } from '@/app/components/base/toast'
import { createEmptyDataset } from '@/service/datasets'
import cn from 'classnames'
import s from './index.module.css'
type IProps = {
show: boolean,
onHide: () => void,
show: boolean
onHide: () => void
}
const EmptyDatasetCreationModal = ({
@ -27,7 +26,7 @@ const EmptyDatasetCreationModal = ({
const { notify } = useContext(ToastContext)
const router = useRouter()
const submit = async () => {
const submit = async () => {
if (!inputValue) {
notify({ type: 'error', message: t('datasetCreation.stepOne.modal.nameNotEmpty') })
return
@ -43,7 +42,6 @@ const EmptyDatasetCreationModal = ({
}
catch (err) {
notify({ type: 'error', message: t('datasetCreation.stepOne.modal.failed') })
return
}
}

View File

@ -1,16 +1,15 @@
'use client'
import React from 'react'
import { useTranslation } from 'react-i18next'
import cn from 'classnames'
import s from './index.module.css'
import Modal from '@/app/components/base/modal'
import Button from '@/app/components/base/button'
import cn from 'classnames'
import s from './index.module.css'
type IProps = {
show: boolean,
onConfirm: () => void,
onHide: () => void,
show: boolean
onConfirm: () => void
onHide: () => void
}
const StopEmbeddingModal = ({
@ -34,7 +33,7 @@ const StopEmbeddingModal = ({
<div className={s.icon}/>
<span className={s.close} onClick={onHide}/>
<div className={s.title}>{t('datasetCreation.stepThree.modelTitle')}</div>
<div className={s.content}>{t('datasetCreation.stepThree.modelContent')}</div>
<div className={s.content}>{t('datasetCreation.stepThree.modelContent')}</div>
<div className='flex flex-row-reverse'>
<Button className='w-24 ml-2' type='primary' onClick={submit}>{t('datasetCreation.stepThree.modelButtonConfirm')}</Button>
<Button className='w-24' onClick={onHide}>{t('datasetCreation.stepThree.modelButtonCancel')}</Button>

View File

@ -287,7 +287,6 @@ const Metadata: FC<IMetadataProps> = ({ docDetail, loading, onUpdate }) => {
}
const onSave = async () => {
console.log('metadataParams:', metadataParams)
setSaveLoading(true)
const [e] = await asyncRunSafe<CommonResponse>(modifyDocMetadata({
datasetId,

View File

@ -257,7 +257,7 @@ const CodeGroupContext = createContext(false)
export function CodeGroup({ children, title, inputs, targetCode, ...props }: IChildrenProps) {
const languages = Children.map(children, child =>
getPanelTitle(child.props.children.props)
getPanelTitle(child.props.children.props),
)
const tabGroupProps = useTabGroupProps(languages)
const hasTabs = Children.count(children) > 1

View File

@ -1,9 +1,9 @@
'use client'
import { useContext } from 'use-context-selector'
import TemplateEn from './template/template.en.mdx'
import TemplateZh from './template/template.zh.mdx'
import TemplateChatEn from './template/template_chat.en.mdx'
import TemplateChatZh from './template/template_chat.zh.mdx'
import { useContext } from 'use-context-selector'
import I18n from '@/context/i18n'
type IDocProps = {
@ -14,20 +14,21 @@ const Doc = ({ appDetail }: IDocProps) => {
const { locale } = useContext(I18n)
const variables = appDetail?.model_config?.configs?.prompt_variables || []
const inputs = variables.reduce((res: any, variable: any) => {
res[variable.key] = variable.name || '';
res[variable.key] = variable.name || ''
return res
}, {})
return (
<article className="prose prose-xl" >
{appDetail?.mode === 'completion' ? (
locale === 'en' ? <TemplateEn appDetail={appDetail} variables={variables} inputs={inputs} /> : <TemplateZh appDetail={appDetail} variables={variables} inputs={inputs} />
) : (
locale === 'en' ? <TemplateChatEn appDetail={appDetail} variables={variables} inputs={inputs} /> : <TemplateChatZh appDetail={appDetail} variables={variables} inputs={inputs} />
)}
{appDetail?.mode === 'completion'
? (
locale === 'en' ? <TemplateEn appDetail={appDetail} variables={variables} inputs={inputs} /> : <TemplateZh appDetail={appDetail} variables={variables} inputs={inputs} />
)
: (
locale === 'en' ? <TemplateChatEn appDetail={appDetail} variables={variables} inputs={inputs} /> : <TemplateChatZh appDetail={appDetail} variables={variables} inputs={inputs} />
)}
</article>
)
}
export default Doc

Some files were not shown because too many files have changed in this diff Show More