Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def _get_databricks_job_spec(task_template: TaskTemplate) -> dict:
if not new_cluster.get("docker_image"):
new_cluster["docker_image"] = {"url": container.image}
if not new_cluster.get("spark_conf"):
new_cluster["spark_conf"] = custom["sparkConf"]
new_cluster["spark_conf"] = custom.get("sparkConf", {})
if not new_cluster.get("spark_env_vars"):
new_cluster["spark_env_vars"] = {k: v for k, v in envs.items()}
else:
Expand Down
42 changes: 41 additions & 1 deletion plugins/flytekit-spark/tests/test_spark_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from flytekit import PodTemplate
from flytekit.core import context_manager
from flytekitplugins.spark import Spark
from flytekitplugins.spark.task import Databricks, new_spark_session
from flytekitplugins.spark.task import Databricks, DatabricksV2, new_spark_session
from pyspark.sql import SparkSession

import flytekit
Expand Down Expand Up @@ -135,6 +135,46 @@ def my_databricks(a: int) -> int:
assert my_databricks(a=3) == 3


@pytest.mark.parametrize("spark_conf", [None, {"spark": "2"}])
def test_databricks_v2(reset_spark_session, spark_conf):
databricks_conf = {
"name": "flytekit databricks plugin example",
"new_cluster": {
"spark_version": "11.0.x-scala2.12",
"node_type_id": "r3.xlarge",
"aws_attributes": {"availability": "ON_DEMAND"},
"num_workers": 4,
"docker_image": {"url": "pingsutw/databricks:latest"},
},
"timeout_seconds": 3600,
"max_retries": 1,
"spark_python_task": {
"python_file": "dbfs:///FileStore/tables/entrypoint-1.py",
"parameters": "ls",
},
}

databricks_instance = "account.cloud.databricks.com"

@task(
task_config=DatabricksV2(
databricks_conf=databricks_conf,
databricks_instance=databricks_instance,
spark_conf=spark_conf,
)
)
def my_databricks(a: int) -> int:
session = flytekit.current_context().spark_session
assert session.sparkContext.appName == "FlyteSpark: ex:local:local:local"
return a

assert my_databricks.task_config is not None
assert my_databricks.task_config.databricks_conf == databricks_conf
assert my_databricks.task_config.databricks_instance == databricks_instance
assert my_databricks.task_config.spark_conf == (spark_conf or {})
assert my_databricks(a=3) == 3


def test_new_spark_session():
name = "SessionName"
spark_conf = {"spark1": "1", "spark2": "2"}
Expand Down
Loading