Skip to content

Commit ae93ffe

Browse files
tcyameterstick-copybara
authored andcommitted
Allow to change the dialect of the SQL generation in compute_on_beam and add support for Calcite dialect. Currently, Meterstick beam uses GoogleSQL as the default SQL dialect. Now the user can specify dialect=Calcite to use Calcite as the dialect, which is the default of Beam SQL.
PiperOrigin-RevId: 785632281
1 parent 1be1eb7 commit ae93ffe

File tree

3 files changed

+28
-5
lines changed

3 files changed

+28
-5
lines changed

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,7 @@ the top of `sql.py` file.
351351
There is also a
352352

353353
```python
354-
compute_on_beam(pcol, split_by=None, execute=None, melted=False, mode=None)
354+
compute_on_beam(pcol, split_by=None, execute=None, melted=False, mode=None, dialect=None)
355355
```
356356

357357
method which takes an [`PCollection`](https://beam.apache.org/documentation/programming-guide/#pcollections)
@@ -365,7 +365,8 @@ As a result,
365365
the [InteractiveRunner](https://github.com/apache/beam/blob/master/sdks/python/apache_beam/runners/interactive/README.md)
366366
does [NOT](https://issues.apache.org/jira/browse/BEAM-10708).
367367
- The config of the pipeline that carries the `PCollection` is set up by you.
368-
For example, your setup decides if the pipeline will be ran in process or in Cloud.
368+
For example, your setup decides if the pipeline will be ran in process or in
369+
Cloud.
369370

370371
## Custom Metric
371372

metrics.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def compute_on_beam(
7474
cache_key=None,
7575
cache=None,
7676
sql_transform_kwargs=None,
77+
dialect=None,
7778
**kwargs,
7879
):
7980
"""A wrapper for metric.compute_on_beam()."""
@@ -86,6 +87,7 @@ def compute_on_beam(
8687
cache_key,
8788
cache=cache,
8889
sql_transform_kwargs=sql_transform_kwargs,
90+
dialect=dialect,
8991
**kwargs,
9092
)
9193

@@ -841,6 +843,7 @@ def compute_on_beam(
841843
cache_key=None,
842844
cache=None,
843845
sql_transform_kwargs=None,
846+
dialect=None,
844847
**kwargs,
845848
):
846849
"""Computes on an Apache Beam PCollection input.
@@ -862,6 +865,9 @@ def compute_on_beam(
862865
sql_transform_kwargs: A dict that holds the kwargs to be passed to
863866
SqlTransform defined in
864867
https://beam.apache.org/releases/pydoc/2.30.0/apache_beam.transforms.sql.html.
868+
dialect: The dialect of the SQL query. If not specified, it will be
869+
the current DIALECT variable in sql.py. The DIALECT variable will be
870+
changed during the computation of this metric and restored after that.
865871
**kwargs: Other kwargs passed to compute_on_sql.
866872
867873
Returns:
@@ -887,7 +893,9 @@ def e(q):
887893

888894
# pylint: disable=g-import-not-at-top
889895

896+
current_dialect = sql.DIALECT
890897
try:
898+
sql.set_dialect(dialect)
891899
return self.compute_on_sql(
892900
'PCOLLECTION', split_by, e, melted, mode, cache_key, cache, **kwargs
893901
)
@@ -899,6 +907,8 @@ def e(q):
899907
"compute_on_beam(..., mode='mixed')."
900908
) from e
901909
raise
910+
finally:
911+
sql.set_dialect(current_dialect)
902912

903913
def compute_equivalent(self, df, split_by=None):
904914
equiv, df = utils.get_fully_expanded_equivalent_metric_tree(self, df)

sql.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def array_agg_fn_not_implemented(
171171
raise NotImplementedError('ARRAY_AGG is not implemented.')
172172

173173

174-
def array_index_fn_googlesql(array: str, zero_based_idx: int):
174+
def array_index_safe_offset_fn(array: str, zero_based_idx: int):
175175
return f'{array}[SAFE_OFFSET({zero_based_idx})]'
176176

177177

@@ -371,6 +371,7 @@ def duplicate_data_n_times_not_implemented(n, alias: Optional[str] = None):
371371
'MariaDB': drop_temp_table_if_exists_then_create_temp_table,
372372
'Oracle': create_temp_table_fn_not_implemented,
373373
'SQL Server': 'SELECT * INTO #{alias} FROM ({query});'.format,
374+
'Calcite': 'CREATE OR REPLACE TEMPORARY TABLE {alias} AS {query};'.format,
374375
}
375376
SUPPORT_FULL_JOIN_OPTIONS = {
376377
'Default': True,
@@ -391,6 +392,7 @@ def duplicate_data_n_times_not_implemented(n, alias: Optional[str] = None):
391392
'Default': lambda columns: ', '.join(columns.aliases),
392393
'SQL Server': lambda columns: ', '.join(columns.expressions),
393394
'Trino': lambda columns: ', '.join(map(str, range(1, len(columns) + 1))),
395+
'Calcite': lambda columns: ', '.join(columns.expressions),
394396
}
395397
SAFE_DIVIDE_OPTIONS = {
396398
'Default': safe_divide_fn_default,
@@ -405,6 +407,7 @@ def duplicate_data_n_times_not_implemented(n, alias: Optional[str] = None):
405407
'Oracle': 'DBMS_RANDOM.VALUE'.format,
406408
'SQL Server': sql_server_rand_fn_not_implemented,
407409
'SQLite': '0.5 - RANDOM() / CAST(-9223372036854775808 AS REAL) / 2'.format,
410+
'Calcite': 'RAND()'.format,
408411
}
409412
# Manually evalueated run_only_once_in_with_clause for each dialect.
410413
NEED_TEMP_TABLE_OPTIONS = {
@@ -415,6 +418,7 @@ def duplicate_data_n_times_not_implemented(n, alias: Optional[str] = None):
415418
'SQL Server': True,
416419
'Trino': True,
417420
'SQLite': False,
421+
'Calcite': True,
418422
}
419423
CEIL_OPTIONS = {
420424
'Default': 'CEIL({})'.format,
@@ -426,6 +430,7 @@ def duplicate_data_n_times_not_implemented(n, alias: Optional[str] = None):
426430
'PostgreSQL': percentile_cont_fn,
427431
'Oracle': percentile_cont_fn,
428432
'Trino': approx_percentile_fn,
433+
'Calcite': percentile_cont_fn,
429434
}
430435
ARRAY_AGG_OPTIONS = {
431436
'Default': array_agg_fn_not_implemented,
@@ -436,14 +441,16 @@ def duplicate_data_n_times_not_implemented(n, alias: Optional[str] = None):
436441
# JSON_ARRAYAGG has been added in SQL Server 2025. Will update later.
437442
'SQL Server': array_agg_fn_not_implemented,
438443
'Trino': array_agg_fn_no_use_filter_no_limit,
444+
'Calcite': array_agg_fn_no_use_filter_no_limit,
439445
}
440446
ARRAY_INDEX_OPTIONS = {
441447
'Default': array_index_fn_not_implemented,
442-
'GoogleSQL': array_index_fn_googlesql,
448+
'GoogleSQL': array_index_safe_offset_fn,
443449
'PostgreSQL': array_subscript_fn,
444450
'MariaDB': json_extract_fn,
445451
'Oracle': json_value_fn,
446452
'Trino': element_at_index_fn,
453+
'Calcite': array_index_safe_offset_fn,
447454
}
448455
NTH_OPTIONS = {
449456
'Default': nth_fn_default,
@@ -463,6 +470,7 @@ def duplicate_data_n_times_not_implemented(n, alias: Optional[str] = None):
463470
'Oracle': 'TO_CHAR({})'.format,
464471
'SQL Server': 'CAST({} AS VARCHAR(MAX))'.format,
465472
'Trino': 'CAST({} AS VARCHAR)'.format,
473+
'Calcite': 'CAST({} AS VARCHAR)'.format,
466474
}
467475
UNIFORM_MAPPING_OPTIONS = {
468476
'Default': uniform_mapping_fn_not_implemented,
@@ -486,12 +494,14 @@ def duplicate_data_n_times_not_implemented(n, alias: Optional[str] = None):
486494
'MariaDB': unnest_json_array_fn,
487495
'Oracle': unnest_json_array_fn,
488496
'Trino': unnest_array_with_ordinality_fn,
497+
'Calcite': unnest_array_with_ordinality_fn,
489498
}
490499
UNNEST_ARRAY_LITERAL_OPTIONS = {
491500
'Default': unnest_array_literal_fn_not_implemented,
492501
'GoogleSQL': unnest_array_literal_fn_googlesql,
493502
'PostgreSQL': unnest_array_literal_fn_postgresql,
494503
'Trino': unnest_array_literal_fn_postgresql,
504+
'Calcite': unnest_array_literal_fn_postgresql,
495505
}
496506
GENERATE_ARRAY_OPTIONS = {
497507
'Default': generate_array_fn_not_implemented,
@@ -513,11 +523,13 @@ def duplicate_data_n_times_not_implemented(n, alias: Optional[str] = None):
513523
}
514524

515525

516-
def set_dialect(dialect: str):
526+
def set_dialect(dialect: Optional[str]):
517527
"""Sets the dialect of the SQL query."""
518528
# You can manually override the options below. You can manually test it in
519529
# https://colab.research.google.com/drive/1y3UigzEby1anMM3-vXocBx7V8LVblIAp?usp=sharing.
520530
global DIALECT, NEED_TEMP_TABLE, CREATE_TEMP_TABLE_FN, SUPPORT_FULL_JOIN, ROW_NUMBER_REQUIRE_ORDER_BY, GROUP_BY_FN, RAND_FN, CEIL_FN, SAFE_DIVIDE_FN, QUANTILE_FN, ARRAY_AGG_FN, ARRAY_INDEX_FN, NTH_VALUE_FN, COUNTIF_FN, STRING_CAST_FN, FLOAT_CAST_FN, UNIFORM_MAPPING_FN, UNNEST_ARRAY_FN, UNNEST_ARRAY_LITERAL_FN, GENERATE_ARRAY_FN, DUPLICATE_DATA_N_TIMES_FN
531+
if not dialect:
532+
return
521533
DIALECT = dialect
522534
NEED_TEMP_TABLE = _get_dialect_option(NEED_TEMP_TABLE_OPTIONS)
523535
CREATE_TEMP_TABLE_FN = _get_dialect_option(CREATE_TEMP_TABLE_OPTIONS)

0 commit comments

Comments
 (0)