Skip to content

Commit 6d7defc

Browse files
authored
feat(bigquery): add QueryJobConfig properties to bigquery backend specified at query time (#11255)
1 parent 9145757 commit 6d7defc

File tree

2 files changed

+183
-12
lines changed

2 files changed

+183
-12
lines changed

ibis/backends/bigquery/__init__.py

Lines changed: 68 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import concurrent.futures
66
import contextlib
7+
import copy
78
import glob
89
import os
910
import re
@@ -765,17 +766,18 @@ def _get_schema_using_query(self, query: str) -> sch.Schema:
765766
)
766767
return BigQuerySchema.to_ibis(job.schema)
767768

768-
def raw_sql(self, query: str, params=None) -> RowIterator:
769-
query_parameters = [
770-
bigquery_param(param.type(), value, param.get_name())
771-
for param, value in (params or {}).items()
772-
]
769+
def raw_sql(
770+
self,
771+
query: str,
772+
params: Mapping[ir.Scalar, Any] | None = None,
773+
query_job_config: bq.QueryJobConfig | None = None,
774+
) -> RowIterator:
775+
job_config = _merge_params_into_config(query_job_config, params)
776+
773777
with contextlib.suppress(AttributeError):
774778
query = query.sql(self.dialect)
775779

776-
job_config = bq.job.QueryJobConfig(query_parameters=query_parameters or [])
777-
778-
return self._client_query(
780+
return self.client.query_and_wait(
779781
query, job_config=job_config, project=self.billing_project
780782
)
781783

@@ -867,13 +869,18 @@ def _to_query(
867869
*,
868870
params: Mapping[ir.Scalar, Any] | None = None,
869871
limit: int | str | None = None,
872+
query_job_config: bq.QueryJobConfig | None = None,
870873
**kwargs: Any,
871874
) -> RowIterator:
872875
self._run_pre_execute_hooks(table_expr)
873876
sql = self.compile(table_expr, limit=limit, params=params, **kwargs)
874877
self._log(sql)
875878

876-
return self.raw_sql(sql, params=params)
879+
return self.raw_sql(
880+
sql,
881+
params=params,
882+
query_job_config=query_job_config,
883+
)
877884

878885
def to_pyarrow(
879886
self,
@@ -882,14 +889,21 @@ def to_pyarrow(
882889
*,
883890
params: Mapping[ir.Scalar, Any] | None = None,
884891
limit: int | str | None = None,
892+
query_job_config: bq.QueryJobConfig | None = None,
885893
**kwargs: Any,
886894
) -> pa.Table:
887895
self._import_pyarrow()
888896

889897
table_expr = expr.as_table()
890898
schema = table_expr.schema() - ibis.schema({"_TABLE_SUFFIX": "string"})
891899

892-
query = self._to_query(table_expr, params=params, limit=limit, **kwargs)
900+
query = self._to_query(
901+
table_expr,
902+
params=params,
903+
limit=limit,
904+
query_job_config=query_job_config,
905+
**kwargs,
906+
)
893907
table = query.to_arrow(
894908
progress_bar_type=None, bqstorage_client=self.storage_client
895909
)
@@ -904,6 +918,7 @@ def to_pyarrow_batches(
904918
params: Mapping[ir.Scalar, Any] | None = None,
905919
limit: int | str | None = None,
906920
chunk_size: int = 1_000_000,
921+
query_job_config: bq.QueryJobConfig | None = None,
907922
**kwargs: Any,
908923
):
909924
pa = self._import_pyarrow()
@@ -912,7 +927,13 @@ def to_pyarrow_batches(
912927
schema = table_expr.schema() - ibis.schema({"_TABLE_SUFFIX": "string"})
913928
colnames = list(schema.names)
914929

915-
query = self._to_query(table_expr, params=params, limit=limit, **kwargs)
930+
query = self._to_query(
931+
table_expr,
932+
params=params,
933+
limit=limit,
934+
query_job_config=query_job_config,
935+
**kwargs,
936+
)
916937
batch_iter = query.to_arrow_iterable(bqstorage_client=self.storage_client)
917938
return pa.ipc.RecordBatchReader.from_batches(
918939
schema.to_pyarrow(),
@@ -926,6 +947,7 @@ def execute(
926947
*,
927948
params: Mapping[ir.Scalar, Any] | None = None,
928949
limit: int | str | None = None,
950+
query_job_config: bq.QueryJobConfig | None = None,
929951
**kwargs: Any,
930952
) -> pd.DataFrame | pd.Series | Any:
931953
"""Compile and execute the given Ibis expression.
@@ -942,6 +964,8 @@ def execute(
942964
already set on the expression.
943965
params
944966
Query parameters
967+
query_job_config
968+
QueryJobConfig, the values in the `params` argument take precedence over the ones in this object
945969
kwargs
946970
Extra arguments specific to the backend
947971
@@ -955,7 +979,13 @@ def execute(
955979

956980
table_expr = expr.as_table()
957981
schema = table_expr.schema() - ibis.schema({"_TABLE_SUFFIX": "string"})
958-
query = self._to_query(table_expr, params=params, limit=limit, **kwargs)
982+
query = self._to_query(
983+
table_expr,
984+
params=params,
985+
limit=limit,
986+
query_job_config=query_job_config,
987+
**kwargs,
988+
)
959989
df = query.to_arrow(
960990
progress_bar_type=None, bqstorage_client=self.storage_client
961991
).to_pandas(timestamp_as_object=True)
@@ -1245,6 +1275,32 @@ def _safe_raw_sql(self, *args, **kwargs):
12451275
yield self.raw_sql(*args, **kwargs)
12461276

12471277

1278+
def _merge_params_into_config(
1279+
query_job_config: bq.QueryJobConfig | None = None,
1280+
params: Mapping[ir.Scalar, Any] | None = None,
1281+
) -> bq.QueryJobConfig:
1282+
"""Merge parameters into a QueryJobConfig.
1283+
1284+
Returns a copy of `query_job_config` with the `params` merged into the `query_parameters`
1285+
field. `params` will override values with a key naming conflict in `query_job_config`.
1286+
"""
1287+
1288+
if query_job_config is not None:
1289+
query_job_config = copy.deepcopy(query_job_config) # do not modify the input
1290+
else:
1291+
query_job_config = bq.QueryJobConfig()
1292+
1293+
config_params = {param.name: param for param in query_job_config.query_parameters}
1294+
1295+
params_as_bq = {
1296+
param.get_name(): bigquery_param(param.type(), value, param.get_name())
1297+
for param, value in (params or {}).items()
1298+
}
1299+
1300+
query_job_config.query_parameters = list({**config_params, **params_as_bq}.values())
1301+
return query_job_config
1302+
1303+
12481304
def compile(expr, params=None, **kwargs):
12491305
"""Compile an expression for BigQuery."""
12501306
backend = Backend()
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
from __future__ import annotations
2+
3+
from datetime import date
4+
5+
import pytest
6+
from google.cloud.bigquery import QueryJobConfig
7+
from google.cloud.bigquery.query import ScalarQueryParameter
8+
from google.cloud.bigquery.table import TableReference
9+
10+
import ibis
11+
from ibis.backends.bigquery import _merge_params_into_config
12+
13+
14+
@pytest.mark.parametrize(
15+
"query_job_config, params, expected",
16+
[
17+
(None, None, []),
18+
(QueryJobConfig(), None, []),
19+
(None, {}, []),
20+
(QueryJobConfig(), {}, []),
21+
(
22+
QueryJobConfig(
23+
query_parameters=[
24+
ScalarQueryParameter("param1", "INT64", 1),
25+
ScalarQueryParameter("param2", "INT64", 2),
26+
],
27+
),
28+
None,
29+
[
30+
ScalarQueryParameter("param1", "INT64", 1),
31+
ScalarQueryParameter("param2", "INT64", 2),
32+
],
33+
),
34+
(
35+
None,
36+
{
37+
ibis.literal(0).name("param1"): 1,
38+
ibis.literal(0).name("param2"): 2,
39+
},
40+
[
41+
ScalarQueryParameter("param1", "INT64", 1),
42+
ScalarQueryParameter("param2", "INT64", 2),
43+
],
44+
),
45+
(
46+
QueryJobConfig(
47+
query_parameters=[
48+
ScalarQueryParameter("param1", "INT64", 1),
49+
ScalarQueryParameter("param2", "INT64", 2),
50+
],
51+
),
52+
{
53+
ibis.literal(0).name("param2"): 3,
54+
ibis.literal(0).name("param3"): 4,
55+
},
56+
[
57+
ScalarQueryParameter("param1", "INT64", 1),
58+
ScalarQueryParameter("param2", "INT64", 3),
59+
ScalarQueryParameter("param3", "INT64", 4),
60+
],
61+
),
62+
(
63+
QueryJobConfig(
64+
query_parameters=[
65+
ScalarQueryParameter("config1", "BOOL", True),
66+
ScalarQueryParameter("config2", "INT64", 1),
67+
ScalarQueryParameter("config3", "FLOAT64", 2.3),
68+
ScalarQueryParameter("config4", "STRING", "abc"),
69+
ScalarQueryParameter("config5", "DATE", "2025-01-01"),
70+
],
71+
# ensure this is preserved
72+
destination=TableReference.from_string(
73+
"test_project.test_dataset.test_table",
74+
),
75+
),
76+
{
77+
ibis.literal(False).name("param1"): False,
78+
ibis.literal(0).name("param2"): 4,
79+
ibis.literal(0.0).name("param3"): 5.6,
80+
ibis.literal("").name("param4"): "def",
81+
ibis.literal(date.today()).name("param5"): date(2025, 1, 2),
82+
},
83+
[
84+
ScalarQueryParameter("config1", "BOOL", True),
85+
ScalarQueryParameter("config2", "INT64", 1),
86+
ScalarQueryParameter("config3", "FLOAT64", 2.3),
87+
ScalarQueryParameter("config4", "STRING", "abc"),
88+
ScalarQueryParameter("config5", "DATE", date(2025, 1, 1)),
89+
ScalarQueryParameter("param1", "BOOL", False),
90+
ScalarQueryParameter("param2", "INT64", 4),
91+
ScalarQueryParameter("param3", "FLOAT64", 5.6),
92+
ScalarQueryParameter("param4", "STRING", "def"),
93+
ScalarQueryParameter("param5", "DATE", date(2025, 1, 2)),
94+
],
95+
),
96+
],
97+
)
98+
def test__merge_params_into_config(query_job_config, params, expected):
99+
# check the merge is correct
100+
result = _merge_params_into_config(query_job_config, params)
101+
assert result is not query_job_config
102+
assert result.query_parameters == expected
103+
104+
# check all the other fields are preserved
105+
if query_job_config is not None:
106+
expected_repr = query_job_config.to_api_repr()
107+
result_repr = result.to_api_repr()
108+
109+
if "queryParameters" in expected_repr["query"]:
110+
del expected_repr["query"]["queryParameters"]
111+
112+
if "queryParameters" in result_repr["query"]:
113+
del result_repr["query"]["queryParameters"]
114+
115+
assert result_repr == expected_repr

0 commit comments

Comments
 (0)