Feat: conversation variable & variable assigner node (#7222)

Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: Joel <iamjoel007@gmail.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
KVOJJJin
2024-08-13 14:44:10 +08:00
committed by GitHub
parent 8b55bd5828
commit 935e72d449
128 changed files with 3354 additions and 683 deletions

View File

@ -1,15 +1,19 @@
from enum import Enum
from sqlalchemy import CHAR, TypeDecorator
from sqlalchemy.dialects.postgresql import UUID
from .model import AppMode
from .types import StringUUID
from .workflow import ConversationVariable, WorkflowNodeExecutionStatus
__all__ = ['ConversationVariable', 'StringUUID', 'AppMode', 'WorkflowNodeExecutionStatus']
class CreatedByRole(Enum):
"""
Enum class for createdByRole
"""
ACCOUNT = "account"
END_USER = "end_user"
ACCOUNT = 'account'
END_USER = 'end_user'
@classmethod
def value_of(cls, value: str) -> 'CreatedByRole':
@ -23,49 +27,3 @@ class CreatedByRole(Enum):
if role.value == value:
return role
raise ValueError(f'invalid createdByRole value {value}')
class CreatedFrom(Enum):
"""
Enum class for createdFrom
"""
SERVICE_API = "service-api"
WEB_APP = "web-app"
EXPLORE = "explore"
@classmethod
def value_of(cls, value: str) -> 'CreatedFrom':
"""
Get value of given mode.
:param value: mode value
:return: mode
"""
for role in cls:
if role.value == value:
return role
raise ValueError(f'invalid createdFrom value {value}')
class StringUUID(TypeDecorator):
impl = CHAR
cache_ok = True
def process_bind_param(self, value, dialect):
if value is None:
return value
elif dialect.name == 'postgresql':
return str(value)
else:
return value.hex
def load_dialect_impl(self, dialect):
if dialect.name == 'postgresql':
return dialect.type_descriptor(UUID())
else:
return dialect.type_descriptor(CHAR(36))
def process_result_value(self, value, dialect):
if value is None:
return value
return str(value)

View File

@ -4,7 +4,8 @@ import json
from flask_login import UserMixin
from extensions.ext_database import db
from models import StringUUID
from .types import StringUUID
class AccountStatus(str, enum.Enum):

View File

@ -1,7 +1,8 @@
import enum
from extensions.ext_database import db
from models import StringUUID
from .types import StringUUID
class APIBasedExtensionPoint(enum.Enum):

View File

@ -16,9 +16,10 @@ from configs import dify_config
from core.rag.retrieval.retrival_methods import RetrievalMethod
from extensions.ext_database import db
from extensions.ext_storage import storage
from models import StringUUID
from models.account import Account
from models.model import App, Tag, TagBinding, UploadFile
from .account import Account
from .model import App, Tag, TagBinding, UploadFile
from .types import StringUUID
class Dataset(db.Model):

View File

@ -14,8 +14,8 @@ from core.file.upload_file_parser import UploadFileParser
from extensions.ext_database import db
from libs.helper import generate_string
from . import StringUUID
from .account import Account, Tenant
from .types import StringUUID
class DifySetup(db.Model):

View File

@ -1,7 +1,8 @@
from enum import Enum
from extensions.ext_database import db
from models import StringUUID
from .types import StringUUID
class ProviderType(Enum):

View File

@ -3,7 +3,8 @@ import json
from sqlalchemy.dialects.postgresql import JSONB
from extensions.ext_database import db
from models import StringUUID
from .types import StringUUID
class DataSourceOauthBinding(db.Model):

View File

@ -2,7 +2,8 @@ import json
from enum import Enum
from extensions.ext_database import db
from models import StringUUID
from .types import StringUUID
class ToolProviderName(Enum):

View File

@ -6,8 +6,9 @@ from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration
from extensions.ext_database import db
from models import StringUUID
from models.model import Account, App, Tenant
from .model import Account, App, Tenant
from .types import StringUUID
class BuiltinToolProvider(db.Model):

26
api/models/types.py Normal file
View File

@ -0,0 +1,26 @@
from sqlalchemy import CHAR, TypeDecorator
from sqlalchemy.dialects.postgresql import UUID
class StringUUID(TypeDecorator):
impl = CHAR
cache_ok = True
def process_bind_param(self, value, dialect):
if value is None:
return value
elif dialect.name == 'postgresql':
return str(value)
else:
return value.hex
def load_dialect_impl(self, dialect):
if dialect.name == 'postgresql':
return dialect.type_descriptor(UUID())
else:
return dialect.type_descriptor(CHAR(36))
def process_result_value(self, value, dialect):
if value is None:
return value
return str(value)

View File

@ -1,7 +1,8 @@
from extensions.ext_database import db
from models import StringUUID
from models.model import Message
from .model import Message
from .types import StringUUID
class SavedMessage(db.Model):

View File

@ -3,18 +3,18 @@ from collections.abc import Mapping, Sequence
from enum import Enum
from typing import Any, Optional, Union
from sqlalchemy import func
from sqlalchemy.orm import Mapped
import contexts
from constants import HIDDEN_VALUE
from core.app.segments import (
SecretVariable,
Variable,
factory,
)
from core.app.segments import SecretVariable, Variable, factory
from core.helper import encrypter
from extensions.ext_database import db
from libs import helper
from models import StringUUID
from models.account import Account
from .account import Account
from .types import StringUUID
class CreatedByRole(Enum):
@ -122,6 +122,7 @@ class Workflow(db.Model):
updated_by = db.Column(StringUUID)
updated_at = db.Column(db.DateTime)
_environment_variables = db.Column('environment_variables', db.Text, nullable=False, server_default='{}')
_conversation_variables = db.Column('conversation_variables', db.Text, nullable=False, server_default='{}')
@property
def created_by_account(self):
@ -249,9 +250,27 @@ class Workflow(db.Model):
'graph': self.graph_dict,
'features': self.features_dict,
'environment_variables': [var.model_dump(mode='json') for var in environment_variables],
'conversation_variables': [var.model_dump(mode='json') for var in self.conversation_variables],
}
return result
@property
def conversation_variables(self) -> Sequence[Variable]:
# TODO: find some way to init `self._conversation_variables` when instance created.
if self._conversation_variables is None:
self._conversation_variables = '{}'
variables_dict: dict[str, Any] = json.loads(self._conversation_variables)
results = [factory.build_variable_from_mapping(v) for v in variables_dict.values()]
return results
@conversation_variables.setter
def conversation_variables(self, value: Sequence[Variable]) -> None:
self._conversation_variables = json.dumps(
{var.name: var.model_dump() for var in value},
ensure_ascii=False,
)
class WorkflowRunTriggeredFrom(Enum):
"""
@ -702,3 +721,34 @@ class WorkflowAppLog(db.Model):
created_by_role = CreatedByRole.value_of(self.created_by_role)
return db.session.get(EndUser, self.created_by) \
if created_by_role == CreatedByRole.END_USER else None
class ConversationVariable(db.Model):
__tablename__ = 'workflow__conversation_variables'
id: Mapped[str] = db.Column(StringUUID, primary_key=True)
conversation_id: Mapped[str] = db.Column(StringUUID, nullable=False, primary_key=True)
app_id: Mapped[str] = db.Column(StringUUID, nullable=False, index=True)
data = db.Column(db.Text, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, index=True, server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp())
def __init__(self, *, id: str, app_id: str, conversation_id: str, data: str) -> None:
self.id = id
self.app_id = app_id
self.conversation_id = conversation_id
self.data = data
@classmethod
def from_variable(cls, *, app_id: str, conversation_id: str, variable: Variable) -> 'ConversationVariable':
obj = cls(
id=variable.id,
app_id=app_id,
conversation_id=conversation_id,
data=variable.model_dump_json(),
)
return obj
def to_variable(self) -> Variable:
mapping = json.loads(self.data)
return factory.build_variable_from_mapping(mapping)