Merge branch 'main' into feat/end-user-oauth

This commit is contained in:
zhsama
2025-12-05 02:58:50 +08:00
13 changed files with 759 additions and 112 deletions

View File

@ -2,6 +2,7 @@ from collections.abc import Sequence
from enum import StrEnum, auto
from typing import Any, Literal
from jsonschema import Draft7Validator, SchemaError
from pydantic import BaseModel, Field, field_validator
from core.file import FileTransferMethod, FileType, FileUploadConfig
@ -98,6 +99,7 @@ class VariableEntityType(StrEnum):
FILE = "file"
FILE_LIST = "file-list"
CHECKBOX = "checkbox"
JSON_OBJECT = "json_object"
class VariableEntity(BaseModel):
@ -118,6 +120,7 @@ class VariableEntity(BaseModel):
allowed_file_types: Sequence[FileType] | None = Field(default_factory=list)
allowed_file_extensions: Sequence[str] | None = Field(default_factory=list)
allowed_file_upload_methods: Sequence[FileTransferMethod] | None = Field(default_factory=list)
json_schema: dict[str, Any] | None = Field(default=None)
@field_validator("description", mode="before")
@classmethod
@ -129,6 +132,17 @@ class VariableEntity(BaseModel):
def convert_none_options(cls, v: Any) -> Sequence[str]:
return v or []
@field_validator("json_schema")
@classmethod
def validate_json_schema(cls, schema: dict[str, Any] | None) -> dict[str, Any] | None:
if schema is None:
return None
try:
Draft7Validator.check_schema(schema)
except SchemaError as e:
raise ValueError(f"Invalid JSON schema: {e.message}")
return schema
class RagPipelineVariableEntity(VariableEntity):
"""

View File

@ -156,78 +156,82 @@ class MessageBasedAppGenerator(BaseAppGenerator):
query = application_generate_entity.query or "New conversation"
conversation_name = (query[:20] + "") if len(query) > 20 else query
if not conversation:
conversation = Conversation(
with db.session.begin():
if not conversation:
conversation = Conversation(
app_id=app_config.app_id,
app_model_config_id=app_model_config_id,
model_provider=model_provider,
model_id=model_id,
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
mode=app_config.app_mode.value,
name=conversation_name,
inputs=application_generate_entity.inputs,
introduction=introduction,
system_instruction="",
system_instruction_tokens=0,
status="normal",
invoke_from=application_generate_entity.invoke_from.value,
from_source=from_source,
from_end_user_id=end_user_id,
from_account_id=account_id,
)
db.session.add(conversation)
db.session.flush()
db.session.refresh(conversation)
else:
conversation.updated_at = naive_utc_now()
message = Message(
app_id=app_config.app_id,
app_model_config_id=app_model_config_id,
model_provider=model_provider,
model_id=model_id,
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
mode=app_config.app_mode.value,
name=conversation_name,
conversation_id=conversation.id,
inputs=application_generate_entity.inputs,
introduction=introduction,
system_instruction="",
system_instruction_tokens=0,
status="normal",
query=application_generate_entity.query,
message="",
message_tokens=0,
message_unit_price=0,
message_price_unit=0,
answer="",
answer_tokens=0,
answer_unit_price=0,
answer_price_unit=0,
parent_message_id=getattr(application_generate_entity, "parent_message_id", None),
provider_response_latency=0,
total_price=0,
currency="USD",
invoke_from=application_generate_entity.invoke_from.value,
from_source=from_source,
from_end_user_id=end_user_id,
from_account_id=account_id,
app_mode=app_config.app_mode,
)
db.session.add(conversation)
db.session.add(message)
db.session.flush()
db.session.refresh(message)
message_files = []
for file in application_generate_entity.files:
message_file = MessageFile(
message_id=message.id,
type=file.type,
transfer_method=file.transfer_method,
belongs_to="user",
url=file.remote_url,
upload_file_id=file.related_id,
created_by_role=(CreatorUserRole.ACCOUNT if account_id else CreatorUserRole.END_USER),
created_by=account_id or end_user_id or "",
)
message_files.append(message_file)
if message_files:
db.session.add_all(message_files)
db.session.commit()
db.session.refresh(conversation)
else:
conversation.updated_at = naive_utc_now()
db.session.commit()
message = Message(
app_id=app_config.app_id,
model_provider=model_provider,
model_id=model_id,
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
conversation_id=conversation.id,
inputs=application_generate_entity.inputs,
query=application_generate_entity.query,
message="",
message_tokens=0,
message_unit_price=0,
message_price_unit=0,
answer="",
answer_tokens=0,
answer_unit_price=0,
answer_price_unit=0,
parent_message_id=getattr(application_generate_entity, "parent_message_id", None),
provider_response_latency=0,
total_price=0,
currency="USD",
invoke_from=application_generate_entity.invoke_from.value,
from_source=from_source,
from_end_user_id=end_user_id,
from_account_id=account_id,
app_mode=app_config.app_mode,
)
db.session.add(message)
db.session.commit()
db.session.refresh(message)
for file in application_generate_entity.files:
message_file = MessageFile(
message_id=message.id,
type=file.type,
transfer_method=file.transfer_method,
belongs_to="user",
url=file.remote_url,
upload_file_id=file.related_id,
created_by_role=(CreatorUserRole.ACCOUNT if account_id else CreatorUserRole.END_USER),
created_by=account_id or end_user_id or "",
)
db.session.add(message_file)
db.session.commit()
return conversation, message
def _get_conversation_introduction(self, application_generate_entity: AppGenerateEntity) -> str:

View File

@ -1,3 +1,8 @@
from typing import Any
from jsonschema import Draft7Validator, ValidationError
from core.app.app_config.entities import VariableEntityType
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
@ -15,6 +20,7 @@ class StartNode(Node[StartNodeData]):
def _run(self) -> NodeRunResult:
node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs)
self._validate_and_normalize_json_object_inputs(node_inputs)
system_inputs = self.graph_runtime_state.variable_pool.system_variables.to_dict()
# TODO: System variables should be directly accessible, no need for special handling
@ -24,3 +30,27 @@ class StartNode(Node[StartNodeData]):
outputs = dict(node_inputs)
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=node_inputs, outputs=outputs)
def _validate_and_normalize_json_object_inputs(self, node_inputs: dict[str, Any]) -> None:
for variable in self.node_data.variables:
if variable.type != VariableEntityType.JSON_OBJECT:
continue
key = variable.variable
value = node_inputs.get(key)
if value is None and variable.required:
raise ValueError(f"{key} is required in input form")
if not isinstance(value, dict):
raise ValueError(f"{key} must be a JSON object")
schema = variable.json_schema
if not schema:
continue
try:
Draft7Validator(schema).validate(value)
except ValidationError as e:
raise ValueError(f"JSON object for '{key}' does not match schema: {e.message}")
node_inputs[key] = value