Skip to content
Open
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 16 additions & 12 deletions api/models/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1026,19 +1026,21 @@ def get_embedding(self) -> list[float]:
return cast(list[float], pickle.loads(self.embedding)) # noqa: S301


class DatasetCollectionBinding(Base):
class DatasetCollectionBinding(TypeBase):
__tablename__ = "dataset_collection_bindings"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="dataset_collection_bindings_pkey"),
sa.Index("provider_model_name_idx", "provider_name", "model_name"),
)

id = mapped_column(StringUUID, primary_key=True, default=lambda: str(uuid4()))
id: Mapped[str] = mapped_column(StringUUID, primary_key=True, default=lambda: str(uuid4()), init=False)
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
type = mapped_column(String(40), server_default=sa.text("'dataset'"), nullable=False)
collection_name = mapped_column(String(64), nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
type: Mapped[str] = mapped_column(String(40), server_default=sa.text("'dataset'"), nullable=False)
collection_name: Mapped[str] = mapped_column(String(64), nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)


class TidbAuthBinding(Base):
Expand Down Expand Up @@ -1176,7 +1178,7 @@ class ExternalKnowledgeBindings(TypeBase):
)


class DatasetAutoDisableLog(Base):
class DatasetAutoDisableLog(TypeBase):
__tablename__ = "dataset_auto_disable_logs"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="dataset_auto_disable_log_pkey"),
Expand All @@ -1185,12 +1187,14 @@ class DatasetAutoDisableLog(Base):
sa.Index("dataset_auto_disable_log_created_atx", "created_at"),
)

id = mapped_column(StringUUID, default=lambda: str(uuid4()))
tenant_id = mapped_column(StringUUID, nullable=False)
dataset_id = mapped_column(StringUUID, nullable=False)
document_id = mapped_column(StringUUID, nullable=False)
notified: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=sa.func.current_timestamp())
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
document_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
notified: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"), default=False)
created_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=sa.func.current_timestamp(), init=False
)


class RateLimitLog(TypeBase):
Expand Down
78 changes: 34 additions & 44 deletions api/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from configs import dify_config
from constants import DEFAULT_FILE_NUMBER_LIMITS
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
from core.file import helpers as file_helpers
from core.tools.signature import sign_tool_file
from core.workflow.enums import WorkflowExecutionStatus
Expand Down Expand Up @@ -594,7 +594,7 @@
return tenant


class OAuthProviderApp(Base):
class OAuthProviderApp(TypeBase):
"""
Globally shared OAuth provider app information.
Only for Dify Cloud.
Expand All @@ -606,18 +606,21 @@
sa.Index("oauth_provider_app_client_id_idx", "client_id"),
)

id = mapped_column(StringUUID, default=lambda: str(uuidv7()))
app_icon = mapped_column(String(255), nullable=False)
app_label = mapped_column(sa.JSON, nullable=False, default="{}")
client_id = mapped_column(String(255), nullable=False)
client_secret = mapped_column(String(255), nullable=False)
redirect_uris = mapped_column(sa.JSON, nullable=False, default="[]")
scope = mapped_column(
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
app_icon: Mapped[str] = mapped_column(String(255), nullable=False)
app_label: Mapped[dict] = mapped_column(sa.JSON, nullable=False)
client_id: Mapped[str] = mapped_column(String(255), nullable=False)
client_secret: Mapped[str] = mapped_column(String(255), nullable=False)
redirect_uris: Mapped[list] = mapped_column(sa.JSON, nullable=False, default_factory=list)
scope: Mapped[str] = mapped_column(
String(255),
nullable=False,
server_default=sa.text("'read:name read:email read:avatar read:interface_language read:timezone'"),
default="read:name read:email read:avatar read:interface_language read:timezone",
)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())


class Conversation(Base):
Expand Down Expand Up @@ -1335,36 +1338,15 @@
}


class MessageFile(Base):
class MessageFile(TypeBase):
__tablename__ = "message_files"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="message_file_pkey"),
sa.Index("message_file_message_idx", "message_id"),
sa.Index("message_file_created_by_idx", "created_by"),
)

def __init__(
self,
*,
message_id: str,
type: FileType,
transfer_method: FileTransferMethod,
url: str | None = None,
belongs_to: Literal["user", "assistant"] | None = None,
upload_file_id: str | None = None,
created_by_role: CreatorUserRole,
created_by: str,
):
self.message_id = message_id
self.type = type
self.transfer_method = transfer_method
self.url = url
self.belongs_to = belongs_to
self.upload_file_id = upload_file_id
self.created_by_role = created_by_role.value
self.created_by = created_by

id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()))
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
type: Mapped[str] = mapped_column(String(255), nullable=False)
transfer_method: Mapped[str] = mapped_column(String(255), nullable=False)
Expand All @@ -1373,7 +1355,9 @@
upload_file_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
created_by_role: Mapped[str] = mapped_column(String(255), nullable=False)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
created_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)


class MessageAnnotation(Base):
Expand Down Expand Up @@ -1447,22 +1431,28 @@
return account


class AppAnnotationSetting(Base):
class AppAnnotationSetting(TypeBase):
__tablename__ = "app_annotation_settings"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="app_annotation_settings_pkey"),
sa.Index("app_annotation_settings_app_idx", "app_id"),
)

id = mapped_column(StringUUID, default=lambda: str(uuid4()))
app_id = mapped_column(StringUUID, nullable=False)
score_threshold = mapped_column(Float, nullable=False, server_default=sa.text("0"))
collection_binding_id = mapped_column(StringUUID, nullable=False)
created_user_id = mapped_column(StringUUID, nullable=False)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_user_id = mapped_column(StringUUID, nullable=False)
updated_at = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
score_threshold: Mapped[float] = mapped_column(Float, nullable=False, server_default=sa.text("0"), default=0.0)
collection_binding_id: Mapped[str] = mapped_column(StringUUID, nullable=False)

Check failure on line 1444 in api/models/model.py

View workflow job for this annotation

GitHub Actions / Style Check / Python Style

Fields without default values cannot appear after fields with default values (reportGeneralTypeIssues)
created_user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)

Check failure on line 1445 in api/models/model.py

View workflow job for this annotation

GitHub Actions / Style Check / Python Style

Fields without default values cannot appear after fields with default values (reportGeneralTypeIssues)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)
updated_user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)

Check failure on line 1449 in api/models/model.py

View workflow job for this annotation

GitHub Actions / Style Check / Python Style

Fields without default values cannot appear after fields with default values (reportGeneralTypeIssues)
updated_at: Mapped[datetime] = mapped_column(
sa.DateTime,
nullable=False,
server_default=func.current_timestamp(),
onupdate=func.current_timestamp(),
init=False,
)

@property
Expand Down
40 changes: 23 additions & 17 deletions api/models/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from libs.uuid_utils import uuidv7

from .base import Base, TypeBase
from .base import TypeBase
from .engine import db
from .types import LongText, StringUUID

Expand Down Expand Up @@ -262,7 +262,7 @@ class ProviderModelSetting(TypeBase):
)


class LoadBalancingModelConfig(Base):
class LoadBalancingModelConfig(TypeBase):
"""
Configurations for load balancing models.
"""
Expand All @@ -273,23 +273,25 @@ class LoadBalancingModelConfig(Base):
sa.Index("load_balancing_model_config_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"),
)

id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()))
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_type: Mapped[str] = mapped_column(String(40), nullable=False)
name: Mapped[str] = mapped_column(String(255), nullable=False)
encrypted_config: Mapped[str | None] = mapped_column(LongText, nullable=True)
credential_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
credential_source_type: Mapped[str | None] = mapped_column(String(40), nullable=True)
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("true"))
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
encrypted_config: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
credential_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
credential_source_type: Mapped[str | None] = mapped_column(String(40), nullable=True, default=None)
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("true"), default=True)
created_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), init=False
)


class ProviderCredential(Base):
class ProviderCredential(TypeBase):
"""
Provider credential - stores multiple named credentials for each provider
"""
Expand All @@ -300,18 +302,20 @@ class ProviderCredential(Base):
sa.Index("provider_credential_tenant_provider_idx", "tenant_id", "provider_name"),
)

id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()))
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
credential_name: Mapped[str] = mapped_column(String(255), nullable=False)
encrypted_config: Mapped[str] = mapped_column(LongText, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
created_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), init=False
)


class ProviderModelCredential(Base):
class ProviderModelCredential(TypeBase):
"""
Provider model credential - stores multiple named credentials for each provider model
"""
Expand All @@ -328,14 +332,16 @@ class ProviderModelCredential(Base):
),
)

id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()))
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_type: Mapped[str] = mapped_column(String(40), nullable=False)
credential_name: Mapped[str] = mapped_column(String(255), nullable=False)
encrypted_config: Mapped[str] = mapped_column(LongText, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
created_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), init=False
)
9 changes: 6 additions & 3 deletions api/models/trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,27 +129,30 @@


# tenant level trigger oauth client params (client_id, client_secret, etc.)
class TriggerOAuthTenantClient(Base):
class TriggerOAuthTenantClient(TypeBase):
__tablename__ = "trigger_oauth_tenant_clients"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="trigger_oauth_tenant_client_pkey"),
sa.UniqueConstraint("tenant_id", "plugin_id", "provider", name="unique_trigger_oauth_tenant_client"),
)

id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()))
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
# tenant id
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
plugin_id: Mapped[str] = mapped_column(String(255), nullable=False)
provider: Mapped[str] = mapped_column(String(255), nullable=False)
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"), default=True)
# oauth params of the trigger provider
encrypted_oauth_params: Mapped[str] = mapped_column(LongText, nullable=False)

Check failure on line 146 in api/models/trigger.py

View workflow job for this annotation

GitHub Actions / Style Check / Python Style

Fields without default values cannot appear after fields with default values (reportGeneralTypeIssues)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
created_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)
updated_at: Mapped[datetime] = mapped_column(
DateTime,
nullable=False,
server_default=func.current_timestamp(),
server_onupdate=func.current_timestamp(),
init=False,
)

@property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -852,12 +852,14 @@ def test_get_agent_logs_with_files(self, db_session_with_containers, mock_extern
# Add files to message
from models.model import MessageFile

assert message.from_account_id is not None
message_file1 = MessageFile(
message_id=message.id,
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
url="http://example.com/file1.jpg",
belongs_to="user",
upload_file_id=str(fake.uuid4()),
created_by_role=CreatorUserRole.ACCOUNT,
created_by=message.from_account_id,
)
Expand All @@ -867,6 +869,7 @@ def test_get_agent_logs_with_files(self, db_session_with_containers, mock_extern
transfer_method=FileTransferMethod.REMOTE_URL,
url="http://example.com/file2.png",
belongs_to="user",
upload_file_id=str(fake.uuid4()),
created_by_role=CreatorUserRole.ACCOUNT,
created_by=message.from_account_id,
)
Expand Down
Loading
Loading