Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions tests/trace_server/test_clickhouse_trace_server_migrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,20 @@ def test_create_distributed_table_sql():
assert sql.strip() == expected.strip()


def test_create_distributed_table_sql_id_sharded():
"""Test distributed table creation SQL for ID-sharded tables."""
distributed_migrator = DistributedClickHouseTraceServerMigrator(
Mock(), replicated_cluster="test_cluster", migration_dir=DEFAULT_MIGRATION_DIR
)
sql = distributed_migrator._create_distributed_table_sql("calls_complete")
expected = """
CREATE TABLE IF NOT EXISTS calls_complete ON CLUSTER test_cluster
AS calls_complete_local
ENGINE = Distributed(test_cluster, currentDatabase(), calls_complete_local, sipHash64(id))
"""
assert sql.strip() == expected.strip()


def test_format_distributed_sql():
"""Test distributed SQL formatting for CREATE TABLE and other DDL."""
distributed_migrator = DistributedClickHouseTraceServerMigrator(
Expand Down
35 changes: 29 additions & 6 deletions weave/trace_server/clickhouse_trace_server_migrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,12 @@
# Constants for table naming conventions
VIEW_SUFFIX = "_view"

# Tables that use ID-based sharding (sipHash64(field)) instead of random sharding
# in distributed mode. Maps table name to the field used for sharding.
# This ensures all data for a specific ID goes to the same shard, enabling
# efficient point lookups.
ID_SHARDED_TABLES: dict[str, str] = {"calls_complete": "id"}


@dataclass(frozen=True)
class PostMigrationHookContext:
Expand Down Expand Up @@ -510,28 +516,36 @@ def _add_on_cluster_clause(self, sql_query: str) -> str:
# ALTER TABLE
if SQLPatterns.ALTER_TABLE_STMT.search(sql_query):
return SQLPatterns.ALTER_TABLE_NAME_PATTERN.sub(
lambda m: f"{m.group(1)}{m.group(2)} ON CLUSTER {self.replicated_cluster}{m.group(3)}",
lambda m: (
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lint

f"{m.group(1)}{m.group(2)} ON CLUSTER {self.replicated_cluster}{m.group(3)}"
),
sql_query,
)

# CREATE TABLE
if SQLPatterns.CREATE_TABLE_STMT.search(sql_query):
return SQLPatterns.CREATE_TABLE_NAME_PATTERN.sub(
lambda m: f"{m.group(1)}{m.group(2)} ON CLUSTER {self.replicated_cluster}{m.group(3)}",
lambda m: (
f"{m.group(1)}{m.group(2)} ON CLUSTER {self.replicated_cluster}{m.group(3)}"
),
sql_query,
)

# DROP VIEW
if SQLPatterns.DROP_VIEW_STMT.search(sql_query):
return SQLPatterns.DROP_VIEW_NAME_PATTERN.sub(
lambda m: f"{m.group(1)}{m.group(2)} ON CLUSTER {self.replicated_cluster}{m.group(3)}",
lambda m: (
f"{m.group(1)}{m.group(2)} ON CLUSTER {self.replicated_cluster}{m.group(3)}"
),
sql_query,
)

# CREATE VIEW / CREATE MATERIALIZED VIEW
if SQLPatterns.CREATE_VIEW_STMT.search(sql_query):
return SQLPatterns.CREATE_VIEW_NAME_PATTERN.sub(
lambda m: f"{m.group(1)}{m.group(2)} ON CLUSTER {self.replicated_cluster}{m.group(3)}",
lambda m: (
f"{m.group(1)}{m.group(2)} ON CLUSTER {self.replicated_cluster}{m.group(3)}"
),
sql_query,
)

Expand Down Expand Up @@ -781,12 +795,21 @@ def _format_distributed_sql(self, sql_query: str) -> DistributedTransformResult:
)

def _create_distributed_table_sql(self, table_name: str) -> str:
"""Generate SQL to create a distributed table."""
"""Generate SQL to create a distributed table.

For tables in ID_SHARDED_TABLES, uses sipHash64(field) as the sharding key
to ensure all data for a specific ID goes to the same shard, enabling
efficient point lookups. Other tables use rand() for even distribution.
"""
local_table_name = table_name + ch_settings.LOCAL_TABLE_SUFFIX
if shard_field := ID_SHARDED_TABLES.get(table_name):
sharding_key = f"sipHash64({shard_field})"
else:
sharding_key = "rand()"
return f"""
CREATE TABLE IF NOT EXISTS {table_name} ON CLUSTER {self.replicated_cluster}
AS {local_table_name}
ENGINE = Distributed({self.replicated_cluster}, currentDatabase(), {local_table_name}, rand())
ENGINE = Distributed({self.replicated_cluster}, currentDatabase(), {local_table_name}, {sharding_key})
"""

@staticmethod
Expand Down