mirror of
https://github.com/langgenius/dify.git
synced 2026-04-21 03:07:39 +08:00
Compare commits
9 Commits
revert-942
...
fix/extern
| Author | SHA1 | Date | |
|---|---|---|---|
| 27e2e9f4cd | |||
| 5bd372c11f | |||
| ed4e029609 | |||
| 33d0904981 | |||
| 8456e6379d | |||
| 99967e6fd0 | |||
| c2328cb676 | |||
| 36d3221a05 | |||
| 40f2e7d821 |
7
.github/workflows/api-tests.yml
vendored
7
.github/workflows/api-tests.yml
vendored
@ -27,17 +27,18 @@ jobs:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install Poetry
|
||||
uses: abatilo/actions-poetry@v3
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
cache: 'poetry'
|
||||
cache-dependency-path: |
|
||||
api/pyproject.toml
|
||||
api/poetry.lock
|
||||
|
||||
- name: Install Poetry
|
||||
uses: abatilo/actions-poetry@v3
|
||||
|
||||
- name: Check Poetry lockfile
|
||||
run: |
|
||||
poetry check -C api --lock
|
||||
|
||||
3
.github/workflows/build-push.yml
vendored
3
.github/workflows/build-push.yml
vendored
@ -5,6 +5,7 @@ on:
|
||||
branches:
|
||||
- "main"
|
||||
- "deploy/dev"
|
||||
- "fix/external-knowledge-retrieval-issues"
|
||||
release:
|
||||
types: [published]
|
||||
|
||||
@ -125,7 +126,7 @@ jobs:
|
||||
with:
|
||||
images: ${{ env[matrix.image_name_env] }}
|
||||
tags: |
|
||||
type=raw,value=latest,enable=${{ startsWith(github.ref, 'refs/tags/') && !contains(github.ref, '-') }}
|
||||
type=raw,value=latest,enable=${{ startsWith(github.ref, 'refs/tags/') && !contains(github.ref, '-beta') }}
|
||||
type=ref,event=branch
|
||||
type=sha,enable=true,priority=100,prefix=,suffix=,format=long
|
||||
type=raw,value=${{ github.ref_name }},enable=${{ startsWith(github.ref, 'refs/tags/') }}
|
||||
|
||||
7
.github/workflows/db-migration-test.yml
vendored
7
.github/workflows/db-migration-test.yml
vendored
@ -23,17 +23,18 @@ jobs:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install Poetry
|
||||
uses: abatilo/actions-poetry@v3
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
cache: 'poetry'
|
||||
cache-dependency-path: |
|
||||
api/pyproject.toml
|
||||
api/poetry.lock
|
||||
|
||||
- name: Install Poetry
|
||||
uses: abatilo/actions-poetry@v3
|
||||
|
||||
- name: Install dependencies
|
||||
run: poetry install -C api
|
||||
|
||||
|
||||
7
.github/workflows/style.yml
vendored
7
.github/workflows/style.yml
vendored
@ -24,16 +24,15 @@ jobs:
|
||||
with:
|
||||
files: api/**
|
||||
|
||||
- name: Install Poetry
|
||||
uses: abatilo/actions-poetry@v3
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
with:
|
||||
python-version: '3.10'
|
||||
|
||||
- name: Install Poetry
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
uses: abatilo/actions-poetry@v3
|
||||
|
||||
- name: Python dependencies
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: poetry install -C api --only lint
|
||||
|
||||
@ -20,9 +20,6 @@ FILES_URL=http://127.0.0.1:5001
|
||||
# The time in seconds after the signature is rejected
|
||||
FILES_ACCESS_TIMEOUT=300
|
||||
|
||||
# Access token expiration time in minutes
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES=60
|
||||
|
||||
# celery configuration
|
||||
CELERY_BROKER_URL=redis://:difyai123456@localhost:6379/1
|
||||
|
||||
@ -42,7 +39,7 @@ DB_DATABASE=dify
|
||||
|
||||
# Storage configuration
|
||||
# use for store upload files, private keys...
|
||||
# storage type: local, s3, azure-blob, google-storage, tencent-cos, huawei-obs, volcengine-tos, baidu-obs, supabase
|
||||
# storage type: local, s3, azure-blob, google-storage, tencent-cos, huawei-obs, volcengine-tos, baidu-obs
|
||||
STORAGE_TYPE=local
|
||||
STORAGE_LOCAL_PATH=storage
|
||||
S3_USE_AWS_MANAGED_IAM=false
|
||||
@ -102,16 +99,11 @@ VOLCENGINE_TOS_ACCESS_KEY=your-access-key
|
||||
VOLCENGINE_TOS_SECRET_KEY=your-secret-key
|
||||
VOLCENGINE_TOS_REGION=your-region
|
||||
|
||||
# Supabase Storage Configuration
|
||||
SUPABASE_BUCKET_NAME=your-bucket-name
|
||||
SUPABASE_API_KEY=your-access-key
|
||||
SUPABASE_URL=your-server-url
|
||||
|
||||
# CORS configuration
|
||||
WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
||||
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
||||
|
||||
# Vector database configuration, support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, vikingdb
|
||||
# Vector database configuration, support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector
|
||||
VECTOR_STORE=weaviate
|
||||
|
||||
# Weaviate configuration
|
||||
@ -211,24 +203,6 @@ OPENSEARCH_USER=admin
|
||||
OPENSEARCH_PASSWORD=admin
|
||||
OPENSEARCH_SECURE=true
|
||||
|
||||
# Baidu configuration
|
||||
BAIDU_VECTOR_DB_ENDPOINT=http://127.0.0.1:5287
|
||||
BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS=30000
|
||||
BAIDU_VECTOR_DB_ACCOUNT=root
|
||||
BAIDU_VECTOR_DB_API_KEY=dify
|
||||
BAIDU_VECTOR_DB_DATABASE=dify
|
||||
BAIDU_VECTOR_DB_SHARD=1
|
||||
BAIDU_VECTOR_DB_REPLICAS=3
|
||||
|
||||
# ViKingDB configuration
|
||||
VIKINGDB_ACCESS_KEY=your-ak
|
||||
VIKINGDB_SECRET_KEY=your-sk
|
||||
VIKINGDB_REGION=cn-shanghai
|
||||
VIKINGDB_HOST=api-vikingdb.xxx.volces.com
|
||||
VIKINGDB_SCHEMA=http
|
||||
VIKINGDB_CONNECTION_TIMEOUT=30
|
||||
VIKINGDB_SOCKET_TIMEOUT=30
|
||||
|
||||
# Upload configuration
|
||||
UPLOAD_FILE_SIZE_LIMIT=15
|
||||
UPLOAD_FILE_BATCH_LIMIT=5
|
||||
|
||||
@ -85,4 +85,3 @@
|
||||
cd ../
|
||||
poetry run -C api bash dev/pytest/pytest_all_tests.sh
|
||||
```
|
||||
|
||||
|
||||
@ -183,7 +183,7 @@ def load_user_from_request(request_from_flask_login):
|
||||
decoded = PassportService().verify(auth_token)
|
||||
user_id = decoded.get("user_id")
|
||||
|
||||
logged_in_account = AccountService.load_logged_in_account(account_id=user_id)
|
||||
logged_in_account = AccountService.load_logged_in_account(account_id=user_id, token=auth_token)
|
||||
if logged_in_account:
|
||||
contexts.tenant_id.set(logged_in_account.current_tenant_id)
|
||||
return logged_in_account
|
||||
|
||||
@ -347,14 +347,6 @@ def migrate_knowledge_vector_database():
|
||||
index_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
index_struct_dict = {"type": "elasticsearch", "vector_store": {"class_prefix": index_name}}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
elif vector_type == VectorType.BAIDU:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
index_struct_dict = {
|
||||
"type": VectorType.BAIDU,
|
||||
"vector_store": {"class_prefix": collection_name},
|
||||
}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
else:
|
||||
raise ValueError(f"Vector store {vector_type} is not supported.")
|
||||
|
||||
|
||||
@ -360,9 +360,9 @@ class WorkflowConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class AuthConfig(BaseSettings):
|
||||
class OAuthConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for authentication and OAuth
|
||||
Configuration for OAuth authentication
|
||||
"""
|
||||
|
||||
OAUTH_REDIRECT_PATH: str = Field(
|
||||
@ -371,7 +371,7 @@ class AuthConfig(BaseSettings):
|
||||
)
|
||||
|
||||
GITHUB_CLIENT_ID: Optional[str] = Field(
|
||||
description="GitHub OAuth client ID",
|
||||
description="GitHub OAuth client secret",
|
||||
default=None,
|
||||
)
|
||||
|
||||
@ -390,11 +390,6 @@ class AuthConfig(BaseSettings):
|
||||
default=None,
|
||||
)
|
||||
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: PositiveInt = Field(
|
||||
description="Expiration time for access tokens in minutes",
|
||||
default=60,
|
||||
)
|
||||
|
||||
|
||||
class ModerationConfig(BaseSettings):
|
||||
"""
|
||||
@ -612,7 +607,6 @@ class PositionConfig(BaseSettings):
|
||||
class FeatureConfig(
|
||||
# place the configs in alphabet order
|
||||
AppExecutionConfig,
|
||||
AuthConfig, # Changed from OAuthConfig to AuthConfig
|
||||
BillingConfig,
|
||||
CodeExecutionSandboxConfig,
|
||||
DataSetConfig,
|
||||
@ -627,13 +621,14 @@ class FeatureConfig(
|
||||
MailConfig,
|
||||
ModelLoadBalanceConfig,
|
||||
ModerationConfig,
|
||||
PositionConfig,
|
||||
OAuthConfig,
|
||||
RagEtlConfig,
|
||||
SecurityConfig,
|
||||
ToolConfig,
|
||||
UpdateConfig,
|
||||
WorkflowConfig,
|
||||
WorkspaceConfig,
|
||||
PositionConfig,
|
||||
# hosted services config
|
||||
HostedServiceConfig,
|
||||
CeleryBeatConfig,
|
||||
|
||||
@ -12,7 +12,6 @@ from configs.middleware.storage.baidu_obs_storage_config import BaiduOBSStorageC
|
||||
from configs.middleware.storage.google_cloud_storage_config import GoogleCloudStorageConfig
|
||||
from configs.middleware.storage.huawei_obs_storage_config import HuaweiCloudOBSStorageConfig
|
||||
from configs.middleware.storage.oci_storage_config import OCIStorageConfig
|
||||
from configs.middleware.storage.supabase_storage_config import SupabaseStorageConfig
|
||||
from configs.middleware.storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig
|
||||
from configs.middleware.storage.volcengine_tos_storage_config import VolcengineTOSStorageConfig
|
||||
from configs.middleware.vdb.analyticdb_config import AnalyticdbConfig
|
||||
@ -28,7 +27,6 @@ from configs.middleware.vdb.qdrant_config import QdrantConfig
|
||||
from configs.middleware.vdb.relyt_config import RelytConfig
|
||||
from configs.middleware.vdb.tencent_vector_config import TencentVectorDBConfig
|
||||
from configs.middleware.vdb.tidb_vector_config import TiDBVectorConfig
|
||||
from configs.middleware.vdb.vikingdb_config import VikingDBConfig
|
||||
from configs.middleware.vdb.weaviate_config import WeaviateConfig
|
||||
|
||||
|
||||
@ -224,7 +222,6 @@ class MiddlewareConfig(
|
||||
HuaweiCloudOBSStorageConfig,
|
||||
OCIStorageConfig,
|
||||
S3StorageConfig,
|
||||
SupabaseStorageConfig,
|
||||
TencentCloudCOSStorageConfig,
|
||||
VolcengineTOSStorageConfig,
|
||||
# configs of vdb and vdb providers
|
||||
@ -244,6 +241,5 @@ class MiddlewareConfig(
|
||||
WeaviateConfig,
|
||||
ElasticsearchConfig,
|
||||
InternalTestConfig,
|
||||
VikingDBConfig,
|
||||
):
|
||||
pass
|
||||
|
||||
@ -1,24 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class SupabaseStorageConfig(BaseModel):
|
||||
"""
|
||||
Configuration settings for Supabase Object Storage Service
|
||||
"""
|
||||
|
||||
SUPABASE_BUCKET_NAME: Optional[str] = Field(
|
||||
description="Name of the Supabase bucket to store and retrieve objects (e.g., 'dify-bucket')",
|
||||
default=None,
|
||||
)
|
||||
|
||||
SUPABASE_API_KEY: Optional[str] = Field(
|
||||
description="API KEY for authenticating with Supabase",
|
||||
default=None,
|
||||
)
|
||||
|
||||
SUPABASE_URL: Optional[str] = Field(
|
||||
description="URL of the Supabase",
|
||||
default=None,
|
||||
)
|
||||
@ -1,45 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field, NonNegativeInt, PositiveInt
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class BaiduVectorDBConfig(BaseSettings):
|
||||
"""
|
||||
Configuration settings for Baidu Vector Database
|
||||
"""
|
||||
|
||||
BAIDU_VECTOR_DB_ENDPOINT: Optional[str] = Field(
|
||||
description="URL of the Baidu Vector Database service (e.g., 'http://vdb.bj.baidubce.com')",
|
||||
default=None,
|
||||
)
|
||||
|
||||
BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS: PositiveInt = Field(
|
||||
description="Timeout in milliseconds for Baidu Vector Database operations (default is 30000 milliseconds)",
|
||||
default=30000,
|
||||
)
|
||||
|
||||
BAIDU_VECTOR_DB_ACCOUNT: Optional[str] = Field(
|
||||
description="Account for authenticating with the Baidu Vector Database",
|
||||
default=None,
|
||||
)
|
||||
|
||||
BAIDU_VECTOR_DB_API_KEY: Optional[str] = Field(
|
||||
description="API key for authenticating with the Baidu Vector Database service",
|
||||
default=None,
|
||||
)
|
||||
|
||||
BAIDU_VECTOR_DB_DATABASE: Optional[str] = Field(
|
||||
description="Name of the specific Baidu Vector Database to connect to",
|
||||
default=None,
|
||||
)
|
||||
|
||||
BAIDU_VECTOR_DB_SHARD: PositiveInt = Field(
|
||||
description="Number of shards for the Baidu Vector Database (default is 1)",
|
||||
default=1,
|
||||
)
|
||||
|
||||
BAIDU_VECTOR_DB_REPLICAS: NonNegativeInt = Field(
|
||||
description="Number of replicas for the Baidu Vector Database (default is 3)",
|
||||
default=3,
|
||||
)
|
||||
@ -14,7 +14,7 @@ class OracleConfig(BaseSettings):
|
||||
default=None,
|
||||
)
|
||||
|
||||
ORACLE_PORT: PositiveInt = Field(
|
||||
ORACLE_PORT: Optional[PositiveInt] = Field(
|
||||
description="Port number on which the Oracle database server is listening (default is 1521)",
|
||||
default=1521,
|
||||
)
|
||||
|
||||
@ -14,7 +14,7 @@ class PGVectorConfig(BaseSettings):
|
||||
default=None,
|
||||
)
|
||||
|
||||
PGVECTOR_PORT: PositiveInt = Field(
|
||||
PGVECTOR_PORT: Optional[PositiveInt] = Field(
|
||||
description="Port number on which the PostgreSQL server is listening (default is 5433)",
|
||||
default=5433,
|
||||
)
|
||||
|
||||
@ -14,7 +14,7 @@ class PGVectoRSConfig(BaseSettings):
|
||||
default=None,
|
||||
)
|
||||
|
||||
PGVECTO_RS_PORT: PositiveInt = Field(
|
||||
PGVECTO_RS_PORT: Optional[PositiveInt] = Field(
|
||||
description="Port number on which the PostgreSQL server with PGVecto.RS is listening (default is 5431)",
|
||||
default=5431,
|
||||
)
|
||||
|
||||
@ -1,49 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class VikingDBConfig(BaseModel):
|
||||
"""
|
||||
Configuration for connecting to Volcengine VikingDB.
|
||||
Refer to the following documentation for details on obtaining credentials:
|
||||
https://www.volcengine.com/docs/6291/65568
|
||||
"""
|
||||
|
||||
VIKINGDB_ACCESS_KEY: Optional[str] = Field(
|
||||
description="The Access Key provided by Volcengine VikingDB for API authentication."
|
||||
"Refer to the following documentation for details on obtaining credentials:"
|
||||
"https://www.volcengine.com/docs/6291/65568",
|
||||
default=None,
|
||||
)
|
||||
|
||||
VIKINGDB_SECRET_KEY: Optional[str] = Field(
|
||||
description="The Secret Key provided by Volcengine VikingDB for API authentication.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
VIKINGDB_REGION: str = Field(
|
||||
description="The region of the Volcengine VikingDB service.(e.g., 'cn-shanghai', 'cn-beijing').",
|
||||
default="cn-shanghai",
|
||||
)
|
||||
|
||||
VIKINGDB_HOST: str = Field(
|
||||
description="The host of the Volcengine VikingDB service.(e.g., 'api-vikingdb.volces.com', \
|
||||
'api-vikingdb.mlp.cn-shanghai.volces.com')",
|
||||
default="api-vikingdb.mlp.cn-shanghai.volces.com",
|
||||
)
|
||||
|
||||
VIKINGDB_SCHEME: str = Field(
|
||||
description="The scheme of the Volcengine VikingDB service.(e.g., 'http', 'https').",
|
||||
default="http",
|
||||
)
|
||||
|
||||
VIKINGDB_CONNECTION_TIMEOUT: int = Field(
|
||||
description="The connection timeout of the Volcengine VikingDB service.",
|
||||
default=30,
|
||||
)
|
||||
|
||||
VIKINGDB_SOCKET_TIMEOUT: int = Field(
|
||||
description="The socket timeout of the Volcengine VikingDB service.",
|
||||
default=30,
|
||||
)
|
||||
@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
|
||||
|
||||
CURRENT_VERSION: str = Field(
|
||||
description="Dify version",
|
||||
default="0.9.2",
|
||||
default="0.9.1-fix1",
|
||||
)
|
||||
|
||||
COMMIT_SHA: str = Field(
|
||||
|
||||
@ -7,7 +7,7 @@ from flask_restful import Resource, reqparse
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from libs.helper import email, extract_remote_ip
|
||||
from libs.helper import email, get_remote_ip
|
||||
from libs.password import valid_password
|
||||
from models.account import Account
|
||||
from services.account_service import AccountService, TenantService
|
||||
@ -40,16 +40,17 @@ class LoginApi(Resource):
|
||||
"data": "workspace not found, please contact system admin to invite you to join in a workspace",
|
||||
}
|
||||
|
||||
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
|
||||
token = AccountService.login(account, ip_address=get_remote_ip(request))
|
||||
|
||||
return {"result": "success", "data": token_pair.model_dump()}
|
||||
return {"result": "success", "data": token}
|
||||
|
||||
|
||||
class LogoutApi(Resource):
|
||||
@setup_required
|
||||
def get(self):
|
||||
account = cast(Account, flask_login.current_user)
|
||||
AccountService.logout(account=account)
|
||||
token = request.headers.get("Authorization", "").split(" ")[1]
|
||||
AccountService.logout(account=account, token=token)
|
||||
flask_login.logout_user()
|
||||
return {"result": "success"}
|
||||
|
||||
@ -105,19 +106,5 @@ class ResetPasswordApi(Resource):
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
class RefreshTokenApi(Resource):
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("refresh_token", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
new_token_pair = AccountService.refresh_token(args["refresh_token"])
|
||||
return {"result": "success", "data": new_token_pair.model_dump()}
|
||||
except Exception as e:
|
||||
return {"result": "fail", "data": str(e)}, 401
|
||||
|
||||
|
||||
api.add_resource(LoginApi, "/login")
|
||||
api.add_resource(LogoutApi, "/logout")
|
||||
api.add_resource(RefreshTokenApi, "/refresh-token")
|
||||
|
||||
@ -9,7 +9,7 @@ from flask_restful import Resource
|
||||
from configs import dify_config
|
||||
from constants.languages import languages
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import extract_remote_ip
|
||||
from libs.helper import get_remote_ip
|
||||
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
|
||||
from models.account import Account, AccountStatus
|
||||
from services.account_service import AccountService, RegisterService, TenantService
|
||||
@ -81,14 +81,9 @@ class OAuthCallback(Resource):
|
||||
|
||||
TenantService.create_owner_tenant_if_not_exist(account)
|
||||
|
||||
token_pair = AccountService.login(
|
||||
account=account,
|
||||
ip_address=extract_remote_ip(request),
|
||||
)
|
||||
token = AccountService.login(account, ip_address=get_remote_ip(request))
|
||||
|
||||
return redirect(
|
||||
f"{dify_config.CONSOLE_WEB_URL}?access_token={token_pair.access_token}&refresh_token={token_pair.refresh_token}"
|
||||
)
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}?console_token={token}")
|
||||
|
||||
|
||||
def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]:
|
||||
|
||||
@ -617,8 +617,6 @@ class DatasetRetrievalSettingApi(Resource):
|
||||
| VectorType.CHROMA
|
||||
| VectorType.TENCENT
|
||||
| VectorType.PGVECTO_RS
|
||||
| VectorType.BAIDU
|
||||
| VectorType.VIKINGDB
|
||||
):
|
||||
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
|
||||
case (
|
||||
@ -655,8 +653,6 @@ class DatasetRetrievalSettingMockApi(Resource):
|
||||
| VectorType.CHROMA
|
||||
| VectorType.TENCENT
|
||||
| VectorType.PGVECTO_RS
|
||||
| VectorType.BAIDU
|
||||
| VectorType.VIKINGDB
|
||||
):
|
||||
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
|
||||
case (
|
||||
|
||||
@ -1,24 +1,88 @@
|
||||
from flask_restful import Resource
|
||||
import logging
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal, reqparse
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase
|
||||
from controllers.console.app.error import (
|
||||
CompletionRequestError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.console.datasets.error import DatasetNotInitializedError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.errors.error import (
|
||||
LLMBadRequestError,
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from fields.hit_testing_fields import hit_testing_record_fields
|
||||
from libs.login import login_required
|
||||
from services.dataset_service import DatasetService
|
||||
from services.hit_testing_service import HitTestingService
|
||||
|
||||
|
||||
class HitTestingApi(Resource, DatasetsHitTestingBase):
|
||||
class HitTestingApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
|
||||
dataset = self.get_and_validate_dataset(dataset_id_str)
|
||||
args = self.parse_args()
|
||||
self.hit_testing_args_check(args)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
return self.perform_hit_testing(dataset, args)
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("query", type=str, location="json")
|
||||
parser.add_argument("retrieval_model", type=dict, required=False, location="json")
|
||||
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
HitTestingService.hit_testing_args_check(args)
|
||||
|
||||
try:
|
||||
response = HitTestingService.retrieve(
|
||||
dataset=dataset,
|
||||
query=args["query"],
|
||||
account=current_user,
|
||||
retrieval_model=args["retrieval_model"],
|
||||
external_retrieval_model=args["external_retrieval_model"],
|
||||
limit=10,
|
||||
)
|
||||
|
||||
return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)}
|
||||
except services.errors.index.IndexNotInitializedError:
|
||||
raise DatasetNotInitializedError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model or Reranking Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider."
|
||||
)
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise ValueError(str(e))
|
||||
except Exception as e:
|
||||
logging.exception("Hit testing failed.")
|
||||
raise InternalServerError(str(e))
|
||||
|
||||
|
||||
api.add_resource(HitTestingApi, "/datasets/<uuid:dataset_id>/hit-testing")
|
||||
|
||||
@ -1,85 +0,0 @@
|
||||
import logging
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restful import marshal, reqparse
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services.dataset_service
|
||||
from controllers.console.app.error import (
|
||||
CompletionRequestError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.console.datasets.error import DatasetNotInitializedError
|
||||
from core.errors.error import (
|
||||
LLMBadRequestError,
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from fields.hit_testing_fields import hit_testing_record_fields
|
||||
from services.dataset_service import DatasetService
|
||||
from services.hit_testing_service import HitTestingService
|
||||
|
||||
|
||||
class DatasetsHitTestingBase:
|
||||
@staticmethod
|
||||
def get_and_validate_dataset(dataset_id: str):
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def hit_testing_args_check(args):
|
||||
HitTestingService.hit_testing_args_check(args)
|
||||
|
||||
@staticmethod
|
||||
def parse_args():
|
||||
parser = reqparse.RequestParser()
|
||||
|
||||
parser.add_argument("query", type=str, location="json")
|
||||
parser.add_argument("retrieval_model", type=dict, required=False, location="json")
|
||||
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
|
||||
return parser.parse_args()
|
||||
|
||||
@staticmethod
|
||||
def perform_hit_testing(dataset, args):
|
||||
try:
|
||||
response = HitTestingService.retrieve(
|
||||
dataset=dataset,
|
||||
query=args["query"],
|
||||
account=current_user,
|
||||
retrieval_model=args["retrieval_model"],
|
||||
external_retrieval_model=args["external_retrieval_model"],
|
||||
limit=10,
|
||||
)
|
||||
return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)}
|
||||
except services.errors.index.IndexNotInitializedError:
|
||||
raise DatasetNotInitializedError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model or Reranking Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider."
|
||||
)
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise ValueError(str(e))
|
||||
except Exception as e:
|
||||
logging.exception("Hit testing failed.")
|
||||
raise InternalServerError(str(e))
|
||||
@ -4,7 +4,7 @@ from flask import request
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
from configs import dify_config
|
||||
from libs.helper import StrLen, email, extract_remote_ip
|
||||
from libs.helper import StrLen, email, get_remote_ip
|
||||
from libs.password import valid_password
|
||||
from models.model import DifySetup
|
||||
from services.account_service import RegisterService, TenantService
|
||||
@ -46,7 +46,7 @@ class SetupApi(Resource):
|
||||
|
||||
# setup
|
||||
RegisterService.setup(
|
||||
email=args["email"], name=args["name"], password=args["password"], ip_address=extract_remote_ip(request)
|
||||
email=args["email"], name=args["name"], password=args["password"], ip_address=get_remote_ip(request)
|
||||
)
|
||||
|
||||
return {"result": "success"}, 201
|
||||
|
||||
@ -5,6 +5,7 @@ from libs.external_api import ExternalApi
|
||||
bp = Blueprint("service_api", __name__, url_prefix="/v1")
|
||||
api = ExternalApi(bp)
|
||||
|
||||
|
||||
from . import index
|
||||
from .app import app, audio, completion, conversation, file, message, workflow
|
||||
from .dataset import dataset, document, hit_testing, segment
|
||||
from .dataset import dataset, document, segment
|
||||
|
||||
@ -4,6 +4,7 @@ from flask_restful import Resource, reqparse
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from constants import UUID_NIL
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.app.error import (
|
||||
AppUnavailableError,
|
||||
@ -107,6 +108,7 @@ class ChatApi(Resource):
|
||||
parser.add_argument("conversation_id", type=uuid_value, location="json")
|
||||
parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
|
||||
parser.add_argument("auto_generate_name", type=bool, required=False, default=True, location="json")
|
||||
parser.add_argument("parent_message_id", type=uuid_value, required=False, default=UUID_NIL, location="json")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
@ -1,17 +0,0 @@
|
||||
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.wraps import DatasetApiResource
|
||||
|
||||
|
||||
class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase):
|
||||
def post(self, tenant_id, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
|
||||
dataset = self.get_and_validate_dataset(dataset_id_str)
|
||||
args = self.parse_args()
|
||||
self.hit_testing_args_check(args)
|
||||
|
||||
return self.perform_hit_testing(dataset, args)
|
||||
|
||||
|
||||
api.add_resource(HitTestingApi, "/datasets/<uuid:dataset_id>/hit-testing")
|
||||
@ -62,8 +62,6 @@ class CotAgentOutputParser:
|
||||
thought_str = "thought:"
|
||||
thought_idx = 0
|
||||
|
||||
last_character = ""
|
||||
|
||||
for response in llm_response:
|
||||
if response.delta.usage:
|
||||
usage_dict["usage"] = response.delta.usage
|
||||
@ -76,38 +74,35 @@ class CotAgentOutputParser:
|
||||
while index < len(response):
|
||||
steps = 1
|
||||
delta = response[index : index + steps]
|
||||
yield_delta = False
|
||||
last_character = response[index - 1] if index > 0 else ""
|
||||
|
||||
if delta == "`":
|
||||
last_character = delta
|
||||
code_block_cache += delta
|
||||
code_block_delimiter_count += 1
|
||||
else:
|
||||
if not in_code_block:
|
||||
if code_block_delimiter_count > 0:
|
||||
last_character = delta
|
||||
yield code_block_cache
|
||||
code_block_cache = ""
|
||||
else:
|
||||
last_character = delta
|
||||
code_block_cache += delta
|
||||
code_block_delimiter_count = 0
|
||||
|
||||
if not in_code_block and not in_json:
|
||||
if delta.lower() == action_str[action_idx] and action_idx == 0:
|
||||
if last_character not in {"\n", " ", ""}:
|
||||
yield_delta = True
|
||||
else:
|
||||
last_character = delta
|
||||
action_cache += delta
|
||||
action_idx += 1
|
||||
if action_idx == len(action_str):
|
||||
action_cache = ""
|
||||
action_idx = 0
|
||||
index += steps
|
||||
yield delta
|
||||
continue
|
||||
|
||||
action_cache += delta
|
||||
action_idx += 1
|
||||
if action_idx == len(action_str):
|
||||
action_cache = ""
|
||||
action_idx = 0
|
||||
index += steps
|
||||
continue
|
||||
elif delta.lower() == action_str[action_idx] and action_idx > 0:
|
||||
last_character = delta
|
||||
action_cache += delta
|
||||
action_idx += 1
|
||||
if action_idx == len(action_str):
|
||||
@ -117,25 +112,24 @@ class CotAgentOutputParser:
|
||||
continue
|
||||
else:
|
||||
if action_cache:
|
||||
last_character = delta
|
||||
yield action_cache
|
||||
action_cache = ""
|
||||
action_idx = 0
|
||||
|
||||
if delta.lower() == thought_str[thought_idx] and thought_idx == 0:
|
||||
if last_character not in {"\n", " ", ""}:
|
||||
yield_delta = True
|
||||
else:
|
||||
last_character = delta
|
||||
thought_cache += delta
|
||||
thought_idx += 1
|
||||
if thought_idx == len(thought_str):
|
||||
thought_cache = ""
|
||||
thought_idx = 0
|
||||
index += steps
|
||||
yield delta
|
||||
continue
|
||||
|
||||
thought_cache += delta
|
||||
thought_idx += 1
|
||||
if thought_idx == len(thought_str):
|
||||
thought_cache = ""
|
||||
thought_idx = 0
|
||||
index += steps
|
||||
continue
|
||||
elif delta.lower() == thought_str[thought_idx] and thought_idx > 0:
|
||||
last_character = delta
|
||||
thought_cache += delta
|
||||
thought_idx += 1
|
||||
if thought_idx == len(thought_str):
|
||||
@ -145,20 +139,12 @@ class CotAgentOutputParser:
|
||||
continue
|
||||
else:
|
||||
if thought_cache:
|
||||
last_character = delta
|
||||
yield thought_cache
|
||||
thought_cache = ""
|
||||
thought_idx = 0
|
||||
|
||||
if yield_delta:
|
||||
index += steps
|
||||
last_character = delta
|
||||
yield delta
|
||||
continue
|
||||
|
||||
if code_block_delimiter_count == 3:
|
||||
if in_code_block:
|
||||
last_character = delta
|
||||
yield from extra_json_from_code_block(code_block_cache)
|
||||
code_block_cache = ""
|
||||
|
||||
@ -170,10 +156,8 @@ class CotAgentOutputParser:
|
||||
if delta == "{":
|
||||
json_quote_count += 1
|
||||
in_json = True
|
||||
last_character = delta
|
||||
json_cache += delta
|
||||
elif delta == "}":
|
||||
last_character = delta
|
||||
json_cache += delta
|
||||
if json_quote_count > 0:
|
||||
json_quote_count -= 1
|
||||
@ -184,19 +168,16 @@ class CotAgentOutputParser:
|
||||
continue
|
||||
else:
|
||||
if in_json:
|
||||
last_character = delta
|
||||
json_cache += delta
|
||||
|
||||
if got_json:
|
||||
got_json = False
|
||||
last_character = delta
|
||||
yield parse_action(json_cache)
|
||||
json_cache = ""
|
||||
json_quote_count = 0
|
||||
in_json = False
|
||||
|
||||
if not in_code_block and not in_json:
|
||||
last_character = delta
|
||||
yield delta.replace("`", "")
|
||||
|
||||
index += steps
|
||||
|
||||
@ -10,7 +10,6 @@ from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
|
||||
import contexts
|
||||
from constants import UUID_NIL
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
|
||||
from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner
|
||||
@ -123,7 +122,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config),
|
||||
query=query,
|
||||
files=file_objs,
|
||||
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
||||
parent_message_id=args.get("parent_message_id"),
|
||||
user_id=user.id,
|
||||
stream=stream,
|
||||
invoke_from=invoke_from,
|
||||
|
||||
@ -56,7 +56,6 @@ from models.account import Account
|
||||
from models.model import Conversation, EndUser, Message
|
||||
from models.workflow import (
|
||||
Workflow,
|
||||
WorkflowNodeExecution,
|
||||
WorkflowRunStatus,
|
||||
)
|
||||
|
||||
@ -73,7 +72,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
_workflow: Workflow
|
||||
_user: Union[Account, EndUser]
|
||||
_workflow_system_variables: dict[SystemVariableKey, Any]
|
||||
_wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -117,7 +115,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
}
|
||||
|
||||
self._task_state = WorkflowTaskState()
|
||||
self._wip_workflow_node_executions = {}
|
||||
|
||||
self._conversation_name_generate_thread = None
|
||||
|
||||
|
||||
@ -8,7 +8,6 @@ from typing import Any, Literal, Union, overload
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
|
||||
from constants import UUID_NIL
|
||||
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
|
||||
@ -128,7 +127,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config),
|
||||
query=query,
|
||||
files=file_objs,
|
||||
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
||||
parent_message_id=args.get("parent_message_id"),
|
||||
user_id=user.id,
|
||||
stream=stream,
|
||||
invoke_from=invoke_from,
|
||||
|
||||
@ -8,7 +8,6 @@ from typing import Any, Literal, Union, overload
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
|
||||
from constants import UUID_NIL
|
||||
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||
@ -129,7 +128,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config),
|
||||
query=query,
|
||||
files=file_objs,
|
||||
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
||||
parent_message_id=args.get("parent_message_id"),
|
||||
user_id=user.id,
|
||||
stream=stream,
|
||||
invoke_from=invoke_from,
|
||||
|
||||
@ -52,7 +52,6 @@ from models.workflow import (
|
||||
Workflow,
|
||||
WorkflowAppLog,
|
||||
WorkflowAppLogCreatedFrom,
|
||||
WorkflowNodeExecution,
|
||||
WorkflowRun,
|
||||
WorkflowRunStatus,
|
||||
)
|
||||
@ -70,7 +69,6 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
_task_state: WorkflowTaskState
|
||||
_application_generate_entity: WorkflowAppGenerateEntity
|
||||
_workflow_system_variables: dict[SystemVariableKey, Any]
|
||||
_wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -105,7 +103,6 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
}
|
||||
|
||||
self._task_state = WorkflowTaskState()
|
||||
self._wip_workflow_node_executions = {}
|
||||
|
||||
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
|
||||
"""
|
||||
|
||||
@ -2,9 +2,8 @@ from collections.abc import Mapping
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from constants import UUID_NIL
|
||||
from core.app.app_config.entities import AppConfig, EasyUIBasedAppConfig, WorkflowUIBasedAppConfig
|
||||
from core.entities.provider_configuration import ProviderModelBundle
|
||||
from core.file.file_obj import FileVar
|
||||
@ -117,36 +116,13 @@ class EasyUIBasedAppGenerateEntity(AppGenerateEntity):
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class ConversationAppGenerateEntity(AppGenerateEntity):
|
||||
"""
|
||||
Base entity for conversation-based app generation.
|
||||
"""
|
||||
|
||||
conversation_id: Optional[str] = None
|
||||
parent_message_id: Optional[str] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Starting from v0.9.0, parent_message_id is used to support message regeneration for internal chat API."
|
||||
"For service API, we need to ensure its forward compatibility, "
|
||||
"so passing in the parent_message_id as request arg is not supported for now. "
|
||||
"It needs to be set to UUID_NIL so that the subsequent processing will treat it as legacy messages."
|
||||
),
|
||||
)
|
||||
|
||||
@field_validator("parent_message_id")
|
||||
@classmethod
|
||||
def validate_parent_message_id(cls, v, info: ValidationInfo):
|
||||
if info.data.get("invoke_from") == InvokeFrom.SERVICE_API and v != UUID_NIL:
|
||||
raise ValueError("parent_message_id should be UUID_NIL for service API")
|
||||
return v
|
||||
|
||||
|
||||
class ChatAppGenerateEntity(ConversationAppGenerateEntity, EasyUIBasedAppGenerateEntity):
|
||||
class ChatAppGenerateEntity(EasyUIBasedAppGenerateEntity):
|
||||
"""
|
||||
Chat Application Generate Entity.
|
||||
"""
|
||||
|
||||
pass
|
||||
conversation_id: Optional[str] = None
|
||||
parent_message_id: Optional[str] = None
|
||||
|
||||
|
||||
class CompletionAppGenerateEntity(EasyUIBasedAppGenerateEntity):
|
||||
@ -157,15 +133,16 @@ class CompletionAppGenerateEntity(EasyUIBasedAppGenerateEntity):
|
||||
pass
|
||||
|
||||
|
||||
class AgentChatAppGenerateEntity(ConversationAppGenerateEntity, EasyUIBasedAppGenerateEntity):
|
||||
class AgentChatAppGenerateEntity(EasyUIBasedAppGenerateEntity):
|
||||
"""
|
||||
Agent Chat Application Generate Entity.
|
||||
"""
|
||||
|
||||
pass
|
||||
conversation_id: Optional[str] = None
|
||||
parent_message_id: Optional[str] = None
|
||||
|
||||
|
||||
class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity):
|
||||
class AdvancedChatAppGenerateEntity(AppGenerateEntity):
|
||||
"""
|
||||
Advanced Chat Application Generate Entity.
|
||||
"""
|
||||
@ -173,6 +150,8 @@ class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity):
|
||||
# app config
|
||||
app_config: WorkflowUIBasedAppConfig
|
||||
|
||||
conversation_id: Optional[str] = None
|
||||
parent_message_id: Optional[str] = None
|
||||
workflow_run_id: Optional[str] = None
|
||||
query: str
|
||||
|
||||
|
||||
@ -1,10 +1,8 @@
|
||||
import logging
|
||||
from threading import Thread
|
||||
from typing import Optional, Union
|
||||
|
||||
from flask import Flask, current_app
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AdvancedChatAppGenerateEntity,
|
||||
AgentChatAppGenerateEntity,
|
||||
@ -85,9 +83,7 @@ class MessageCycleManage:
|
||||
name = LLMGenerator.generate_conversation_name(app_model.tenant_id, query)
|
||||
conversation.name = name
|
||||
except Exception as e:
|
||||
if dify_config.DEBUG:
|
||||
logging.exception(f"generate conversation name failed: {e}")
|
||||
pass
|
||||
logging.exception(f"generate conversation name failed: {e}")
|
||||
|
||||
db.session.merge(conversation)
|
||||
db.session.commit()
|
||||
|
||||
@ -57,7 +57,6 @@ class WorkflowCycleManage:
|
||||
_user: Union[Account, EndUser]
|
||||
_task_state: WorkflowTaskState
|
||||
_workflow_system_variables: dict[SystemVariableKey, Any]
|
||||
_wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
|
||||
|
||||
def _handle_workflow_run_start(self) -> WorkflowRun:
|
||||
max_sequence = (
|
||||
@ -252,8 +251,6 @@ class WorkflowCycleManage:
|
||||
db.session.refresh(workflow_node_execution)
|
||||
db.session.close()
|
||||
|
||||
self._wip_workflow_node_executions[workflow_node_execution.node_execution_id] = workflow_node_execution
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
def _handle_workflow_node_execution_success(self, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution:
|
||||
@ -266,36 +263,20 @@ class WorkflowCycleManage:
|
||||
|
||||
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
||||
outputs = WorkflowEntry.handle_special_values(event.outputs)
|
||||
execution_metadata = (
|
||||
json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None
|
||||
)
|
||||
finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
elapsed_time = (finished_at - event.start_at).total_seconds()
|
||||
|
||||
db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update(
|
||||
{
|
||||
WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.SUCCEEDED.value,
|
||||
WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None,
|
||||
WorkflowNodeExecution.process_data: json.dumps(event.process_data) if event.process_data else None,
|
||||
WorkflowNodeExecution.outputs: json.dumps(outputs) if outputs else None,
|
||||
WorkflowNodeExecution.execution_metadata: execution_metadata,
|
||||
WorkflowNodeExecution.finished_at: finished_at,
|
||||
WorkflowNodeExecution.elapsed_time: elapsed_time,
|
||||
}
|
||||
)
|
||||
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
|
||||
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
|
||||
workflow_node_execution.process_data = json.dumps(event.process_data) if event.process_data else None
|
||||
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
|
||||
workflow_node_execution.execution_metadata = execution_metadata
|
||||
workflow_node_execution.finished_at = finished_at
|
||||
workflow_node_execution.elapsed_time = elapsed_time
|
||||
workflow_node_execution.execution_metadata = (
|
||||
json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None
|
||||
)
|
||||
workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - event.start_at).total_seconds()
|
||||
|
||||
self._wip_workflow_node_executions.pop(workflow_node_execution.node_execution_id)
|
||||
db.session.commit()
|
||||
db.session.refresh(workflow_node_execution)
|
||||
db.session.close()
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
@ -309,33 +290,18 @@ class WorkflowCycleManage:
|
||||
|
||||
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
||||
outputs = WorkflowEntry.handle_special_values(event.outputs)
|
||||
finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
elapsed_time = (finished_at - event.start_at).total_seconds()
|
||||
|
||||
db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update(
|
||||
{
|
||||
WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.FAILED.value,
|
||||
WorkflowNodeExecution.error: event.error,
|
||||
WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None,
|
||||
WorkflowNodeExecution.process_data: json.dumps(event.process_data) if event.process_data else None,
|
||||
WorkflowNodeExecution.outputs: json.dumps(outputs) if outputs else None,
|
||||
WorkflowNodeExecution.finished_at: finished_at,
|
||||
WorkflowNodeExecution.elapsed_time: elapsed_time,
|
||||
}
|
||||
)
|
||||
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
|
||||
workflow_node_execution.error = event.error
|
||||
workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
|
||||
workflow_node_execution.process_data = json.dumps(event.process_data) if event.process_data else None
|
||||
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
|
||||
workflow_node_execution.finished_at = finished_at
|
||||
workflow_node_execution.elapsed_time = elapsed_time
|
||||
workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - event.start_at).total_seconds()
|
||||
|
||||
self._wip_workflow_node_executions.pop(workflow_node_execution.node_execution_id)
|
||||
db.session.commit()
|
||||
db.session.refresh(workflow_node_execution)
|
||||
db.session.close()
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
@ -712,7 +678,17 @@ class WorkflowCycleManage:
|
||||
:param node_execution_id: workflow node execution id
|
||||
:return:
|
||||
"""
|
||||
workflow_node_execution = self._wip_workflow_node_executions.get(node_execution_id)
|
||||
workflow_node_execution = (
|
||||
db.session.query(WorkflowNodeExecution)
|
||||
.filter(
|
||||
WorkflowNodeExecution.tenant_id == self._application_generate_entity.app_config.tenant_id,
|
||||
WorkflowNodeExecution.app_id == self._application_generate_entity.app_config.app_id,
|
||||
WorkflowNodeExecution.workflow_id == self._workflow.id,
|
||||
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
WorkflowNodeExecution.node_execution_id == node_execution_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not workflow_node_execution:
|
||||
raise Exception(f"Workflow node execution not found: {node_execution_id}")
|
||||
|
||||
@ -5,7 +5,6 @@ from typing import Optional, cast
|
||||
import numpy as np
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from configs import dify_config
|
||||
from core.embedding.embedding_constant import EmbeddingInputType
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
@ -111,8 +110,6 @@ class CacheEmbedding(Embeddings):
|
||||
embedding_results = embedding_result.embeddings[0]
|
||||
embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist()
|
||||
except Exception as ex:
|
||||
if dify_config.DEBUG:
|
||||
logging.exception(f"Failed to embed query text: {ex}")
|
||||
raise ex
|
||||
|
||||
try:
|
||||
@ -125,8 +122,6 @@ class CacheEmbedding(Embeddings):
|
||||
encoded_str = encoded_vector.decode("utf-8")
|
||||
redis_client.setex(embedding_cache_key, 600, encoded_str)
|
||||
except Exception as ex:
|
||||
if dify_config.DEBUG:
|
||||
logging.exception("Failed to add embedding to redis %s", ex)
|
||||
raise ex
|
||||
logging.exception("Failed to add embedding to redis %s", ex)
|
||||
|
||||
return embedding_results
|
||||
|
||||
@ -60,8 +60,8 @@ class TokenBufferMemory:
|
||||
thread_messages = extract_thread_messages(messages)
|
||||
|
||||
# for newly created message, its answer is temporarily empty, we don't need to add it to memory
|
||||
if thread_messages and not thread_messages[0].answer:
|
||||
thread_messages.pop(0)
|
||||
if thread_messages and not thread_messages[-1].answer:
|
||||
thread_messages.pop()
|
||||
|
||||
messages = list(reversed(thread_messages))
|
||||
|
||||
|
||||
@ -1098,14 +1098,6 @@ LLM_BASE_MODELS = [
|
||||
ModelPropertyKey.CONTEXT_SIZE: 128000,
|
||||
},
|
||||
parameter_rules=[
|
||||
ParameterRule(
|
||||
name="temperature",
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
|
||||
),
|
||||
ParameterRule(
|
||||
name="top_p",
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
|
||||
),
|
||||
ParameterRule(
|
||||
name="response_format",
|
||||
label=I18nObject(zh_Hans="回复格式", en_US="response_format"),
|
||||
@ -1143,14 +1135,6 @@ LLM_BASE_MODELS = [
|
||||
ModelPropertyKey.CONTEXT_SIZE: 128000,
|
||||
},
|
||||
parameter_rules=[
|
||||
ParameterRule(
|
||||
name="temperature",
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
|
||||
),
|
||||
ParameterRule(
|
||||
name="top_p",
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
|
||||
),
|
||||
ParameterRule(
|
||||
name="response_format",
|
||||
label=I18nObject(zh_Hans="回复格式", en_US="response_format"),
|
||||
|
||||
@ -119,15 +119,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
try:
|
||||
client = AzureOpenAI(**self._to_credential_kwargs(credentials))
|
||||
|
||||
if model.startswith("o1"):
|
||||
client.chat.completions.create(
|
||||
messages=[{"role": "user", "content": "ping"}],
|
||||
model=model,
|
||||
temperature=1,
|
||||
max_completion_tokens=20,
|
||||
stream=False,
|
||||
)
|
||||
elif ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value:
|
||||
if ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value:
|
||||
# chat model
|
||||
client.chat.completions.create(
|
||||
messages=[{"role": "user", "content": "ping"}],
|
||||
|
||||
@ -18,7 +18,6 @@ supported_model_types:
|
||||
- text-embedding
|
||||
configurate_methods:
|
||||
- predefined-model
|
||||
- customizable-model
|
||||
provider_credential_schema:
|
||||
credential_form_schemas:
|
||||
- variable: fireworks_api_key
|
||||
@ -29,75 +28,3 @@ provider_credential_schema:
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Key
|
||||
en_US: Enter your API Key
|
||||
model_credential_schema:
|
||||
model:
|
||||
label:
|
||||
en_US: Model URL
|
||||
zh_Hans: 模型URL
|
||||
placeholder:
|
||||
en_US: Enter your Model URL
|
||||
zh_Hans: 输入模型URL
|
||||
credential_form_schemas:
|
||||
- variable: model_label_zh_Hanns
|
||||
label:
|
||||
zh_Hans: 模型中文名称
|
||||
en_US: The zh_Hans of Model
|
||||
required: true
|
||||
type: text-input
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的模型中文名称
|
||||
en_US: Enter your zh_Hans of Model
|
||||
- variable: model_label_en_US
|
||||
label:
|
||||
zh_Hans: 模型英文名称
|
||||
en_US: The en_US of Model
|
||||
required: true
|
||||
type: text-input
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的模型英文名称
|
||||
en_US: Enter your en_US of Model
|
||||
- variable: fireworks_api_key
|
||||
label:
|
||||
en_US: API Key
|
||||
type: secret-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Key
|
||||
en_US: Enter your API Key
|
||||
- variable: context_size
|
||||
label:
|
||||
zh_Hans: 模型上下文长度
|
||||
en_US: Model context size
|
||||
required: true
|
||||
type: text-input
|
||||
default: '4096'
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的模型上下文长度
|
||||
en_US: Enter your Model context size
|
||||
- variable: max_tokens
|
||||
label:
|
||||
zh_Hans: 最大 token 上限
|
||||
en_US: Upper bound for max tokens
|
||||
default: '4096'
|
||||
type: text-input
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- variable: function_calling_type
|
||||
label:
|
||||
en_US: Function calling
|
||||
type: select
|
||||
required: false
|
||||
default: no_call
|
||||
options:
|
||||
- value: no_call
|
||||
label:
|
||||
en_US: Not Support
|
||||
zh_Hans: 不支持
|
||||
- value: function_call
|
||||
label:
|
||||
en_US: Support
|
||||
zh_Hans: 支持
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
|
||||
@ -43,4 +43,3 @@ pricing:
|
||||
output: '0.2'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
deprecated: true
|
||||
|
||||
@ -8,8 +8,7 @@ from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, Cho
|
||||
from openai.types.chat.chat_completion_message import FunctionCall
|
||||
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
@ -21,15 +20,6 @@ from core.model_runtime.entities.message_entities import (
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import (
|
||||
AIModelEntity,
|
||||
FetchFrom,
|
||||
ModelFeature,
|
||||
ModelPropertyKey,
|
||||
ModelType,
|
||||
ParameterRule,
|
||||
ParameterType,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.model_providers.fireworks._common import _CommonFireworks
|
||||
@ -618,50 +608,3 @@ class FireworksLargeLanguageModel(_CommonFireworks, LargeLanguageModel):
|
||||
num_tokens += self._get_num_tokens_by_gpt2(required_field)
|
||||
|
||||
return num_tokens
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
||||
return AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(
|
||||
en_US=credentials.get("model_label_en_US", model),
|
||||
zh_Hans=credentials.get("model_label_zh_Hanns", model),
|
||||
),
|
||||
model_type=ModelType.LLM,
|
||||
features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL]
|
||||
if credentials.get("function_calling_type") == "function_call"
|
||||
else [],
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 4096)),
|
||||
ModelPropertyKey.MODE: LLMMode.CHAT.value,
|
||||
},
|
||||
parameter_rules=[
|
||||
ParameterRule(
|
||||
name="temperature",
|
||||
use_template="temperature",
|
||||
label=I18nObject(en_US="Temperature", zh_Hans="温度"),
|
||||
type=ParameterType.FLOAT,
|
||||
),
|
||||
ParameterRule(
|
||||
name="max_tokens",
|
||||
use_template="max_tokens",
|
||||
default=512,
|
||||
min=1,
|
||||
max=int(credentials.get("max_tokens", 4096)),
|
||||
label=I18nObject(en_US="Max Tokens", zh_Hans="最大标记"),
|
||||
type=ParameterType.INT,
|
||||
),
|
||||
ParameterRule(
|
||||
name="top_p",
|
||||
use_template="top_p",
|
||||
label=I18nObject(en_US="Top P", zh_Hans="Top P"),
|
||||
type=ParameterType.FLOAT,
|
||||
),
|
||||
ParameterRule(
|
||||
name="top_k",
|
||||
use_template="top_k",
|
||||
label=I18nObject(en_US="Top K", zh_Hans="Top K"),
|
||||
type=ParameterType.FLOAT,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
@ -1,46 +0,0 @@
|
||||
model: accounts/fireworks/models/qwen2p5-72b-instruct
|
||||
label:
|
||||
zh_Hans: Qwen2.5 72B Instruct
|
||||
en_US: Qwen2.5 72B Instruct
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32768
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
- name: context_length_exceeded_behavior
|
||||
default: None
|
||||
label:
|
||||
zh_Hans: 上下文长度超出行为
|
||||
en_US: Context Length Exceeded Behavior
|
||||
help:
|
||||
zh_Hans: 上下文长度超出行为
|
||||
en_US: Context Length Exceeded Behavior
|
||||
type: string
|
||||
options:
|
||||
- None
|
||||
- truncate
|
||||
- error
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.9'
|
||||
output: '0.9'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@ -32,6 +32,15 @@ parameter_rules:
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: stream
|
||||
label:
|
||||
zh_Hans: 流式输出
|
||||
en_US: Stream
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 流式输出允许模型在生成文本的过程中逐步返回结果,而不是一次性生成全部结果后再返回。
|
||||
en_US: Streaming output allows the model to return results incrementally as it generates text, rather than generating all the results at once.
|
||||
default: false
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
||||
@ -32,6 +32,15 @@ parameter_rules:
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: stream
|
||||
label:
|
||||
zh_Hans: 流式输出
|
||||
en_US: Stream
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 流式输出允许模型在生成文本的过程中逐步返回结果,而不是一次性生成全部结果后再返回。
|
||||
en_US: Streaming output allows the model to return results incrementally as it generates text, rather than generating all the results at once.
|
||||
default: false
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
||||
@ -32,6 +32,15 @@ parameter_rules:
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: stream
|
||||
label:
|
||||
zh_Hans: 流式输出
|
||||
en_US: Stream
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 流式输出允许模型在生成文本的过程中逐步返回结果,而不是一次性生成全部结果后再返回。
|
||||
en_US: Streaming output allows the model to return results incrementally as it generates text, rather than generating all the results at once.
|
||||
default: false
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
||||
@ -32,6 +32,15 @@ parameter_rules:
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: stream
|
||||
label:
|
||||
zh_Hans: 流式输出
|
||||
en_US: Stream
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 流式输出允许模型在生成文本的过程中逐步返回结果,而不是一次性生成全部结果后再返回。
|
||||
en_US: Streaming output allows the model to return results incrementally as it generates text, rather than generating all the results at once.
|
||||
default: false
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
||||
@ -32,6 +32,15 @@ parameter_rules:
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: stream
|
||||
label:
|
||||
zh_Hans: 流式输出
|
||||
en_US: Stream
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 流式输出允许模型在生成文本的过程中逐步返回结果,而不是一次性生成全部结果后再返回。
|
||||
en_US: Streaming output allows the model to return results incrementally as it generates text, rather than generating all the results at once.
|
||||
default: false
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
||||
@ -32,6 +32,15 @@ parameter_rules:
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: stream
|
||||
label:
|
||||
zh_Hans: 流式输出
|
||||
en_US: Stream
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 流式输出允许模型在生成文本的过程中逐步返回结果,而不是一次性生成全部结果后再返回。
|
||||
en_US: Streaming output allows the model to return results incrementally as it generates text, rather than generating all the results at once.
|
||||
default: false
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
||||
@ -32,6 +32,15 @@ parameter_rules:
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: stream
|
||||
label:
|
||||
zh_Hans: 流式输出
|
||||
en_US: Stream
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 流式输出允许模型在生成文本的过程中逐步返回结果,而不是一次性生成全部结果后再返回。
|
||||
en_US: Streaming output allows the model to return results incrementally as it generates text, rather than generating all the results at once.
|
||||
default: false
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
||||
@ -32,6 +32,15 @@ parameter_rules:
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: stream
|
||||
label:
|
||||
zh_Hans: 流式输出
|
||||
en_US: Stream
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 流式输出允许模型在生成文本的过程中逐步返回结果,而不是一次性生成全部结果后再返回。
|
||||
en_US: Streaming output allows the model to return results incrementally as it generates text, rather than generating all the results at once.
|
||||
default: false
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
||||
@ -32,6 +32,15 @@ parameter_rules:
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: stream
|
||||
label:
|
||||
zh_Hans: 流式输出
|
||||
en_US: Stream
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 流式输出允许模型在生成文本的过程中逐步返回结果,而不是一次性生成全部结果后再返回。
|
||||
en_US: Streaming output allows the model to return results incrementally as it generates text, rather than generating all the results at once.
|
||||
default: false
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
||||
@ -32,6 +32,15 @@ parameter_rules:
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: stream
|
||||
label:
|
||||
zh_Hans: 流式输出
|
||||
en_US: Stream
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 流式输出允许模型在生成文本的过程中逐步返回结果,而不是一次性生成全部结果后再返回。
|
||||
en_US: Streaming output allows the model to return results incrementally as it generates text, rather than generating all the results at once.
|
||||
default: false
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
||||
@ -32,6 +32,15 @@ parameter_rules:
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: stream
|
||||
label:
|
||||
zh_Hans: 流式输出
|
||||
en_US: Stream
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 流式输出允许模型在生成文本的过程中逐步返回结果,而不是一次性生成全部结果后再返回。
|
||||
en_US: Streaming output allows the model to return results incrementally as it generates text, rather than generating all the results at once.
|
||||
default: false
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
||||
@ -32,6 +32,15 @@ parameter_rules:
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: stream
|
||||
label:
|
||||
zh_Hans: 流式输出
|
||||
en_US: Stream
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 流式输出允许模型在生成文本的过程中逐步返回结果,而不是一次性生成全部结果后再返回。
|
||||
en_US: Streaming output allows the model to return results incrementally as it generates text, rather than generating all the results at once.
|
||||
default: false
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
||||
@ -32,6 +32,15 @@ parameter_rules:
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: stream
|
||||
label:
|
||||
zh_Hans: 流式输出
|
||||
en_US: Stream
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 流式输出允许模型在生成文本的过程中逐步返回结果,而不是一次性生成全部结果后再返回。
|
||||
en_US: Streaming output allows the model to return results incrementally as it generates text, rather than generating all the results at once.
|
||||
default: false
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
||||
@ -27,6 +27,15 @@ parameter_rules:
|
||||
default: 4096
|
||||
min: 1
|
||||
max: 4096
|
||||
- name: stream
|
||||
label:
|
||||
zh_Hans: 流式输出
|
||||
en_US: Stream
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 流式输出允许模型在生成文本的过程中逐步返回结果,而不是一次性生成全部结果后再返回。
|
||||
en_US: Streaming output allows the model to return results incrementally as it generates text, rather than generating all the results at once.
|
||||
default: false
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
||||
@ -31,6 +31,15 @@ parameter_rules:
|
||||
max: 2048
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: stream
|
||||
label:
|
||||
zh_Hans: 流式输出
|
||||
en_US: Stream
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 流式输出允许模型在生成文本的过程中逐步返回结果,而不是一次性生成全部结果后再返回。
|
||||
en_US: Streaming output allows the model to return results incrementally as it generates text, rather than generating all the results at once.
|
||||
default: false
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
||||
@ -1,44 +0,0 @@
|
||||
model: abab6.5t-chat
|
||||
label:
|
||||
en_US: Abab6.5t-Chat
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
min: 0.01
|
||||
max: 1
|
||||
default: 0.9
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
min: 0.01
|
||||
max: 1
|
||||
default: 0.95
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 3072
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: mask_sensitive_info
|
||||
type: boolean
|
||||
default: true
|
||||
label:
|
||||
zh_Hans: 隐私保护
|
||||
en_US: Moderate
|
||||
help:
|
||||
zh_Hans: 对输出中易涉及隐私问题的文本信息进行打码,目前包括但不限于邮箱、域名、链接、证件号、家庭住址等,默认true,即开启打码
|
||||
en_US: Mask the sensitive info of the generated content, such as email/domain/link/address/phone/id..
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
pricing:
|
||||
input: '0.005'
|
||||
output: '0.005'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
||||
@ -61,8 +61,7 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel):
|
||||
url = f"{self.api_base}?GroupId={group_id}"
|
||||
headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"}
|
||||
|
||||
embedding_type = "db" if input_type == EmbeddingInputType.DOCUMENT else "query"
|
||||
data = {"model": "embo-01", "texts": texts, "type": embedding_type}
|
||||
data = {"model": "embo-01", "texts": texts, "type": "db"}
|
||||
|
||||
try:
|
||||
response = post(url, headers=headers, data=dumps(data))
|
||||
|
||||
@ -19,9 +19,9 @@ class OpenAIProvider(ModelProvider):
|
||||
try:
|
||||
model_instance = self.get_model_instance(ModelType.LLM)
|
||||
|
||||
# Use `gpt-4o-mini` model for validate,
|
||||
# Use `gpt-3.5-turbo` model for validate,
|
||||
# no matter what model you pass in, text completion model or chat model
|
||||
model_instance.validate_credentials(model="gpt-4o-mini", credentials=credentials)
|
||||
model_instance.validate_credentials(model="gpt-3.5-turbo", credentials=credentials)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
|
||||
@ -28,6 +28,15 @@ parameter_rules:
|
||||
zh_Hans: do_sample 为 true 时启用采样策略,do_sample 为 false 时采样策略 temperature、top_p 将不生效。默认值为 true。
|
||||
en_US: When `do_sample` is set to true, the sampling strategy is enabled. When `do_sample` is set to false, the sampling strategies such as `temperature` and `top_p` will not take effect. The default value is true.
|
||||
default: true
|
||||
- name: stream
|
||||
label:
|
||||
zh_Hans: 流处理
|
||||
en_US: Event Stream
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 使用同步调用时,此参数应当设置为 fasle 或者省略。表示模型生成完所有内容后一次性返回所有内容。默认值为 false。如果设置为 true,模型将通过标准 Event Stream ,逐块返回模型生成内容。Event Stream 结束时会返回一条data:[DONE]消息。注意:在模型流式输出生成内容的过程中,我们会分批对模型生成内容进行检测,当检测到违法及不良信息时,API会返回错误码(1301)。开发者识别到错误码(1301),应及时采取(清屏、重启对话)等措施删除生成内容,并确保不将含有违法及不良信息的内容传递给模型继续生成,避免其造成负面影响。
|
||||
en_US: When using synchronous invocation, this parameter should be set to false or omitted. It indicates that the model will return all the generated content at once after the generation is complete. The default value is false. If set to true, the model will return the generated content in chunks via the standard Event Stream. A data:[DONE] message will be sent at the end of the Event Stream.Note:During the model's streaming output process, we will batch check the generated content. If illegal or harmful information is detected, the API will return an error code (1301). Developers who identify error code (1301) should promptly take actions such as clearing the screen or restarting the conversation to delete the generated content. They should also ensure that no illegal or harmful content is passed back to the model for continued generation to avoid negative impacts.
|
||||
default: false
|
||||
- name: return_type
|
||||
label:
|
||||
zh_Hans: 回复类型
|
||||
@ -40,4 +49,3 @@ parameter_rules:
|
||||
options:
|
||||
- text
|
||||
- json_string
|
||||
deprecated: true
|
||||
|
||||
@ -32,6 +32,15 @@ parameter_rules:
|
||||
zh_Hans: do_sample 为 true 时启用采样策略,do_sample 为 false 时采样策略 temperature、top_p 将不生效。默认值为 true。
|
||||
en_US: When `do_sample` is set to true, the sampling strategy is enabled. When `do_sample` is set to false, the sampling strategies such as `temperature` and `top_p` will not take effect. The default value is true.
|
||||
default: true
|
||||
- name: stream
|
||||
label:
|
||||
zh_Hans: 流处理
|
||||
en_US: Event Stream
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 使用同步调用时,此参数应当设置为 fasle 或者省略。表示模型生成完所有内容后一次性返回所有内容。默认值为 false。如果设置为 true,模型将通过标准 Event Stream ,逐块返回模型生成内容。Event Stream 结束时会返回一条data:[DONE]消息。注意:在模型流式输出生成内容的过程中,我们会分批对模型生成内容进行检测,当检测到违法及不良信息时,API会返回错误码(1301)。开发者识别到错误码(1301),应及时采取(清屏、重启对话)等措施删除生成内容,并确保不将含有违法及不良信息的内容传递给模型继续生成,避免其造成负面影响。
|
||||
en_US: When using synchronous invocation, this parameter should be set to false or omitted. It indicates that the model will return all the generated content at once after the generation is complete. The default value is false. If set to true, the model will return the generated content in chunks via the standard Event Stream. A data:[DONE] message will be sent at the end of the Event Stream.Note:During the model's streaming output process, we will batch check the generated content. If illegal or harmful information is detected, the API will return an error code (1301). Developers who identify error code (1301) should promptly take actions such as clearing the screen or restarting the conversation to delete the generated content. They should also ensure that no illegal or harmful content is passed back to the model for continued generation to avoid negative impacts.
|
||||
default: false
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
|
||||
@ -32,6 +32,15 @@ parameter_rules:
|
||||
zh_Hans: do_sample 为 true 时启用采样策略,do_sample 为 false 时采样策略 temperature、top_p 将不生效。默认值为 true。
|
||||
en_US: When `do_sample` is set to true, the sampling strategy is enabled. When `do_sample` is set to false, the sampling strategies such as `temperature` and `top_p` will not take effect. The default value is true.
|
||||
default: true
|
||||
- name: stream
|
||||
label:
|
||||
zh_Hans: 流处理
|
||||
en_US: Event Stream
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 使用同步调用时,此参数应当设置为 fasle 或者省略。表示模型生成完所有内容后一次性返回所有内容。默认值为 false。如果设置为 true,模型将通过标准 Event Stream ,逐块返回模型生成内容。Event Stream 结束时会返回一条data:[DONE]消息。注意:在模型流式输出生成内容的过程中,我们会分批对模型生成内容进行检测,当检测到违法及不良信息时,API会返回错误码(1301)。开发者识别到错误码(1301),应及时采取(清屏、重启对话)等措施删除生成内容,并确保不将含有违法及不良信息的内容传递给模型继续生成,避免其造成负面影响。
|
||||
en_US: When using synchronous invocation, this parameter should be set to false or omitted. It indicates that the model will return all the generated content at once after the generation is complete. The default value is false. If set to true, the model will return the generated content in chunks via the standard Event Stream. A data:[DONE] message will be sent at the end of the Event Stream.Note:During the model's streaming output process, we will batch check the generated content. If illegal or harmful information is detected, the API will return an error code (1301). Developers who identify error code (1301) should promptly take actions such as clearing the screen or restarting the conversation to delete the generated content. They should also ensure that no illegal or harmful content is passed back to the model for continued generation to avoid negative impacts.
|
||||
default: false
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
|
||||
@ -32,6 +32,15 @@ parameter_rules:
|
||||
zh_Hans: do_sample 为 true 时启用采样策略,do_sample 为 false 时采样策略 temperature、top_p 将不生效。默认值为 true。
|
||||
en_US: When `do_sample` is set to true, the sampling strategy is enabled. When `do_sample` is set to false, the sampling strategies such as `temperature` and `top_p` will not take effect. The default value is true.
|
||||
default: true
|
||||
- name: stream
|
||||
label:
|
||||
zh_Hans: 流处理
|
||||
en_US: Event Stream
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 使用同步调用时,此参数应当设置为 fasle 或者省略。表示模型生成完所有内容后一次性返回所有内容。默认值为 false。如果设置为 true,模型将通过标准 Event Stream ,逐块返回模型生成内容。Event Stream 结束时会返回一条data:[DONE]消息。注意:在模型流式输出生成内容的过程中,我们会分批对模型生成内容进行检测,当检测到违法及不良信息时,API会返回错误码(1301)。开发者识别到错误码(1301),应及时采取(清屏、重启对话)等措施删除生成内容,并确保不将含有违法及不良信息的内容传递给模型继续生成,避免其造成负面影响。
|
||||
en_US: When using synchronous invocation, this parameter should be set to false or omitted. It indicates that the model will return all the generated content at once after the generation is complete. The default value is false. If set to true, the model will return the generated content in chunks via the standard Event Stream. A data:[DONE] message will be sent at the end of the Event Stream.Note:During the model's streaming output process, we will batch check the generated content. If illegal or harmful information is detected, the API will return an error code (1301). Developers who identify error code (1301) should promptly take actions such as clearing the screen or restarting the conversation to delete the generated content. They should also ensure that no illegal or harmful content is passed back to the model for continued generation to avoid negative impacts.
|
||||
default: false
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
|
||||
@ -32,6 +32,15 @@ parameter_rules:
|
||||
zh_Hans: do_sample 为 true 时启用采样策略,do_sample 为 false 时采样策略 temperature、top_p 将不生效。默认值为 true。
|
||||
en_US: When `do_sample` is set to true, the sampling strategy is enabled. When `do_sample` is set to false, the sampling strategies such as `temperature` and `top_p` will not take effect. The default value is true.
|
||||
default: true
|
||||
- name: stream
|
||||
label:
|
||||
zh_Hans: 流处理
|
||||
en_US: Event Stream
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 使用同步调用时,此参数应当设置为 fasle 或者省略。表示模型生成完所有内容后一次性返回所有内容。默认值为 false。如果设置为 true,模型将通过标准 Event Stream ,逐块返回模型生成内容。Event Stream 结束时会返回一条data:[DONE]消息。注意:在模型流式输出生成内容的过程中,我们会分批对模型生成内容进行检测,当检测到违法及不良信息时,API会返回错误码(1301)。开发者识别到错误码(1301),应及时采取(清屏、重启对话)等措施删除生成内容,并确保不将含有违法及不良信息的内容传递给模型继续生成,避免其造成负面影响。
|
||||
en_US: When using synchronous invocation, this parameter should be set to false or omitted. It indicates that the model will return all the generated content at once after the generation is complete. The default value is false. If set to true, the model will return the generated content in chunks via the standard Event Stream. A data:[DONE] message will be sent at the end of the Event Stream.Note:During the model's streaming output process, we will batch check the generated content. If illegal or harmful information is detected, the API will return an error code (1301). Developers who identify error code (1301) should promptly take actions such as clearing the screen or restarting the conversation to delete the generated content. They should also ensure that no illegal or harmful content is passed back to the model for continued generation to avoid negative impacts.
|
||||
default: false
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
|
||||
@ -1,53 +0,0 @@
|
||||
model: glm-4-flashx
|
||||
label:
|
||||
en_US: glm-4-flashx
|
||||
model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
default: 0.95
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
help:
|
||||
zh_Hans: 采样温度,控制输出的随机性,必须为正数取值范围是:(0.0,1.0],不能等于 0,默认值为 0.95 值越大,会使输出更随机,更具创造性;值越小,输出会更加稳定或确定建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数。
|
||||
en_US: Sampling temperature, controls the randomness of the output, must be a positive number. The value range is (0.0,1.0], which cannot be equal to 0. The default value is 0.95. The larger the value, the more random and creative the output will be; the smaller the value, The output will be more stable or certain. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time.
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
default: 0.7
|
||||
help:
|
||||
zh_Hans: 用温度取样的另一种方法,称为核取样取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1,默认值为 0.7 模型考虑具有 top_p 概率质量tokens的结果例如:0.1 意味着模型解码器只考虑从前 10% 的概率的候选集中取 tokens 建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数。
|
||||
en_US: Another method of temperature sampling is called kernel sampling. The value range is (0.0, 1.0) open interval, which cannot be equal to 0 or 1. The default value is 0.7. The model considers the results with top_p probability mass tokens. For example 0.1 means The model decoder only considers tokens from the candidate set with the top 10% probability. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time.
|
||||
- name: do_sample
|
||||
label:
|
||||
zh_Hans: 采样策略
|
||||
en_US: Sampling strategy
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: do_sample 为 true 时启用采样策略,do_sample 为 false 时采样策略 temperature、top_p 将不生效。默认值为 true。
|
||||
en_US: When `do_sample` is set to true, the sampling strategy is enabled. When `do_sample` is set to false, the sampling strategies such as `temperature` and `top_p` will not take effect. The default value is true.
|
||||
default: true
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 4095
|
||||
- name: web_search
|
||||
type: boolean
|
||||
label:
|
||||
zh_Hans: 联网搜索
|
||||
en_US: Web Search
|
||||
default: false
|
||||
help:
|
||||
zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。
|
||||
en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic.
|
||||
pricing:
|
||||
input: '0'
|
||||
output: '0'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
||||
@ -32,6 +32,15 @@ parameter_rules:
|
||||
zh_Hans: do_sample 为 true 时启用采样策略,do_sample 为 false 时采样策略 temperature、top_p 将不生效。默认值为 true。
|
||||
en_US: When `do_sample` is set to true, the sampling strategy is enabled. When `do_sample` is set to false, the sampling strategies such as `temperature` and `top_p` will not take effect. The default value is true.
|
||||
default: true
|
||||
- name: stream
|
||||
label:
|
||||
zh_Hans: 流处理
|
||||
en_US: Event Stream
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 使用同步调用时,此参数应当设置为 fasle 或者省略。表示模型生成完所有内容后一次性返回所有内容。默认值为 false。如果设置为 true,模型将通过标准 Event Stream ,逐块返回模型生成内容。Event Stream 结束时会返回一条data:[DONE]消息。注意:在模型流式输出生成内容的过程中,我们会分批对模型生成内容进行检测,当检测到违法及不良信息时,API会返回错误码(1301)。开发者识别到错误码(1301),应及时采取(清屏、重启对话)等措施删除生成内容,并确保不将含有违法及不良信息的内容传递给模型继续生成,避免其造成负面影响。
|
||||
en_US: When using synchronous invocation, this parameter should be set to false or omitted. It indicates that the model will return all the generated content at once after the generation is complete. The default value is false. If set to true, the model will return the generated content in chunks via the standard Event Stream. A data:[DONE] message will be sent at the end of the Event Stream.Note:During the model's streaming output process, we will batch check the generated content. If illegal or harmful information is detected, the API will return an error code (1301). Developers who identify error code (1301) should promptly take actions such as clearing the screen or restarting the conversation to delete the generated content. They should also ensure that no illegal or harmful content is passed back to the model for continued generation to avoid negative impacts.
|
||||
default: false
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
|
||||
@ -32,6 +32,15 @@ parameter_rules:
|
||||
zh_Hans: do_sample 为 true 时启用采样策略,do_sample 为 false 时采样策略 temperature、top_p 将不生效。默认值为 true。
|
||||
en_US: When `do_sample` is set to true, the sampling strategy is enabled. When `do_sample` is set to false, the sampling strategies such as `temperature` and `top_p` will not take effect. The default value is true.
|
||||
default: true
|
||||
- name: stream
|
||||
label:
|
||||
zh_Hans: 流处理
|
||||
en_US: Event Stream
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 使用同步调用时,此参数应当设置为 fasle 或者省略。表示模型生成完所有内容后一次性返回所有内容。默认值为 false。如果设置为 true,模型将通过标准 Event Stream ,逐块返回模型生成内容。Event Stream 结束时会返回一条data:[DONE]消息。注意:在模型流式输出生成内容的过程中,我们会分批对模型生成内容进行检测,当检测到违法及不良信息时,API会返回错误码(1301)。开发者识别到错误码(1301),应及时采取(清屏、重启对话)等措施删除生成内容,并确保不将含有违法及不良信息的内容传递给模型继续生成,避免其造成负面影响。
|
||||
en_US: When using synchronous invocation, this parameter should be set to false or omitted. It indicates that the model will return all the generated content at once after the generation is complete. The default value is false. If set to true, the model will return the generated content in chunks via the standard Event Stream. A data:[DONE] message will be sent at the end of the Event Stream.Note:During the model's streaming output process, we will batch check the generated content. If illegal or harmful information is detected, the API will return an error code (1301). Developers who identify error code (1301) should promptly take actions such as clearing the screen or restarting the conversation to delete the generated content. They should also ensure that no illegal or harmful content is passed back to the model for continued generation to avoid negative impacts.
|
||||
default: false
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
|
||||
@ -35,6 +35,15 @@ parameter_rules:
|
||||
zh_Hans: do_sample 为 true 时启用采样策略,do_sample 为 false 时采样策略 temperature、top_p 将不生效。默认值为 true。
|
||||
en_US: When `do_sample` is set to true, the sampling strategy is enabled. When `do_sample` is set to false, the sampling strategies such as `temperature` and `top_p` will not take effect. The default value is true.
|
||||
default: true
|
||||
- name: stream
|
||||
label:
|
||||
zh_Hans: 流处理
|
||||
en_US: Event Stream
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 使用同步调用时,此参数应当设置为 fasle 或者省略。表示模型生成完所有内容后一次性返回所有内容。默认值为 false。如果设置为 true,模型将通过标准 Event Stream ,逐块返回模型生成内容。Event Stream 结束时会返回一条data:[DONE]消息。注意:在模型流式输出生成内容的过程中,我们会分批对模型生成内容进行检测,当检测到违法及不良信息时,API会返回错误码(1301)。开发者识别到错误码(1301),应及时采取(清屏、重启对话)等措施删除生成内容,并确保不将含有违法及不良信息的内容传递给模型继续生成,避免其造成负面影响。
|
||||
en_US: When using synchronous invocation, this parameter should be set to false or omitted. It indicates that the model will return all the generated content at once after the generation is complete. The default value is false. If set to true, the model will return the generated content in chunks via the standard Event Stream. A data:[DONE] message will be sent at the end of the Event Stream.Note:During the model's streaming output process, we will batch check the generated content. If illegal or harmful information is detected, the API will return an error code (1301). Developers who identify error code (1301) should promptly take actions such as clearing the screen or restarting the conversation to delete the generated content. They should also ensure that no illegal or harmful content is passed back to the model for continued generation to avoid negative impacts.
|
||||
default: false
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
|
||||
@ -32,6 +32,15 @@ parameter_rules:
|
||||
zh_Hans: do_sample 为 true 时启用采样策略,do_sample 为 false 时采样策略 temperature、top_p 将不生效。默认值为 true。
|
||||
en_US: When `do_sample` is set to true, the sampling strategy is enabled. When `do_sample` is set to false, the sampling strategies such as `temperature` and `top_p` will not take effect. The default value is true.
|
||||
default: true
|
||||
- name: stream
|
||||
label:
|
||||
zh_Hans: 流处理
|
||||
en_US: Event Stream
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 使用同步调用时,此参数应当设置为 fasle 或者省略。表示模型生成完所有内容后一次性返回所有内容。默认值为 false。如果设置为 true,模型将通过标准 Event Stream ,逐块返回模型生成内容。Event Stream 结束时会返回一条data:[DONE]消息。注意:在模型流式输出生成内容的过程中,我们会分批对模型生成内容进行检测,当检测到违法及不良信息时,API会返回错误码(1301)。开发者识别到错误码(1301),应及时采取(清屏、重启对话)等措施删除生成内容,并确保不将含有违法及不良信息的内容传递给模型继续生成,避免其造成负面影响。
|
||||
en_US: When using synchronous invocation, this parameter should be set to false or omitted. It indicates that the model will return all the generated content at once after the generation is complete. The default value is false. If set to true, the model will return the generated content in chunks via the standard Event Stream. A data:[DONE] message will be sent at the end of the Event Stream.Note:During the model's streaming output process, we will batch check the generated content. If illegal or harmful information is detected, the API will return an error code (1301). Developers who identify error code (1301) should promptly take actions such as clearing the screen or restarting the conversation to delete the generated content. They should also ensure that no illegal or harmful content is passed back to the model for continued generation to avoid negative impacts.
|
||||
default: false
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
|
||||
@ -30,6 +30,15 @@ parameter_rules:
|
||||
zh_Hans: do_sample 为 true 时启用采样策略,do_sample 为 false 时采样策略 temperature、top_p 将不生效。默认值为 true。
|
||||
en_US: When `do_sample` is set to true, the sampling strategy is enabled. When `do_sample` is set to false, the sampling strategies such as `temperature` and `top_p` will not take effect. The default value is true.
|
||||
default: true
|
||||
- name: stream
|
||||
label:
|
||||
zh_Hans: 流处理
|
||||
en_US: Event Stream
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 使用同步调用时,此参数应当设置为 fasle 或者省略。表示模型生成完所有内容后一次性返回所有内容。默认值为 false。如果设置为 true,模型将通过标准 Event Stream ,逐块返回模型生成内容。Event Stream 结束时会返回一条data:[DONE]消息。注意:在模型流式输出生成内容的过程中,我们会分批对模型生成内容进行检测,当检测到违法及不良信息时,API会返回错误码(1301)。开发者识别到错误码(1301),应及时采取(清屏、重启对话)等措施删除生成内容,并确保不将含有违法及不良信息的内容传递给模型继续生成,避免其造成负面影响。
|
||||
en_US: When using synchronous invocation, this parameter should be set to false or omitted. It indicates that the model will return all the generated content at once after the generation is complete. The default value is false. If set to true, the model will return the generated content in chunks via the standard Event Stream. A data:[DONE] message will be sent at the end of the Event Stream.Note:During the model's streaming output process, we will batch check the generated content. If illegal or harmful information is detected, the API will return an error code (1301). Developers who identify error code (1301) should promptly take actions such as clearing the screen or restarting the conversation to delete the generated content. They should also ensure that no illegal or harmful content is passed back to the model for continued generation to avoid negative impacts.
|
||||
default: false
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
|
||||
@ -30,6 +30,15 @@ parameter_rules:
|
||||
zh_Hans: do_sample 为 true 时启用采样策略,do_sample 为 false 时采样策略 temperature、top_p 将不生效。默认值为 true。
|
||||
en_US: When `do_sample` is set to true, the sampling strategy is enabled. When `do_sample` is set to false, the sampling strategies such as `temperature` and `top_p` will not take effect. The default value is true.
|
||||
default: true
|
||||
- name: stream
|
||||
label:
|
||||
zh_Hans: 流处理
|
||||
en_US: Event Stream
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 使用同步调用时,此参数应当设置为 fasle 或者省略。表示模型生成完所有内容后一次性返回所有内容。默认值为 false。如果设置为 true,模型将通过标准 Event Stream ,逐块返回模型生成内容。Event Stream 结束时会返回一条data:[DONE]消息。注意:在模型流式输出生成内容的过程中,我们会分批对模型生成内容进行检测,当检测到违法及不良信息时,API会返回错误码(1301)。开发者识别到错误码(1301),应及时采取(清屏、重启对话)等措施删除生成内容,并确保不将含有违法及不良信息的内容传递给模型继续生成,避免其造成负面影响。
|
||||
en_US: When using synchronous invocation, this parameter should be set to false or omitted. It indicates that the model will return all the generated content at once after the generation is complete. The default value is false. If set to true, the model will return the generated content in chunks via the standard Event Stream. A data:[DONE] message will be sent at the end of the Event Stream.Note:During the model's streaming output process, we will batch check the generated content. If illegal or harmful information is detected, the API will return an error code (1301). Developers who identify error code (1301) should promptly take actions such as clearing the screen or restarting the conversation to delete the generated content. They should also ensure that no illegal or harmful content is passed back to the model for continued generation to avoid negative impacts.
|
||||
default: false
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
|
||||
@ -1,10 +1,6 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
|
||||
from zhipuai import ZhipuAI
|
||||
from zhipuai.types.chat.chat_completion import Completion
|
||||
from zhipuai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
@ -20,6 +16,9 @@ from core.model_runtime.entities.message_entities import (
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.model_providers.zhipuai._common import _CommonZhipuaiAI
|
||||
from core.model_runtime.model_providers.zhipuai.zhipuai_sdk._client import ZhipuAI
|
||||
from core.model_runtime.model_providers.zhipuai.zhipuai_sdk.types.chat.chat_completion import Completion
|
||||
from core.model_runtime.model_providers.zhipuai.zhipuai_sdk.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
from core.model_runtime.utils import helper
|
||||
|
||||
GLM_JSON_MODE_PROMPT = """You should always follow the instructions and output a valid JSON object.
|
||||
|
||||
@ -1,14 +1,13 @@
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from zhipuai import ZhipuAI
|
||||
|
||||
from core.embedding.embedding_constant import EmbeddingInputType
|
||||
from core.model_runtime.entities.model_entities import PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.model_runtime.model_providers.zhipuai._common import _CommonZhipuaiAI
|
||||
from core.model_runtime.model_providers.zhipuai.zhipuai_sdk._client import ZhipuAI
|
||||
|
||||
|
||||
class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel):
|
||||
|
||||
@ -0,0 +1,15 @@
|
||||
from .__version__ import __version__
|
||||
from ._client import ZhipuAI
|
||||
from .core import (
|
||||
APIAuthenticationError,
|
||||
APIConnectionError,
|
||||
APIInternalError,
|
||||
APIReachLimitError,
|
||||
APIRequestFailedError,
|
||||
APIResponseError,
|
||||
APIResponseValidationError,
|
||||
APIServerFlowExceedError,
|
||||
APIStatusError,
|
||||
APITimeoutError,
|
||||
ZhipuAIError,
|
||||
)
|
||||
@ -0,0 +1 @@
|
||||
__version__ = "v2.1.0"
|
||||
@ -0,0 +1,82 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from collections.abc import Mapping
|
||||
from typing import Union
|
||||
|
||||
import httpx
|
||||
from httpx import Timeout
|
||||
from typing_extensions import override
|
||||
|
||||
from . import api_resource
|
||||
from .core import NOT_GIVEN, ZHIPUAI_DEFAULT_MAX_RETRIES, HttpClient, NotGiven, ZhipuAIError, _jwt_token
|
||||
|
||||
|
||||
class ZhipuAI(HttpClient):
|
||||
chat: api_resource.chat.Chat
|
||||
api_key: str
|
||||
_disable_token_cache: bool = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str | None = None,
|
||||
base_url: str | httpx.URL | None = None,
|
||||
timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN,
|
||||
max_retries: int = ZHIPUAI_DEFAULT_MAX_RETRIES,
|
||||
http_client: httpx.Client | None = None,
|
||||
custom_headers: Mapping[str, str] | None = None,
|
||||
disable_token_cache: bool = True,
|
||||
_strict_response_validation: bool = False,
|
||||
) -> None:
|
||||
if api_key is None:
|
||||
api_key = os.environ.get("ZHIPUAI_API_KEY")
|
||||
if api_key is None:
|
||||
raise ZhipuAIError("未提供api_key,请通过参数或环境变量提供")
|
||||
self.api_key = api_key
|
||||
self._disable_token_cache = disable_token_cache
|
||||
|
||||
if base_url is None:
|
||||
base_url = os.environ.get("ZHIPUAI_BASE_URL")
|
||||
if base_url is None:
|
||||
base_url = "https://open.bigmodel.cn/api/paas/v4"
|
||||
from .__version__ import __version__
|
||||
|
||||
super().__init__(
|
||||
version=__version__,
|
||||
base_url=base_url,
|
||||
max_retries=max_retries,
|
||||
timeout=timeout,
|
||||
custom_httpx_client=http_client,
|
||||
custom_headers=custom_headers,
|
||||
_strict_response_validation=_strict_response_validation,
|
||||
)
|
||||
self.chat = api_resource.chat.Chat(self)
|
||||
self.images = api_resource.images.Images(self)
|
||||
self.embeddings = api_resource.embeddings.Embeddings(self)
|
||||
self.files = api_resource.files.Files(self)
|
||||
self.fine_tuning = api_resource.fine_tuning.FineTuning(self)
|
||||
self.batches = api_resource.Batches(self)
|
||||
self.knowledge = api_resource.Knowledge(self)
|
||||
self.tools = api_resource.Tools(self)
|
||||
self.videos = api_resource.Videos(self)
|
||||
self.assistant = api_resource.Assistant(self)
|
||||
|
||||
@property
|
||||
@override
|
||||
def auth_headers(self) -> dict[str, str]:
|
||||
api_key = self.api_key
|
||||
if self._disable_token_cache:
|
||||
return {"Authorization": f"Bearer {api_key}"}
|
||||
else:
|
||||
return {"Authorization": f"Bearer {_jwt_token.generate_token(api_key)}"}
|
||||
|
||||
def __del__(self) -> None:
|
||||
if not hasattr(self, "_has_custom_http_client") or not hasattr(self, "close") or not hasattr(self, "_client"):
|
||||
# if the '__init__' method raised an error, self would not have client attr
|
||||
return
|
||||
|
||||
if self._has_custom_http_client:
|
||||
return
|
||||
|
||||
self.close()
|
||||
@ -0,0 +1,34 @@
|
||||
from .assistant import (
|
||||
Assistant,
|
||||
)
|
||||
from .batches import Batches
|
||||
from .chat import (
|
||||
AsyncCompletions,
|
||||
Chat,
|
||||
Completions,
|
||||
)
|
||||
from .embeddings import Embeddings
|
||||
from .files import Files, FilesWithRawResponse
|
||||
from .fine_tuning import FineTuning
|
||||
from .images import Images
|
||||
from .knowledge import Knowledge
|
||||
from .tools import Tools
|
||||
from .videos import (
|
||||
Videos,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Videos",
|
||||
"AsyncCompletions",
|
||||
"Chat",
|
||||
"Completions",
|
||||
"Images",
|
||||
"Embeddings",
|
||||
"Files",
|
||||
"FilesWithRawResponse",
|
||||
"FineTuning",
|
||||
"Batches",
|
||||
"Knowledge",
|
||||
"Tools",
|
||||
"Assistant",
|
||||
]
|
||||
@ -0,0 +1,3 @@
|
||||
from .assistant import Assistant
|
||||
|
||||
__all__ = ["Assistant"]
|
||||
@ -0,0 +1,122 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from ...core import (
|
||||
NOT_GIVEN,
|
||||
BaseAPI,
|
||||
Body,
|
||||
Headers,
|
||||
NotGiven,
|
||||
StreamResponse,
|
||||
deepcopy_minimal,
|
||||
make_request_options,
|
||||
maybe_transform,
|
||||
)
|
||||
from ...types.assistant import AssistantCompletion
|
||||
from ...types.assistant.assistant_conversation_resp import ConversationUsageListResp
|
||||
from ...types.assistant.assistant_support_resp import AssistantSupportResp
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..._client import ZhipuAI
|
||||
|
||||
from ...types.assistant import assistant_conversation_params, assistant_create_params
|
||||
|
||||
__all__ = ["Assistant"]
|
||||
|
||||
|
||||
class Assistant(BaseAPI):
|
||||
def __init__(self, client: ZhipuAI) -> None:
|
||||
super().__init__(client)
|
||||
|
||||
def conversation(
|
||||
self,
|
||||
assistant_id: str,
|
||||
model: str,
|
||||
messages: list[assistant_create_params.ConversationMessage],
|
||||
*,
|
||||
stream: bool = True,
|
||||
conversation_id: Optional[str] = None,
|
||||
attachments: Optional[list[assistant_create_params.AssistantAttachments]] = None,
|
||||
metadata: dict | None = None,
|
||||
request_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> StreamResponse[AssistantCompletion]:
|
||||
body = deepcopy_minimal(
|
||||
{
|
||||
"assistant_id": assistant_id,
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"stream": stream,
|
||||
"conversation_id": conversation_id,
|
||||
"attachments": attachments,
|
||||
"metadata": metadata,
|
||||
"request_id": request_id,
|
||||
"user_id": user_id,
|
||||
}
|
||||
)
|
||||
return self._post(
|
||||
"/assistant",
|
||||
body=maybe_transform(body, assistant_create_params.AssistantParameters),
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=AssistantCompletion,
|
||||
stream=stream or True,
|
||||
stream_cls=StreamResponse[AssistantCompletion],
|
||||
)
|
||||
|
||||
def query_support(
|
||||
self,
|
||||
*,
|
||||
assistant_id_list: Optional[list[str]] = None,
|
||||
request_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> AssistantSupportResp:
|
||||
body = deepcopy_minimal(
|
||||
{
|
||||
"assistant_id_list": assistant_id_list,
|
||||
"request_id": request_id,
|
||||
"user_id": user_id,
|
||||
}
|
||||
)
|
||||
return self._post(
|
||||
"/assistant/list",
|
||||
body=body,
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=AssistantSupportResp,
|
||||
)
|
||||
|
||||
def query_conversation_usage(
|
||||
self,
|
||||
assistant_id: str,
|
||||
page: int = 1,
|
||||
page_size: int = 10,
|
||||
*,
|
||||
request_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> ConversationUsageListResp:
|
||||
body = deepcopy_minimal(
|
||||
{
|
||||
"assistant_id": assistant_id,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"request_id": request_id,
|
||||
"user_id": user_id,
|
||||
}
|
||||
)
|
||||
return self._post(
|
||||
"/assistant/conversation/list",
|
||||
body=maybe_transform(body, assistant_conversation_params.ConversationParameters),
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=ConversationUsageListResp,
|
||||
)
|
||||
@ -0,0 +1,146 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Literal, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from ..core import NOT_GIVEN, BaseAPI, Body, Headers, NotGiven, make_request_options, maybe_transform
|
||||
from ..core.pagination import SyncCursorPage
|
||||
from ..types import batch_create_params, batch_list_params
|
||||
from ..types.batch import Batch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .._client import ZhipuAI
|
||||
|
||||
|
||||
class Batches(BaseAPI):
|
||||
def __init__(self, client: ZhipuAI) -> None:
|
||||
super().__init__(client)
|
||||
|
||||
def create(
|
||||
self,
|
||||
*,
|
||||
completion_window: str | None = None,
|
||||
endpoint: Literal["/v1/chat/completions", "/v1/embeddings"],
|
||||
input_file_id: str,
|
||||
metadata: Optional[dict[str, str]] | NotGiven = NOT_GIVEN,
|
||||
auto_delete_input_file: bool = True,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> Batch:
|
||||
return self._post(
|
||||
"/batches",
|
||||
body=maybe_transform(
|
||||
{
|
||||
"completion_window": completion_window,
|
||||
"endpoint": endpoint,
|
||||
"input_file_id": input_file_id,
|
||||
"metadata": metadata,
|
||||
"auto_delete_input_file": auto_delete_input_file,
|
||||
},
|
||||
batch_create_params.BatchCreateParams,
|
||||
),
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=Batch,
|
||||
)
|
||||
|
||||
def retrieve(
|
||||
self,
|
||||
batch_id: str,
|
||||
*,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> Batch:
|
||||
"""
|
||||
Retrieves a batch.
|
||||
|
||||
Args:
|
||||
extra_headers: Send extra headers
|
||||
|
||||
extra_body: Add additional JSON properties to the request
|
||||
|
||||
timeout: Override the client-level default timeout for this request, in seconds
|
||||
"""
|
||||
if not batch_id:
|
||||
raise ValueError(f"Expected a non-empty value for `batch_id` but received {batch_id!r}")
|
||||
return self._get(
|
||||
f"/batches/{batch_id}",
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=Batch,
|
||||
)
|
||||
|
||||
def list(
|
||||
self,
|
||||
*,
|
||||
after: str | NotGiven = NOT_GIVEN,
|
||||
limit: int | NotGiven = NOT_GIVEN,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> SyncCursorPage[Batch]:
|
||||
"""List your organization's batches.
|
||||
|
||||
Args:
|
||||
after: A cursor for use in pagination.
|
||||
|
||||
`after` is an object ID that defines your place
|
||||
in the list. For instance, if you make a list request and receive 100 objects,
|
||||
ending with obj_foo, your subsequent call can include after=obj_foo in order to
|
||||
fetch the next page of the list.
|
||||
|
||||
limit: A limit on the number of objects to be returned. Limit can range between 1 and
|
||||
100, and the default is 20.
|
||||
|
||||
extra_headers: Send extra headers
|
||||
|
||||
extra_body: Add additional JSON properties to the request
|
||||
|
||||
timeout: Override the client-level default timeout for this request, in seconds
|
||||
"""
|
||||
return self._get_api_list(
|
||||
"/batches",
|
||||
page=SyncCursorPage[Batch],
|
||||
options=make_request_options(
|
||||
extra_headers=extra_headers,
|
||||
extra_body=extra_body,
|
||||
timeout=timeout,
|
||||
query=maybe_transform(
|
||||
{
|
||||
"after": after,
|
||||
"limit": limit,
|
||||
},
|
||||
batch_list_params.BatchListParams,
|
||||
),
|
||||
),
|
||||
model=Batch,
|
||||
)
|
||||
|
||||
def cancel(
|
||||
self,
|
||||
batch_id: str,
|
||||
*,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> Batch:
|
||||
"""
|
||||
Cancels an in-progress batch.
|
||||
|
||||
Args:
|
||||
batch_id: The ID of the batch to cancel.
|
||||
extra_headers: Send extra headers
|
||||
|
||||
extra_body: Add additional JSON properties to the request
|
||||
|
||||
timeout: Override the client-level default timeout for this request, in seconds
|
||||
|
||||
"""
|
||||
if not batch_id:
|
||||
raise ValueError(f"Expected a non-empty value for `batch_id` but received {batch_id!r}")
|
||||
return self._post(
|
||||
f"/batches/{batch_id}/cancel",
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=Batch,
|
||||
)
|
||||
@ -0,0 +1,5 @@
|
||||
from .async_completions import AsyncCompletions
|
||||
from .chat import Chat
|
||||
from .completions import Completions
|
||||
|
||||
__all__ = ["AsyncCompletions", "Chat", "Completions"]
|
||||
@ -0,0 +1,115 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Literal, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from ...core import (
|
||||
NOT_GIVEN,
|
||||
BaseAPI,
|
||||
Body,
|
||||
Headers,
|
||||
NotGiven,
|
||||
drop_prefix_image_data,
|
||||
make_request_options,
|
||||
maybe_transform,
|
||||
)
|
||||
from ...types.chat.async_chat_completion import AsyncCompletion, AsyncTaskStatus
|
||||
from ...types.chat.code_geex import code_geex_params
|
||||
from ...types.sensitive_word_check import SensitiveWordCheckRequest
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..._client import ZhipuAI
|
||||
|
||||
|
||||
class AsyncCompletions(BaseAPI):
|
||||
def __init__(self, client: ZhipuAI) -> None:
|
||||
super().__init__(client)
|
||||
|
||||
def create(
|
||||
self,
|
||||
*,
|
||||
model: str,
|
||||
request_id: Optional[str] | NotGiven = NOT_GIVEN,
|
||||
user_id: Optional[str] | NotGiven = NOT_GIVEN,
|
||||
do_sample: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
|
||||
temperature: Optional[float] | NotGiven = NOT_GIVEN,
|
||||
top_p: Optional[float] | NotGiven = NOT_GIVEN,
|
||||
max_tokens: int | NotGiven = NOT_GIVEN,
|
||||
seed: int | NotGiven = NOT_GIVEN,
|
||||
messages: Union[str, list[str], list[int], list[list[int]], None],
|
||||
stop: Optional[Union[str, list[str], None]] | NotGiven = NOT_GIVEN,
|
||||
sensitive_word_check: Optional[SensitiveWordCheckRequest] | NotGiven = NOT_GIVEN,
|
||||
tools: Optional[object] | NotGiven = NOT_GIVEN,
|
||||
tool_choice: str | NotGiven = NOT_GIVEN,
|
||||
meta: Optional[dict[str, str]] | NotGiven = NOT_GIVEN,
|
||||
extra: Optional[code_geex_params.CodeGeexExtra] | NotGiven = NOT_GIVEN,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> AsyncTaskStatus:
|
||||
_cast_type = AsyncTaskStatus
|
||||
logger.debug(f"temperature:{temperature}, top_p:{top_p}")
|
||||
if temperature is not None and temperature != NOT_GIVEN:
|
||||
if temperature <= 0:
|
||||
do_sample = False
|
||||
temperature = 0.01
|
||||
# logger.warning("temperature:取值范围是:(0.0, 1.0) 开区间,do_sample重写为:false(参数top_p temperature不生效)") # noqa: E501
|
||||
if temperature >= 1:
|
||||
temperature = 0.99
|
||||
# logger.warning("temperature:取值范围是:(0.0, 1.0) 开区间")
|
||||
if top_p is not None and top_p != NOT_GIVEN:
|
||||
if top_p >= 1:
|
||||
top_p = 0.99
|
||||
# logger.warning("top_p:取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1")
|
||||
if top_p <= 0:
|
||||
top_p = 0.01
|
||||
# logger.warning("top_p:取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1")
|
||||
|
||||
logger.debug(f"temperature:{temperature}, top_p:{top_p}")
|
||||
if isinstance(messages, list):
|
||||
for item in messages:
|
||||
if item.get("content"):
|
||||
item["content"] = drop_prefix_image_data(item["content"])
|
||||
|
||||
body = {
|
||||
"model": model,
|
||||
"request_id": request_id,
|
||||
"user_id": user_id,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"do_sample": do_sample,
|
||||
"max_tokens": max_tokens,
|
||||
"seed": seed,
|
||||
"messages": messages,
|
||||
"stop": stop,
|
||||
"sensitive_word_check": sensitive_word_check,
|
||||
"tools": tools,
|
||||
"tool_choice": tool_choice,
|
||||
"meta": meta,
|
||||
"extra": maybe_transform(extra, code_geex_params.CodeGeexExtra),
|
||||
}
|
||||
return self._post(
|
||||
"/async/chat/completions",
|
||||
body=body,
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=_cast_type,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
def retrieve_completion_result(
|
||||
self,
|
||||
id: str,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> Union[AsyncCompletion, AsyncTaskStatus]:
|
||||
_cast_type = Union[AsyncCompletion, AsyncTaskStatus]
|
||||
return self._get(
|
||||
path=f"/async-result/{id}",
|
||||
cast_type=_cast_type,
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
)
|
||||
@ -0,0 +1,18 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...core import BaseAPI, cached_property
|
||||
from .async_completions import AsyncCompletions
|
||||
from .completions import Completions
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
class Chat(BaseAPI):
|
||||
@cached_property
|
||||
def completions(self) -> Completions:
|
||||
return Completions(self._client)
|
||||
|
||||
@cached_property
|
||||
def asyncCompletions(self) -> AsyncCompletions: # noqa: N802
|
||||
return AsyncCompletions(self._client)
|
||||
@ -0,0 +1,108 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Literal, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from ...core import (
|
||||
NOT_GIVEN,
|
||||
BaseAPI,
|
||||
Body,
|
||||
Headers,
|
||||
NotGiven,
|
||||
StreamResponse,
|
||||
deepcopy_minimal,
|
||||
drop_prefix_image_data,
|
||||
make_request_options,
|
||||
maybe_transform,
|
||||
)
|
||||
from ...types.chat.chat_completion import Completion
|
||||
from ...types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
from ...types.chat.code_geex import code_geex_params
|
||||
from ...types.sensitive_word_check import SensitiveWordCheckRequest
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..._client import ZhipuAI
|
||||
|
||||
|
||||
class Completions(BaseAPI):
|
||||
def __init__(self, client: ZhipuAI) -> None:
|
||||
super().__init__(client)
|
||||
|
||||
def create(
|
||||
self,
|
||||
*,
|
||||
model: str,
|
||||
request_id: Optional[str] | NotGiven = NOT_GIVEN,
|
||||
user_id: Optional[str] | NotGiven = NOT_GIVEN,
|
||||
do_sample: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
|
||||
stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
|
||||
temperature: Optional[float] | NotGiven = NOT_GIVEN,
|
||||
top_p: Optional[float] | NotGiven = NOT_GIVEN,
|
||||
max_tokens: int | NotGiven = NOT_GIVEN,
|
||||
seed: int | NotGiven = NOT_GIVEN,
|
||||
messages: Union[str, list[str], list[int], object, None],
|
||||
stop: Optional[Union[str, list[str], None]] | NotGiven = NOT_GIVEN,
|
||||
sensitive_word_check: Optional[SensitiveWordCheckRequest] | NotGiven = NOT_GIVEN,
|
||||
tools: Optional[object] | NotGiven = NOT_GIVEN,
|
||||
tool_choice: str | NotGiven = NOT_GIVEN,
|
||||
meta: Optional[dict[str, str]] | NotGiven = NOT_GIVEN,
|
||||
extra: Optional[code_geex_params.CodeGeexExtra] | NotGiven = NOT_GIVEN,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> Completion | StreamResponse[ChatCompletionChunk]:
|
||||
logger.debug(f"temperature:{temperature}, top_p:{top_p}")
|
||||
if temperature is not None and temperature != NOT_GIVEN:
|
||||
if temperature <= 0:
|
||||
do_sample = False
|
||||
temperature = 0.01
|
||||
# logger.warning("temperature:取值范围是:(0.0, 1.0) 开区间,do_sample重写为:false(参数top_p temperature不生效)") # noqa: E501
|
||||
if temperature >= 1:
|
||||
temperature = 0.99
|
||||
# logger.warning("temperature:取值范围是:(0.0, 1.0) 开区间")
|
||||
if top_p is not None and top_p != NOT_GIVEN:
|
||||
if top_p >= 1:
|
||||
top_p = 0.99
|
||||
# logger.warning("top_p:取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1")
|
||||
if top_p <= 0:
|
||||
top_p = 0.01
|
||||
# logger.warning("top_p:取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1")
|
||||
|
||||
logger.debug(f"temperature:{temperature}, top_p:{top_p}")
|
||||
if isinstance(messages, list):
|
||||
for item in messages:
|
||||
if item.get("content"):
|
||||
item["content"] = drop_prefix_image_data(item["content"])
|
||||
|
||||
body = deepcopy_minimal(
|
||||
{
|
||||
"model": model,
|
||||
"request_id": request_id,
|
||||
"user_id": user_id,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"do_sample": do_sample,
|
||||
"max_tokens": max_tokens,
|
||||
"seed": seed,
|
||||
"messages": messages,
|
||||
"stop": stop,
|
||||
"sensitive_word_check": sensitive_word_check,
|
||||
"stream": stream,
|
||||
"tools": tools,
|
||||
"tool_choice": tool_choice,
|
||||
"meta": meta,
|
||||
"extra": maybe_transform(extra, code_geex_params.CodeGeexExtra),
|
||||
}
|
||||
)
|
||||
return self._post(
|
||||
"/chat/completions",
|
||||
body=body,
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=Completion,
|
||||
stream=stream or False,
|
||||
stream_cls=StreamResponse[ChatCompletionChunk],
|
||||
)
|
||||
@ -0,0 +1,50 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from ..core import NOT_GIVEN, BaseAPI, Body, Headers, NotGiven, make_request_options
|
||||
from ..types.embeddings import EmbeddingsResponded
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .._client import ZhipuAI
|
||||
|
||||
|
||||
class Embeddings(BaseAPI):
|
||||
def __init__(self, client: ZhipuAI) -> None:
|
||||
super().__init__(client)
|
||||
|
||||
def create(
|
||||
self,
|
||||
*,
|
||||
input: Union[str, list[str], list[int], list[list[int]]],
|
||||
model: Union[str],
|
||||
dimensions: Union[int] | NotGiven = NOT_GIVEN,
|
||||
encoding_format: str | NotGiven = NOT_GIVEN,
|
||||
user: str | NotGiven = NOT_GIVEN,
|
||||
request_id: Optional[str] | NotGiven = NOT_GIVEN,
|
||||
sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
disable_strict_validation: Optional[bool] | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> EmbeddingsResponded:
|
||||
_cast_type = EmbeddingsResponded
|
||||
if disable_strict_validation:
|
||||
_cast_type = object
|
||||
return self._post(
|
||||
"/embeddings",
|
||||
body={
|
||||
"input": input,
|
||||
"model": model,
|
||||
"dimensions": dimensions,
|
||||
"encoding_format": encoding_format,
|
||||
"user": user,
|
||||
"request_id": request_id,
|
||||
"sensitive_word_check": sensitive_word_check,
|
||||
},
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=_cast_type,
|
||||
stream=False,
|
||||
)
|
||||
@ -0,0 +1,194 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from typing import TYPE_CHECKING, Literal, Optional, cast
|
||||
|
||||
import httpx
|
||||
|
||||
from ..core import (
|
||||
NOT_GIVEN,
|
||||
BaseAPI,
|
||||
Body,
|
||||
FileTypes,
|
||||
Headers,
|
||||
NotGiven,
|
||||
_legacy_binary_response,
|
||||
_legacy_response,
|
||||
deepcopy_minimal,
|
||||
extract_files,
|
||||
make_request_options,
|
||||
maybe_transform,
|
||||
)
|
||||
from ..types.files import FileDeleted, FileObject, ListOfFileObject, UploadDetail, file_create_params
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .._client import ZhipuAI
|
||||
|
||||
__all__ = ["Files", "FilesWithRawResponse"]
|
||||
|
||||
|
||||
class Files(BaseAPI):
|
||||
def __init__(self, client: ZhipuAI) -> None:
|
||||
super().__init__(client)
|
||||
|
||||
def create(
|
||||
self,
|
||||
*,
|
||||
file: Optional[FileTypes] = None,
|
||||
upload_detail: Optional[list[UploadDetail]] = None,
|
||||
purpose: Literal["fine-tune", "retrieval", "batch"],
|
||||
knowledge_id: Optional[str] = None,
|
||||
sentence_size: Optional[int] = None,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> FileObject:
|
||||
if not file and not upload_detail:
|
||||
raise ValueError("At least one of `file` and `upload_detail` must be provided.")
|
||||
body = deepcopy_minimal(
|
||||
{
|
||||
"file": file,
|
||||
"upload_detail": upload_detail,
|
||||
"purpose": purpose,
|
||||
"knowledge_id": knowledge_id,
|
||||
"sentence_size": sentence_size,
|
||||
}
|
||||
)
|
||||
files = extract_files(cast(Mapping[str, object], body), paths=[["file"]])
|
||||
if files:
|
||||
# It should be noted that the actual Content-Type header that will be
|
||||
# sent to the server will contain a `boundary` parameter, e.g.
|
||||
# multipart/form-data; boundary=---abc--
|
||||
extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})}
|
||||
return self._post(
|
||||
"/files",
|
||||
body=maybe_transform(body, file_create_params.FileCreateParams),
|
||||
files=files,
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=FileObject,
|
||||
)
|
||||
|
||||
# def retrieve(
|
||||
# self,
|
||||
# file_id: str,
|
||||
# *,
|
||||
# extra_headers: Headers | None = None,
|
||||
# extra_body: Body | None = None,
|
||||
# timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
# ) -> FileObject:
|
||||
# """
|
||||
# Returns information about a specific file.
|
||||
#
|
||||
# Args:
|
||||
# file_id: The ID of the file to retrieve information about
|
||||
# extra_headers: Send extra headers
|
||||
#
|
||||
# extra_body: Add additional JSON properties to the request
|
||||
#
|
||||
# timeout: Override the client-level default timeout for this request, in seconds
|
||||
# """
|
||||
# if not file_id:
|
||||
# raise ValueError(f"Expected a non-empty value for `file_id` but received {file_id!r}")
|
||||
# return self._get(
|
||||
# f"/files/{file_id}",
|
||||
# options=make_request_options(
|
||||
# extra_headers=extra_headers, extra_body=extra_body, timeout=timeout
|
||||
# ),
|
||||
# cast_type=FileObject,
|
||||
# )
|
||||
|
||||
def list(
|
||||
self,
|
||||
*,
|
||||
purpose: str | NotGiven = NOT_GIVEN,
|
||||
limit: int | NotGiven = NOT_GIVEN,
|
||||
after: str | NotGiven = NOT_GIVEN,
|
||||
order: str | NotGiven = NOT_GIVEN,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> ListOfFileObject:
|
||||
return self._get(
|
||||
"/files",
|
||||
cast_type=ListOfFileObject,
|
||||
options=make_request_options(
|
||||
extra_headers=extra_headers,
|
||||
extra_body=extra_body,
|
||||
timeout=timeout,
|
||||
query={
|
||||
"purpose": purpose,
|
||||
"limit": limit,
|
||||
"after": after,
|
||||
"order": order,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
def delete(
|
||||
self,
|
||||
file_id: str,
|
||||
*,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> FileDeleted:
|
||||
"""
|
||||
Delete a file.
|
||||
|
||||
Args:
|
||||
file_id: The ID of the file to delete
|
||||
extra_headers: Send extra headers
|
||||
|
||||
extra_body: Add additional JSON properties to the request
|
||||
|
||||
timeout: Override the client-level default timeout for this request, in seconds
|
||||
"""
|
||||
if not file_id:
|
||||
raise ValueError(f"Expected a non-empty value for `file_id` but received {file_id!r}")
|
||||
return self._delete(
|
||||
f"/files/{file_id}",
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=FileDeleted,
|
||||
)
|
||||
|
||||
def content(
|
||||
self,
|
||||
file_id: str,
|
||||
*,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> _legacy_response.HttpxBinaryResponseContent:
|
||||
"""
|
||||
Returns the contents of the specified file.
|
||||
|
||||
Args:
|
||||
extra_headers: Send extra headers
|
||||
|
||||
extra_body: Add additional JSON properties to the request
|
||||
|
||||
timeout: Override the client-level default timeout for this request, in seconds
|
||||
"""
|
||||
if not file_id:
|
||||
raise ValueError(f"Expected a non-empty value for `file_id` but received {file_id!r}")
|
||||
extra_headers = {"Accept": "application/binary", **(extra_headers or {})}
|
||||
return self._get(
|
||||
f"/files/{file_id}/content",
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=_legacy_binary_response.HttpxBinaryResponseContent,
|
||||
)
|
||||
|
||||
|
||||
class FilesWithRawResponse:
|
||||
def __init__(self, files: Files) -> None:
|
||||
self._files = files
|
||||
|
||||
self.create = _legacy_response.to_raw_response_wrapper(
|
||||
files.create,
|
||||
)
|
||||
self.list = _legacy_response.to_raw_response_wrapper(
|
||||
files.list,
|
||||
)
|
||||
self.content = _legacy_response.to_raw_response_wrapper(
|
||||
files.content,
|
||||
)
|
||||
@ -0,0 +1,5 @@
|
||||
from .fine_tuning import FineTuning
|
||||
from .jobs import Jobs
|
||||
from .models import FineTunedModels
|
||||
|
||||
__all__ = ["Jobs", "FineTunedModels", "FineTuning"]
|
||||
@ -0,0 +1,18 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...core import BaseAPI, cached_property
|
||||
from .jobs import Jobs
|
||||
from .models import FineTunedModels
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
class FineTuning(BaseAPI):
|
||||
@cached_property
|
||||
def jobs(self) -> Jobs:
|
||||
return Jobs(self._client)
|
||||
|
||||
@cached_property
|
||||
def models(self) -> FineTunedModels:
|
||||
return FineTunedModels(self._client)
|
||||
@ -0,0 +1,3 @@
|
||||
from .jobs import Jobs
|
||||
|
||||
__all__ = ["Jobs"]
|
||||
@ -0,0 +1,152 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from ....core import (
|
||||
NOT_GIVEN,
|
||||
BaseAPI,
|
||||
Body,
|
||||
Headers,
|
||||
NotGiven,
|
||||
make_request_options,
|
||||
)
|
||||
from ....types.fine_tuning import (
|
||||
FineTuningJob,
|
||||
FineTuningJobEvent,
|
||||
ListOfFineTuningJob,
|
||||
job_create_params,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...._client import ZhipuAI
|
||||
|
||||
__all__ = ["Jobs"]
|
||||
|
||||
|
||||
class Jobs(BaseAPI):
|
||||
def __init__(self, client: ZhipuAI) -> None:
|
||||
super().__init__(client)
|
||||
|
||||
def create(
|
||||
self,
|
||||
*,
|
||||
model: str,
|
||||
training_file: str,
|
||||
hyperparameters: job_create_params.Hyperparameters | NotGiven = NOT_GIVEN,
|
||||
suffix: Optional[str] | NotGiven = NOT_GIVEN,
|
||||
request_id: Optional[str] | NotGiven = NOT_GIVEN,
|
||||
validation_file: Optional[str] | NotGiven = NOT_GIVEN,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> FineTuningJob:
|
||||
return self._post(
|
||||
"/fine_tuning/jobs",
|
||||
body={
|
||||
"model": model,
|
||||
"training_file": training_file,
|
||||
"hyperparameters": hyperparameters,
|
||||
"suffix": suffix,
|
||||
"validation_file": validation_file,
|
||||
"request_id": request_id,
|
||||
},
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=FineTuningJob,
|
||||
)
|
||||
|
||||
def retrieve(
|
||||
self,
|
||||
fine_tuning_job_id: str,
|
||||
*,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> FineTuningJob:
|
||||
return self._get(
|
||||
f"/fine_tuning/jobs/{fine_tuning_job_id}",
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=FineTuningJob,
|
||||
)
|
||||
|
||||
def list(
|
||||
self,
|
||||
*,
|
||||
after: str | NotGiven = NOT_GIVEN,
|
||||
limit: int | NotGiven = NOT_GIVEN,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> ListOfFineTuningJob:
|
||||
return self._get(
|
||||
"/fine_tuning/jobs",
|
||||
cast_type=ListOfFineTuningJob,
|
||||
options=make_request_options(
|
||||
extra_headers=extra_headers,
|
||||
extra_body=extra_body,
|
||||
timeout=timeout,
|
||||
query={
|
||||
"after": after,
|
||||
"limit": limit,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
def cancel(
|
||||
self,
|
||||
fine_tuning_job_id: str,
|
||||
*,
|
||||
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # noqa: E501
|
||||
# The extra values given here take precedence over values defined on the client or passed to this method.
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> FineTuningJob:
|
||||
if not fine_tuning_job_id:
|
||||
raise ValueError(f"Expected a non-empty value for `fine_tuning_job_id` but received {fine_tuning_job_id!r}")
|
||||
return self._post(
|
||||
f"/fine_tuning/jobs/{fine_tuning_job_id}/cancel",
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=FineTuningJob,
|
||||
)
|
||||
|
||||
def list_events(
|
||||
self,
|
||||
fine_tuning_job_id: str,
|
||||
*,
|
||||
after: str | NotGiven = NOT_GIVEN,
|
||||
limit: int | NotGiven = NOT_GIVEN,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> FineTuningJobEvent:
|
||||
return self._get(
|
||||
f"/fine_tuning/jobs/{fine_tuning_job_id}/events",
|
||||
cast_type=FineTuningJobEvent,
|
||||
options=make_request_options(
|
||||
extra_headers=extra_headers,
|
||||
extra_body=extra_body,
|
||||
timeout=timeout,
|
||||
query={
|
||||
"after": after,
|
||||
"limit": limit,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
def delete(
|
||||
self,
|
||||
fine_tuning_job_id: str,
|
||||
*,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> FineTuningJob:
|
||||
if not fine_tuning_job_id:
|
||||
raise ValueError(f"Expected a non-empty value for `fine_tuning_job_id` but received {fine_tuning_job_id!r}")
|
||||
return self._delete(
|
||||
f"/fine_tuning/jobs/{fine_tuning_job_id}",
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=FineTuningJob,
|
||||
)
|
||||
@ -0,0 +1,3 @@
|
||||
from .fine_tuned_models import FineTunedModels
|
||||
|
||||
__all__ = ["FineTunedModels"]
|
||||
@ -0,0 +1,41 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import httpx
|
||||
|
||||
from ....core import (
|
||||
NOT_GIVEN,
|
||||
BaseAPI,
|
||||
Body,
|
||||
Headers,
|
||||
NotGiven,
|
||||
make_request_options,
|
||||
)
|
||||
from ....types.fine_tuning.models import FineTunedModelsStatus
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...._client import ZhipuAI
|
||||
|
||||
__all__ = ["FineTunedModels"]
|
||||
|
||||
|
||||
class FineTunedModels(BaseAPI):
|
||||
def __init__(self, client: ZhipuAI) -> None:
|
||||
super().__init__(client)
|
||||
|
||||
def delete(
|
||||
self,
|
||||
fine_tuned_model: str,
|
||||
*,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> FineTunedModelsStatus:
|
||||
if not fine_tuned_model:
|
||||
raise ValueError(f"Expected a non-empty value for `fine_tuned_model` but received {fine_tuned_model!r}")
|
||||
return self._delete(
|
||||
f"fine_tuning/fine_tuned_models/{fine_tuned_model}",
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=FineTunedModelsStatus,
|
||||
)
|
||||
@ -0,0 +1,59 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from ..core import NOT_GIVEN, BaseAPI, Body, Headers, NotGiven, make_request_options
|
||||
from ..types.image import ImagesResponded
|
||||
from ..types.sensitive_word_check import SensitiveWordCheckRequest
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .._client import ZhipuAI
|
||||
|
||||
|
||||
class Images(BaseAPI):
|
||||
def __init__(self, client: ZhipuAI) -> None:
|
||||
super().__init__(client)
|
||||
|
||||
def generations(
|
||||
self,
|
||||
*,
|
||||
prompt: str,
|
||||
model: str | NotGiven = NOT_GIVEN,
|
||||
n: Optional[int] | NotGiven = NOT_GIVEN,
|
||||
quality: Optional[str] | NotGiven = NOT_GIVEN,
|
||||
response_format: Optional[str] | NotGiven = NOT_GIVEN,
|
||||
size: Optional[str] | NotGiven = NOT_GIVEN,
|
||||
style: Optional[str] | NotGiven = NOT_GIVEN,
|
||||
sensitive_word_check: Optional[SensitiveWordCheckRequest] | NotGiven = NOT_GIVEN,
|
||||
user: str | NotGiven = NOT_GIVEN,
|
||||
request_id: Optional[str] | NotGiven = NOT_GIVEN,
|
||||
user_id: Optional[str] | NotGiven = NOT_GIVEN,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
disable_strict_validation: Optional[bool] | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> ImagesResponded:
|
||||
_cast_type = ImagesResponded
|
||||
if disable_strict_validation:
|
||||
_cast_type = object
|
||||
return self._post(
|
||||
"/images/generations",
|
||||
body={
|
||||
"prompt": prompt,
|
||||
"model": model,
|
||||
"n": n,
|
||||
"quality": quality,
|
||||
"response_format": response_format,
|
||||
"sensitive_word_check": sensitive_word_check,
|
||||
"size": size,
|
||||
"style": style,
|
||||
"user": user,
|
||||
"user_id": user_id,
|
||||
"request_id": request_id,
|
||||
},
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=_cast_type,
|
||||
stream=False,
|
||||
)
|
||||
@ -0,0 +1,3 @@
|
||||
from .knowledge import Knowledge
|
||||
|
||||
__all__ = ["Knowledge"]
|
||||
@ -0,0 +1,3 @@
|
||||
from .document import Document
|
||||
|
||||
__all__ = ["Document"]
|
||||
@ -0,0 +1,217 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from typing import TYPE_CHECKING, Literal, Optional, cast
|
||||
|
||||
import httpx
|
||||
|
||||
from ....core import (
|
||||
NOT_GIVEN,
|
||||
BaseAPI,
|
||||
Body,
|
||||
FileTypes,
|
||||
Headers,
|
||||
NotGiven,
|
||||
deepcopy_minimal,
|
||||
extract_files,
|
||||
make_request_options,
|
||||
maybe_transform,
|
||||
)
|
||||
from ....types.files import UploadDetail, file_create_params
|
||||
from ....types.knowledge.document import DocumentData, DocumentObject, document_edit_params, document_list_params
|
||||
from ....types.knowledge.document.document_list_resp import DocumentPage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...._client import ZhipuAI
|
||||
|
||||
__all__ = ["Document"]
|
||||
|
||||
|
||||
class Document(BaseAPI):
|
||||
def __init__(self, client: ZhipuAI) -> None:
|
||||
super().__init__(client)
|
||||
|
||||
def create(
|
||||
self,
|
||||
*,
|
||||
file: Optional[FileTypes] = None,
|
||||
custom_separator: Optional[list[str]] = None,
|
||||
upload_detail: Optional[list[UploadDetail]] = None,
|
||||
purpose: Literal["retrieval"],
|
||||
knowledge_id: Optional[str] = None,
|
||||
sentence_size: Optional[int] = None,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> DocumentObject:
|
||||
if not file and not upload_detail:
|
||||
raise ValueError("At least one of `file` and `upload_detail` must be provided.")
|
||||
body = deepcopy_minimal(
|
||||
{
|
||||
"file": file,
|
||||
"upload_detail": upload_detail,
|
||||
"purpose": purpose,
|
||||
"custom_separator": custom_separator,
|
||||
"knowledge_id": knowledge_id,
|
||||
"sentence_size": sentence_size,
|
||||
}
|
||||
)
|
||||
files = extract_files(cast(Mapping[str, object], body), paths=[["file"]])
|
||||
if files:
|
||||
# It should be noted that the actual Content-Type header that will be
|
||||
# sent to the server will contain a `boundary` parameter, e.g.
|
||||
# multipart/form-data; boundary=---abc--
|
||||
extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})}
|
||||
return self._post(
|
||||
"/files",
|
||||
body=maybe_transform(body, file_create_params.FileCreateParams),
|
||||
files=files,
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=DocumentObject,
|
||||
)
|
||||
|
||||
def edit(
|
||||
self,
|
||||
document_id: str,
|
||||
knowledge_type: str,
|
||||
*,
|
||||
custom_separator: Optional[list[str]] = None,
|
||||
sentence_size: Optional[int] = None,
|
||||
callback_url: Optional[str] = None,
|
||||
callback_header: Optional[dict[str, str]] = None,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> httpx.Response:
|
||||
"""
|
||||
|
||||
Args:
|
||||
document_id: 知识id
|
||||
knowledge_type: 知识类型:
|
||||
1:文章知识: 支持pdf,url,docx
|
||||
2.问答知识-文档: 支持pdf,url,docx
|
||||
3.问答知识-表格: 支持xlsx
|
||||
4.商品库-表格: 支持xlsx
|
||||
5.自定义: 支持pdf,url,docx
|
||||
extra_headers: Send extra headers
|
||||
|
||||
extra_body: Add additional JSON properties to the request
|
||||
|
||||
timeout: Override the client-level default timeout for this request, in seconds
|
||||
:param knowledge_type:
|
||||
:param document_id:
|
||||
:param timeout:
|
||||
:param extra_body:
|
||||
:param callback_header:
|
||||
:param sentence_size:
|
||||
:param extra_headers:
|
||||
:param callback_url:
|
||||
:param custom_separator:
|
||||
"""
|
||||
if not document_id:
|
||||
raise ValueError(f"Expected a non-empty value for `document_id` but received {document_id!r}")
|
||||
|
||||
body = deepcopy_minimal(
|
||||
{
|
||||
"id": document_id,
|
||||
"knowledge_type": knowledge_type,
|
||||
"custom_separator": custom_separator,
|
||||
"sentence_size": sentence_size,
|
||||
"callback_url": callback_url,
|
||||
"callback_header": callback_header,
|
||||
}
|
||||
)
|
||||
|
||||
return self._put(
|
||||
f"/document/{document_id}",
|
||||
body=maybe_transform(body, document_edit_params.DocumentEditParams),
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=httpx.Response,
|
||||
)
|
||||
|
||||
def list(
|
||||
self,
|
||||
knowledge_id: str,
|
||||
*,
|
||||
purpose: str | NotGiven = NOT_GIVEN,
|
||||
page: str | NotGiven = NOT_GIVEN,
|
||||
limit: str | NotGiven = NOT_GIVEN,
|
||||
order: Literal["desc", "asc"] | NotGiven = NOT_GIVEN,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> DocumentPage:
|
||||
return self._get(
|
||||
"/files",
|
||||
options=make_request_options(
|
||||
extra_headers=extra_headers,
|
||||
extra_body=extra_body,
|
||||
timeout=timeout,
|
||||
query=maybe_transform(
|
||||
{
|
||||
"knowledge_id": knowledge_id,
|
||||
"purpose": purpose,
|
||||
"page": page,
|
||||
"limit": limit,
|
||||
"order": order,
|
||||
},
|
||||
document_list_params.DocumentListParams,
|
||||
),
|
||||
),
|
||||
cast_type=DocumentPage,
|
||||
)
|
||||
|
||||
def delete(
|
||||
self,
|
||||
document_id: str,
|
||||
*,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> httpx.Response:
|
||||
"""
|
||||
Delete a file.
|
||||
|
||||
Args:
|
||||
|
||||
document_id: 知识id
|
||||
extra_headers: Send extra headers
|
||||
|
||||
extra_body: Add additional JSON properties to the request
|
||||
|
||||
timeout: Override the client-level default timeout for this request, in seconds
|
||||
"""
|
||||
if not document_id:
|
||||
raise ValueError(f"Expected a non-empty value for `document_id` but received {document_id!r}")
|
||||
|
||||
return self._delete(
|
||||
f"/document/{document_id}",
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=httpx.Response,
|
||||
)
|
||||
|
||||
def retrieve(
|
||||
self,
|
||||
document_id: str,
|
||||
*,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> DocumentData:
|
||||
"""
|
||||
|
||||
Args:
|
||||
extra_headers: Send extra headers
|
||||
|
||||
extra_body: Add additional JSON properties to the request
|
||||
|
||||
timeout: Override the client-level default timeout for this request, in seconds
|
||||
"""
|
||||
if not document_id:
|
||||
raise ValueError(f"Expected a non-empty value for `document_id` but received {document_id!r}")
|
||||
|
||||
return self._get(
|
||||
f"/document/{document_id}",
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=DocumentData,
|
||||
)
|
||||
@ -0,0 +1,173 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Literal, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from ...core import (
|
||||
NOT_GIVEN,
|
||||
BaseAPI,
|
||||
Body,
|
||||
Headers,
|
||||
NotGiven,
|
||||
cached_property,
|
||||
deepcopy_minimal,
|
||||
make_request_options,
|
||||
maybe_transform,
|
||||
)
|
||||
from ...types.knowledge import KnowledgeInfo, KnowledgeUsed, knowledge_create_params, knowledge_list_params
|
||||
from ...types.knowledge.knowledge_list_resp import KnowledgePage
|
||||
from .document import Document
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..._client import ZhipuAI
|
||||
|
||||
__all__ = ["Knowledge"]
|
||||
|
||||
|
||||
class Knowledge(BaseAPI):
|
||||
def __init__(self, client: ZhipuAI) -> None:
|
||||
super().__init__(client)
|
||||
|
||||
@cached_property
|
||||
def document(self) -> Document:
|
||||
return Document(self._client)
|
||||
|
||||
def create(
|
||||
self,
|
||||
embedding_id: int,
|
||||
name: str,
|
||||
*,
|
||||
customer_identifier: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
background: Optional[Literal["blue", "red", "orange", "purple", "sky"]] = None,
|
||||
icon: Optional[Literal["question", "book", "seal", "wrench", "tag", "horn", "house"]] = None,
|
||||
bucket_id: Optional[str] = None,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> KnowledgeInfo:
|
||||
body = deepcopy_minimal(
|
||||
{
|
||||
"embedding_id": embedding_id,
|
||||
"name": name,
|
||||
"customer_identifier": customer_identifier,
|
||||
"description": description,
|
||||
"background": background,
|
||||
"icon": icon,
|
||||
"bucket_id": bucket_id,
|
||||
}
|
||||
)
|
||||
return self._post(
|
||||
"/knowledge",
|
||||
body=maybe_transform(body, knowledge_create_params.KnowledgeBaseParams),
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=KnowledgeInfo,
|
||||
)
|
||||
|
||||
def modify(
|
||||
self,
|
||||
knowledge_id: str,
|
||||
embedding_id: int,
|
||||
*,
|
||||
name: str,
|
||||
description: Optional[str] = None,
|
||||
background: Optional[Literal["blue", "red", "orange", "purple", "sky"]] = None,
|
||||
icon: Optional[Literal["question", "book", "seal", "wrench", "tag", "horn", "house"]] = None,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> httpx.Response:
|
||||
body = deepcopy_minimal(
|
||||
{
|
||||
"id": knowledge_id,
|
||||
"embedding_id": embedding_id,
|
||||
"name": name,
|
||||
"description": description,
|
||||
"background": background,
|
||||
"icon": icon,
|
||||
}
|
||||
)
|
||||
return self._put(
|
||||
f"/knowledge/{knowledge_id}",
|
||||
body=maybe_transform(body, knowledge_create_params.KnowledgeBaseParams),
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=httpx.Response,
|
||||
)
|
||||
|
||||
def query(
|
||||
self,
|
||||
*,
|
||||
page: int | NotGiven = 1,
|
||||
size: int | NotGiven = 10,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> KnowledgePage:
|
||||
return self._get(
|
||||
"/knowledge",
|
||||
options=make_request_options(
|
||||
extra_headers=extra_headers,
|
||||
extra_body=extra_body,
|
||||
timeout=timeout,
|
||||
query=maybe_transform(
|
||||
{
|
||||
"page": page,
|
||||
"size": size,
|
||||
},
|
||||
knowledge_list_params.KnowledgeListParams,
|
||||
),
|
||||
),
|
||||
cast_type=KnowledgePage,
|
||||
)
|
||||
|
||||
def delete(
|
||||
self,
|
||||
knowledge_id: str,
|
||||
*,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> httpx.Response:
|
||||
"""
|
||||
Delete a file.
|
||||
|
||||
Args:
|
||||
knowledge_id: 知识库ID
|
||||
extra_headers: Send extra headers
|
||||
|
||||
extra_body: Add additional JSON properties to the request
|
||||
|
||||
timeout: Override the client-level default timeout for this request, in seconds
|
||||
"""
|
||||
if not knowledge_id:
|
||||
raise ValueError("Expected a non-empty value for `knowledge_id`")
|
||||
|
||||
return self._delete(
|
||||
f"/knowledge/{knowledge_id}",
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=httpx.Response,
|
||||
)
|
||||
|
||||
def used(
|
||||
self,
|
||||
*,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> KnowledgeUsed:
|
||||
"""
|
||||
Returns the contents of the specified file.
|
||||
|
||||
Args:
|
||||
extra_headers: Send extra headers
|
||||
|
||||
extra_body: Add additional JSON properties to the request
|
||||
|
||||
timeout: Override the client-level default timeout for this request, in seconds
|
||||
"""
|
||||
return self._get(
|
||||
"/knowledge/capacity",
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=KnowledgeUsed,
|
||||
)
|
||||
@ -0,0 +1,3 @@
|
||||
from .tools import Tools
|
||||
|
||||
__all__ = ["Tools"]
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user