Skip to content

Commit 3da26d9

Browse files
authored
[BUG] KeyError: 'sparkConf' occurs when running a Databricks task without spark_conf (#3263)
Signed-off-by: machichima <[email protected]>
1 parent 614fbea commit 3da26d9

File tree

2 files changed

+42
-2
lines changed

2 files changed

+42
-2
lines changed

plugins/flytekit-spark/flytekitplugins/spark/connector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def _get_databricks_job_spec(task_template: TaskTemplate) -> dict:
4040
if not new_cluster.get("docker_image"):
4141
new_cluster["docker_image"] = {"url": container.image}
4242
if not new_cluster.get("spark_conf"):
43-
new_cluster["spark_conf"] = custom["sparkConf"]
43+
new_cluster["spark_conf"] = custom.get("sparkConf", {})
4444
if not new_cluster.get("spark_env_vars"):
4545
new_cluster["spark_env_vars"] = {k: v for k, v in envs.items()}
4646
else:

plugins/flytekit-spark/tests/test_spark_task.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from flytekit import PodTemplate
1111
from flytekit.core import context_manager
1212
from flytekitplugins.spark import Spark
13-
from flytekitplugins.spark.task import Databricks, new_spark_session
13+
from flytekitplugins.spark.task import Databricks, DatabricksV2, new_spark_session
1414
from pyspark.sql import SparkSession
1515

1616
import flytekit
@@ -135,6 +135,46 @@ def my_databricks(a: int) -> int:
135135
assert my_databricks(a=3) == 3
136136

137137

138+
@pytest.mark.parametrize("spark_conf", [None, {"spark": "2"}])
139+
def test_databricks_v2(reset_spark_session, spark_conf):
140+
databricks_conf = {
141+
"name": "flytekit databricks plugin example",
142+
"new_cluster": {
143+
"spark_version": "11.0.x-scala2.12",
144+
"node_type_id": "r3.xlarge",
145+
"aws_attributes": {"availability": "ON_DEMAND"},
146+
"num_workers": 4,
147+
"docker_image": {"url": "pingsutw/databricks:latest"},
148+
},
149+
"timeout_seconds": 3600,
150+
"max_retries": 1,
151+
"spark_python_task": {
152+
"python_file": "dbfs:///FileStore/tables/entrypoint-1.py",
153+
"parameters": "ls",
154+
},
155+
}
156+
157+
databricks_instance = "account.cloud.databricks.com"
158+
159+
@task(
160+
task_config=DatabricksV2(
161+
databricks_conf=databricks_conf,
162+
databricks_instance=databricks_instance,
163+
spark_conf=spark_conf,
164+
)
165+
)
166+
def my_databricks(a: int) -> int:
167+
session = flytekit.current_context().spark_session
168+
assert session.sparkContext.appName == "FlyteSpark: ex:local:local:local"
169+
return a
170+
171+
assert my_databricks.task_config is not None
172+
assert my_databricks.task_config.databricks_conf == databricks_conf
173+
assert my_databricks.task_config.databricks_instance == databricks_instance
174+
assert my_databricks.task_config.spark_conf == (spark_conf or {})
175+
assert my_databricks(a=3) == 3
176+
177+
138178
def test_new_spark_session():
139179
name = "SessionName"
140180
spark_conf = {"spark1": "1", "spark2": "2"}

0 commit comments

Comments
 (0)