mirror of
https://github.com/langgenius/dify.git
synced 2026-05-05 18:08:07 +08:00
Feat: conversation variable & variable assigner node (#7222)
Signed-off-by: -LAN- <laipz8200@outlook.com> Co-authored-by: Joel <iamjoel007@gmail.com> Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
@ -1,15 +1,19 @@
|
||||
from enum import Enum
|
||||
|
||||
from sqlalchemy import CHAR, TypeDecorator
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from .model import AppMode
|
||||
from .types import StringUUID
|
||||
from .workflow import ConversationVariable, WorkflowNodeExecutionStatus
|
||||
|
||||
__all__ = ['ConversationVariable', 'StringUUID', 'AppMode', 'WorkflowNodeExecutionStatus']
|
||||
|
||||
|
||||
class CreatedByRole(Enum):
|
||||
"""
|
||||
Enum class for createdByRole
|
||||
"""
|
||||
ACCOUNT = "account"
|
||||
END_USER = "end_user"
|
||||
|
||||
ACCOUNT = 'account'
|
||||
END_USER = 'end_user'
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> 'CreatedByRole':
|
||||
@ -23,49 +27,3 @@ class CreatedByRole(Enum):
|
||||
if role.value == value:
|
||||
return role
|
||||
raise ValueError(f'invalid createdByRole value {value}')
|
||||
|
||||
|
||||
class CreatedFrom(Enum):
|
||||
"""
|
||||
Enum class for createdFrom
|
||||
"""
|
||||
SERVICE_API = "service-api"
|
||||
WEB_APP = "web-app"
|
||||
EXPLORE = "explore"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> 'CreatedFrom':
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
:param value: mode value
|
||||
:return: mode
|
||||
"""
|
||||
for role in cls:
|
||||
if role.value == value:
|
||||
return role
|
||||
raise ValueError(f'invalid createdFrom value {value}')
|
||||
|
||||
|
||||
class StringUUID(TypeDecorator):
|
||||
impl = CHAR
|
||||
cache_ok = True
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
if value is None:
|
||||
return value
|
||||
elif dialect.name == 'postgresql':
|
||||
return str(value)
|
||||
else:
|
||||
return value.hex
|
||||
|
||||
def load_dialect_impl(self, dialect):
|
||||
if dialect.name == 'postgresql':
|
||||
return dialect.type_descriptor(UUID())
|
||||
else:
|
||||
return dialect.type_descriptor(CHAR(36))
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
if value is None:
|
||||
return value
|
||||
return str(value)
|
||||
|
||||
@ -4,7 +4,8 @@ import json
|
||||
from flask_login import UserMixin
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models import StringUUID
|
||||
|
||||
from .types import StringUUID
|
||||
|
||||
|
||||
class AccountStatus(str, enum.Enum):
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
import enum
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models import StringUUID
|
||||
|
||||
from .types import StringUUID
|
||||
|
||||
|
||||
class APIBasedExtensionPoint(enum.Enum):
|
||||
|
||||
@ -16,9 +16,10 @@ from configs import dify_config
|
||||
from core.rag.retrieval.retrival_methods import RetrievalMethod
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models import StringUUID
|
||||
from models.account import Account
|
||||
from models.model import App, Tag, TagBinding, UploadFile
|
||||
|
||||
from .account import Account
|
||||
from .model import App, Tag, TagBinding, UploadFile
|
||||
from .types import StringUUID
|
||||
|
||||
|
||||
class Dataset(db.Model):
|
||||
|
||||
@ -14,8 +14,8 @@ from core.file.upload_file_parser import UploadFileParser
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import generate_string
|
||||
|
||||
from . import StringUUID
|
||||
from .account import Account, Tenant
|
||||
from .types import StringUUID
|
||||
|
||||
|
||||
class DifySetup(db.Model):
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
from enum import Enum
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models import StringUUID
|
||||
|
||||
from .types import StringUUID
|
||||
|
||||
|
||||
class ProviderType(Enum):
|
||||
|
||||
@ -3,7 +3,8 @@ import json
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models import StringUUID
|
||||
|
||||
from .types import StringUUID
|
||||
|
||||
|
||||
class DataSourceOauthBinding(db.Model):
|
||||
|
||||
@ -2,7 +2,8 @@ import json
|
||||
from enum import Enum
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models import StringUUID
|
||||
|
||||
from .types import StringUUID
|
||||
|
||||
|
||||
class ToolProviderName(Enum):
|
||||
|
||||
@ -6,8 +6,9 @@ from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration
|
||||
from extensions.ext_database import db
|
||||
from models import StringUUID
|
||||
from models.model import Account, App, Tenant
|
||||
|
||||
from .model import Account, App, Tenant
|
||||
from .types import StringUUID
|
||||
|
||||
|
||||
class BuiltinToolProvider(db.Model):
|
||||
|
||||
26
api/models/types.py
Normal file
26
api/models/types.py
Normal file
@ -0,0 +1,26 @@
|
||||
from sqlalchemy import CHAR, TypeDecorator
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
|
||||
class StringUUID(TypeDecorator):
|
||||
impl = CHAR
|
||||
cache_ok = True
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
if value is None:
|
||||
return value
|
||||
elif dialect.name == 'postgresql':
|
||||
return str(value)
|
||||
else:
|
||||
return value.hex
|
||||
|
||||
def load_dialect_impl(self, dialect):
|
||||
if dialect.name == 'postgresql':
|
||||
return dialect.type_descriptor(UUID())
|
||||
else:
|
||||
return dialect.type_descriptor(CHAR(36))
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
if value is None:
|
||||
return value
|
||||
return str(value)
|
||||
@ -1,7 +1,8 @@
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models import StringUUID
|
||||
from models.model import Message
|
||||
|
||||
from .model import Message
|
||||
from .types import StringUUID
|
||||
|
||||
|
||||
class SavedMessage(db.Model):
|
||||
|
||||
@ -3,18 +3,18 @@ from collections.abc import Mapping, Sequence
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Mapped
|
||||
|
||||
import contexts
|
||||
from constants import HIDDEN_VALUE
|
||||
from core.app.segments import (
|
||||
SecretVariable,
|
||||
Variable,
|
||||
factory,
|
||||
)
|
||||
from core.app.segments import SecretVariable, Variable, factory
|
||||
from core.helper import encrypter
|
||||
from extensions.ext_database import db
|
||||
from libs import helper
|
||||
from models import StringUUID
|
||||
from models.account import Account
|
||||
|
||||
from .account import Account
|
||||
from .types import StringUUID
|
||||
|
||||
|
||||
class CreatedByRole(Enum):
|
||||
@ -122,6 +122,7 @@ class Workflow(db.Model):
|
||||
updated_by = db.Column(StringUUID)
|
||||
updated_at = db.Column(db.DateTime)
|
||||
_environment_variables = db.Column('environment_variables', db.Text, nullable=False, server_default='{}')
|
||||
_conversation_variables = db.Column('conversation_variables', db.Text, nullable=False, server_default='{}')
|
||||
|
||||
@property
|
||||
def created_by_account(self):
|
||||
@ -249,9 +250,27 @@ class Workflow(db.Model):
|
||||
'graph': self.graph_dict,
|
||||
'features': self.features_dict,
|
||||
'environment_variables': [var.model_dump(mode='json') for var in environment_variables],
|
||||
'conversation_variables': [var.model_dump(mode='json') for var in self.conversation_variables],
|
||||
}
|
||||
return result
|
||||
|
||||
@property
|
||||
def conversation_variables(self) -> Sequence[Variable]:
|
||||
# TODO: find some way to init `self._conversation_variables` when instance created.
|
||||
if self._conversation_variables is None:
|
||||
self._conversation_variables = '{}'
|
||||
|
||||
variables_dict: dict[str, Any] = json.loads(self._conversation_variables)
|
||||
results = [factory.build_variable_from_mapping(v) for v in variables_dict.values()]
|
||||
return results
|
||||
|
||||
@conversation_variables.setter
|
||||
def conversation_variables(self, value: Sequence[Variable]) -> None:
|
||||
self._conversation_variables = json.dumps(
|
||||
{var.name: var.model_dump() for var in value},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
|
||||
class WorkflowRunTriggeredFrom(Enum):
|
||||
"""
|
||||
@ -702,3 +721,34 @@ class WorkflowAppLog(db.Model):
|
||||
created_by_role = CreatedByRole.value_of(self.created_by_role)
|
||||
return db.session.get(EndUser, self.created_by) \
|
||||
if created_by_role == CreatedByRole.END_USER else None
|
||||
|
||||
|
||||
class ConversationVariable(db.Model):
|
||||
__tablename__ = 'workflow__conversation_variables'
|
||||
|
||||
id: Mapped[str] = db.Column(StringUUID, primary_key=True)
|
||||
conversation_id: Mapped[str] = db.Column(StringUUID, nullable=False, primary_key=True)
|
||||
app_id: Mapped[str] = db.Column(StringUUID, nullable=False, index=True)
|
||||
data = db.Column(db.Text, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, index=True, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp())
|
||||
|
||||
def __init__(self, *, id: str, app_id: str, conversation_id: str, data: str) -> None:
|
||||
self.id = id
|
||||
self.app_id = app_id
|
||||
self.conversation_id = conversation_id
|
||||
self.data = data
|
||||
|
||||
@classmethod
|
||||
def from_variable(cls, *, app_id: str, conversation_id: str, variable: Variable) -> 'ConversationVariable':
|
||||
obj = cls(
|
||||
id=variable.id,
|
||||
app_id=app_id,
|
||||
conversation_id=conversation_id,
|
||||
data=variable.model_dump_json(),
|
||||
)
|
||||
return obj
|
||||
|
||||
def to_variable(self) -> Variable:
|
||||
mapping = json.loads(self.data)
|
||||
return factory.build_variable_from_mapping(mapping)
|
||||
|
||||
Reference in New Issue
Block a user