|
10 | 10 | from flytekit import PodTemplate |
11 | 11 | from flytekit.core import context_manager |
12 | 12 | from flytekitplugins.spark import Spark |
13 | | -from flytekitplugins.spark.task import Databricks, new_spark_session |
| 13 | +from flytekitplugins.spark.task import Databricks, DatabricksV2, new_spark_session |
14 | 14 | from pyspark.sql import SparkSession |
15 | 15 |
|
16 | 16 | import flytekit |
@@ -135,6 +135,46 @@ def my_databricks(a: int) -> int: |
135 | 135 | assert my_databricks(a=3) == 3 |
136 | 136 |
|
137 | 137 |
|
| 138 | +@pytest.mark.parametrize("spark_conf", [None, {"spark": "2"}]) |
| 139 | +def test_databricks_v2(reset_spark_session, spark_conf): |
| 140 | + databricks_conf = { |
| 141 | + "name": "flytekit databricks plugin example", |
| 142 | + "new_cluster": { |
| 143 | + "spark_version": "11.0.x-scala2.12", |
| 144 | + "node_type_id": "r3.xlarge", |
| 145 | + "aws_attributes": {"availability": "ON_DEMAND"}, |
| 146 | + "num_workers": 4, |
| 147 | + "docker_image": {"url": "pingsutw/databricks:latest"}, |
| 148 | + }, |
| 149 | + "timeout_seconds": 3600, |
| 150 | + "max_retries": 1, |
| 151 | + "spark_python_task": { |
| 152 | + "python_file": "dbfs:///FileStore/tables/entrypoint-1.py", |
| 153 | + "parameters": "ls", |
| 154 | + }, |
| 155 | + } |
| 156 | + |
| 157 | + databricks_instance = "account.cloud.databricks.com" |
| 158 | + |
| 159 | + @task( |
| 160 | + task_config=DatabricksV2( |
| 161 | + databricks_conf=databricks_conf, |
| 162 | + databricks_instance=databricks_instance, |
| 163 | + spark_conf=spark_conf, |
| 164 | + ) |
| 165 | + ) |
| 166 | + def my_databricks(a: int) -> int: |
| 167 | + session = flytekit.current_context().spark_session |
| 168 | + assert session.sparkContext.appName == "FlyteSpark: ex:local:local:local" |
| 169 | + return a |
| 170 | + |
| 171 | + assert my_databricks.task_config is not None |
| 172 | + assert my_databricks.task_config.databricks_conf == databricks_conf |
| 173 | + assert my_databricks.task_config.databricks_instance == databricks_instance |
| 174 | + assert my_databricks.task_config.spark_conf == (spark_conf or {}) |
| 175 | + assert my_databricks(a=3) == 3 |
| 176 | + |
| 177 | + |
138 | 178 | def test_new_spark_session(): |
139 | 179 | name = "SessionName" |
140 | 180 | spark_conf = {"spark1": "1", "spark2": "2"} |
|
0 commit comments