diff --git a/tests/trace_server/test_clickhouse_trace_server_migrator.py b/tests/trace_server/test_clickhouse_trace_server_migrator.py index 30f37878c571..07f3f320c5e4 100644 --- a/tests/trace_server/test_clickhouse_trace_server_migrator.py +++ b/tests/trace_server/test_clickhouse_trace_server_migrator.py @@ -387,6 +387,20 @@ def test_create_distributed_table_sql(): assert sql.strip() == expected.strip() +def test_create_distributed_table_sql_id_sharded(): + """Test distributed table creation SQL for ID-sharded tables.""" + distributed_migrator = DistributedClickHouseTraceServerMigrator( + Mock(), replicated_cluster="test_cluster", migration_dir=DEFAULT_MIGRATION_DIR + ) + sql = distributed_migrator._create_distributed_table_sql("calls_complete") + expected = """ + CREATE TABLE IF NOT EXISTS calls_complete ON CLUSTER test_cluster + AS calls_complete_local + ENGINE = Distributed(test_cluster, currentDatabase(), calls_complete_local, sipHash64(id)) + """ + assert sql.strip() == expected.strip() + + def test_format_distributed_sql(): """Test distributed SQL formatting for CREATE TABLE and other DDL.""" distributed_migrator = DistributedClickHouseTraceServerMigrator( diff --git a/weave/trace_server/clickhouse_trace_server_migrator.py b/weave/trace_server/clickhouse_trace_server_migrator.py index 750492db2d7b..89894d39bc3a 100644 --- a/weave/trace_server/clickhouse_trace_server_migrator.py +++ b/weave/trace_server/clickhouse_trace_server_migrator.py @@ -85,6 +85,12 @@ # Constants for table naming conventions VIEW_SUFFIX = "_view" +# Tables that use ID-based sharding (sipHash64(field)) instead of random sharding +# in distributed mode. Maps table name to the field used for sharding. +# This ensures all data for a specific ID goes to the same shard, enabling +# efficient point lookups. +ID_SHARDED_TABLES: dict[str, str] = {"calls_complete": "id"} + @dataclass(frozen=True) class PostMigrationHookContext: @@ -510,28 +516,36 @@ def _add_on_cluster_clause(self, sql_query: str) -> str: # ALTER TABLE if SQLPatterns.ALTER_TABLE_STMT.search(sql_query): return SQLPatterns.ALTER_TABLE_NAME_PATTERN.sub( - lambda m: f"{m.group(1)}{m.group(2)} ON CLUSTER {self.replicated_cluster}{m.group(3)}", + lambda m: ( + f"{m.group(1)}{m.group(2)} ON CLUSTER {self.replicated_cluster}{m.group(3)}" + ), sql_query, ) # CREATE TABLE if SQLPatterns.CREATE_TABLE_STMT.search(sql_query): return SQLPatterns.CREATE_TABLE_NAME_PATTERN.sub( - lambda m: f"{m.group(1)}{m.group(2)} ON CLUSTER {self.replicated_cluster}{m.group(3)}", + lambda m: ( + f"{m.group(1)}{m.group(2)} ON CLUSTER {self.replicated_cluster}{m.group(3)}" + ), sql_query, ) # DROP VIEW if SQLPatterns.DROP_VIEW_STMT.search(sql_query): return SQLPatterns.DROP_VIEW_NAME_PATTERN.sub( - lambda m: f"{m.group(1)}{m.group(2)} ON CLUSTER {self.replicated_cluster}{m.group(3)}", + lambda m: ( + f"{m.group(1)}{m.group(2)} ON CLUSTER {self.replicated_cluster}{m.group(3)}" + ), sql_query, ) # CREATE VIEW / CREATE MATERIALIZED VIEW if SQLPatterns.CREATE_VIEW_STMT.search(sql_query): return SQLPatterns.CREATE_VIEW_NAME_PATTERN.sub( - lambda m: f"{m.group(1)}{m.group(2)} ON CLUSTER {self.replicated_cluster}{m.group(3)}", + lambda m: ( + f"{m.group(1)}{m.group(2)} ON CLUSTER {self.replicated_cluster}{m.group(3)}" + ), sql_query, ) @@ -781,12 +795,21 @@ def _format_distributed_sql(self, sql_query: str) -> DistributedTransformResult: ) def _create_distributed_table_sql(self, table_name: str) -> str: - """Generate SQL to create a distributed table.""" + """Generate SQL to create a distributed table. + + For tables in ID_SHARDED_TABLES, uses sipHash64(field) as the sharding key + to ensure all data for a specific ID goes to the same shard, enabling + efficient point lookups. Other tables use rand() for even distribution. + """ local_table_name = table_name + ch_settings.LOCAL_TABLE_SUFFIX + if shard_field := ID_SHARDED_TABLES.get(table_name): + sharding_key = f"sipHash64({shard_field})" + else: + sharding_key = "rand()" return f""" CREATE TABLE IF NOT EXISTS {table_name} ON CLUSTER {self.replicated_cluster} AS {local_table_name} - ENGINE = Distributed({self.replicated_cluster}, currentDatabase(), {local_table_name}, rand()) + ENGINE = Distributed({self.replicated_cluster}, currentDatabase(), {local_table_name}, {sharding_key}) """ @staticmethod