Skip to content

Commit 4495e93

Browse files
Add Snowflake plugin with integration test
1 parent a6f63b1 commit 4495e93

File tree

6 files changed

+279
-1
lines changed

6 files changed

+279
-1
lines changed

core/pom.xml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,11 @@
139139
<artifactId>mongo-spark-connector_${scala.binary.version}</artifactId>
140140
<optional>true</optional>
141141
</dependency>
142+
<dependency>
143+
<groupId>net.snowflake</groupId>
144+
<artifactId>spark-snowflake_${scala.binary.version}</artifactId>
145+
<optional>true</optional>
146+
</dependency>
142147
<dependency>
143148
<groupId>org.elasticsearch</groupId>
144149
<artifactId>elasticsearch-hadoop</artifactId>
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
/*
2+
* Copyright 2021 ABSA Group Limited
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package za.co.absa.spline.harvester.plugin.embedded
18+
19+
import za.co.absa.spline.commons.reflect.ReflectionUtils.extractValue
20+
import org.apache.spark.sql.SparkSession
21+
import org.apache.spark.sql.execution.datasources.{LogicalRelation, SaveIntoDataSourceCommand}
22+
import org.apache.spark.sql.sources.BaseRelation
23+
import za.co.absa.spline.commons.reflect.extractors.SafeTypeMatchingExtractor
24+
import za.co.absa.spline.harvester.builder.SourceIdentifier
25+
import za.co.absa.spline.harvester.plugin.Plugin.{Precedence, ReadNodeInfo, WriteNodeInfo}
26+
import za.co.absa.spline.harvester.plugin.embedded.SnowflakePlugin._
27+
import za.co.absa.spline.harvester.plugin.{BaseRelationProcessing, Plugin, RelationProviderProcessing}
28+
29+
import javax.annotation.Priority
30+
import scala.language.reflectiveCalls
31+
32+
@Priority(Precedence.Normal)
33+
class SnowflakePlugin(spark: SparkSession)
34+
extends Plugin
35+
with BaseRelationProcessing
36+
with RelationProviderProcessing {
37+
38+
import za.co.absa.spline.commons.ExtractorImplicits._
39+
40+
override def baseRelationProcessor: PartialFunction[(BaseRelation, LogicalRelation), ReadNodeInfo] = {
41+
case (`_: SnowflakeRelation`(r), _) =>
42+
val params = extractValue[net.snowflake.spark.snowflake.Parameters.MergedParameters](r, "params")
43+
44+
val url: String = params.sfURL
45+
val warehouse: String = params.sfWarehouse.getOrElse("")
46+
val database: String = params.sfDatabase
47+
val schema: String = params.sfSchema
48+
val table: String = params.table.getOrElse("").toString
49+
50+
ReadNodeInfo(asSourceId(url, warehouse, database, schema, table), Map.empty)
51+
}
52+
53+
override def relationProviderProcessor: PartialFunction[(AnyRef, SaveIntoDataSourceCommand), WriteNodeInfo] = {
54+
case (rp, cmd) if rp == "net.snowflake.spark.snowflake.DefaultSource" || SnowflakeSourceExtractor.matches(rp) =>
55+
val url: String = cmd.options("sfUrl")
56+
val warehouse: String = cmd.options("sfWarehouse")
57+
val database: String = cmd.options("sfDatabase")
58+
val schema: String = cmd.options("sfSchema")
59+
val table: String = cmd.options("dbtable")
60+
61+
WriteNodeInfo(asSourceId(url, warehouse, database, schema, table), cmd.mode, cmd.query, cmd.options) }
62+
}
63+
64+
object SnowflakePlugin {
65+
66+
private object `_: SnowflakeRelation` extends SafeTypeMatchingExtractor[AnyRef]("net.snowflake.spark.snowflake.SnowflakeRelation")
67+
68+
private object SnowflakeSourceExtractor extends SafeTypeMatchingExtractor(classOf[net.snowflake.spark.snowflake.DefaultSource])
69+
70+
private def asSourceId(url: String, warehouse: String, database: String, schema: String, table: String) =
71+
SourceIdentifier(Some("snowflake"), s"snowflake://$url.$warehouse.$database.$schema.$table")
72+
}
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
/*
2+
* Copyright 2021 ABSA Group Limited
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package za.co.absa.spline.harvester.plugin.embedded
18+
19+
import org.apache.spark.sql.{SaveMode, SparkSession}
20+
import org.apache.spark.sql.execution.datasources.{LogicalRelation, SaveIntoDataSourceCommand}
21+
import org.scalatest.flatspec.AnyFlatSpec
22+
import org.scalatest.matchers.should.Matchers
23+
import org.scalatestplus.mockito.MockitoSugar
24+
import za.co.absa.spline.harvester.plugin.Plugin.{ReadNodeInfo, WriteNodeInfo}
25+
import za.co.absa.spline.harvester.builder.SourceIdentifier
26+
import org.mockito.Mockito.{mock, _}
27+
import net.snowflake.spark.snowflake.Parameters
28+
import net.snowflake.spark.snowflake.Parameters.MergedParameters
29+
import org.apache.spark.sql.sources.BaseRelation
30+
import za.co.absa.spline.commons.reflect.extractors.SafeTypeMatchingExtractor
31+
import za.co.absa.spline.commons.reflect.{ReflectionUtils, ValueExtractor}
32+
33+
class SnowflakePluginSpec extends AnyFlatSpec with Matchers with MockitoSugar {
34+
"SnowflakePlugin" should "process Snowflake relation providers" in {
35+
// Setup
36+
val spark = mock[SparkSession]
37+
val plugin = new SnowflakePlugin(spark)
38+
39+
val options = Map(
40+
"sfUrl" -> "test-url",
41+
"sfWarehouse" -> "test-warehouse",
42+
"sfDatabase" -> "test-database",
43+
"sfSchema" -> "test-schema",
44+
"sfUser" -> "user1",
45+
"dbtable" -> "test-table"
46+
)
47+
48+
val cmd = mock[SaveIntoDataSourceCommand]
49+
when(cmd.options) thenReturn(options)
50+
when(cmd.mode) thenReturn(SaveMode.Overwrite)
51+
when(cmd.query) thenReturn(null)
52+
53+
// Mocking the relation provider to be Snowflake
54+
val snowflakeRP = "net.snowflake.spark.snowflake.DefaultSource"
55+
56+
// Execute
57+
val result = plugin.relationProviderProcessor((snowflakeRP, cmd))
58+
59+
// Verify
60+
val expectedSourceId = SourceIdentifier(Some("snowflake"), "snowflake://test-url.test-warehouse.test-database.test-schema.test-table")
61+
result shouldEqual WriteNodeInfo(expectedSourceId, SaveMode.Overwrite, null, options)
62+
}
63+
64+
it should "not process non-Snowflake relation providers" in {
65+
// Setup
66+
val spark = mock[SparkSession]
67+
val plugin = new SnowflakePlugin(spark)
68+
69+
val cmd = mock[SaveIntoDataSourceCommand]
70+
71+
// Mocking the relation provider to be non-Snowflake
72+
val nonSnowflakeRP = "some.other.datasource"
73+
74+
// Execute & Verify
75+
assertThrows[MatchError] {
76+
plugin.relationProviderProcessor((nonSnowflakeRP, cmd))
77+
}
78+
}
79+
}

integration-tests/pom.xml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,11 @@
131131
<artifactId>spark-cobol_${scala.binary.version}</artifactId>
132132
<scope>test</scope>
133133
</dependency>
134+
<dependency>
135+
<groupId>net.snowflake</groupId>
136+
<artifactId>spark-snowflake_${scala.binary.version}</artifactId>
137+
<optional>true</optional>
138+
</dependency>
134139

135140
<!-- to force newer version of jackson-annotations - needed for testcontainers -->
136141
<dependency>
@@ -163,6 +168,12 @@
163168
<version>${testcontainers.version}</version>
164169
<scope>test</scope>
165170
</dependency>
171+
<dependency>
172+
<groupId>org.testcontainers</groupId>
173+
<artifactId>localstack</artifactId>
174+
<version>1.19.8</version>
175+
<scope>test</scope>
176+
</dependency>
166177

167178
<!-- required for spark-cassandra-connector -->
168179
<dependency>
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
/*
2+
* Copyright 2019 ABSA Group Limited
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package za.co.absa.spline
18+
19+
import org.apache.spark.sql.{Row, RowFactory}
20+
import org.scalatest.BeforeAndAfterAll
21+
import org.scalatest.flatspec.AsyncFlatSpec
22+
import org.scalatest.matchers.should.Matchers
23+
import org.testcontainers.containers.GenericContainer
24+
import org.testcontainers.containers.wait.strategy.Wait
25+
import org.testcontainers.utility.DockerImageName
26+
import za.co.absa.spline.commons.io.TempDirectory
27+
import za.co.absa.spline.test.fixture.spline.SplineFixture
28+
import za.co.absa.spline.test.fixture.{ReleasableResourceFixture, SparkFixture}
29+
30+
import java.util
31+
32+
class SnowflakeSpec
33+
extends AsyncFlatSpec
34+
with BeforeAndAfterAll
35+
with Matchers
36+
with SparkFixture
37+
with SplineFixture
38+
with ReleasableResourceFixture {
39+
40+
val tableName = "testTable"
41+
val schemaName = "testSchema"
42+
val warehouseName = "testWarehouse"
43+
val databaseName = "test"
44+
val sparkFormat = "net.snowflake.spark.snowflake"
45+
46+
it should "support snowflake as a read and write source" in {
47+
usingResource(new GenericContainer(DockerImageName.parse("localstack/snowflake"))) { container =>
48+
container.start()
49+
Wait.forHealthcheck
50+
51+
val host = container.getHost
52+
53+
withNewSparkSession { implicit spark =>
54+
55+
withLineageTracking { captor =>
56+
val sfOptions = Map(
57+
"sfURL" -> "snowflake.localhost.localstack.cloud",
58+
"sfUser" -> "test",
59+
"sfPassword" -> "test",
60+
"sfDatabase" -> databaseName,
61+
"sfWarehouse" -> warehouseName,
62+
"sfSchema" -> schemaName
63+
)
64+
65+
// Define your data as a Java List
66+
val data = new util.ArrayList[Row]()
67+
data.add(RowFactory.create(1.asInstanceOf[Object]))
68+
data.add(RowFactory.create(2.asInstanceOf[Object]))
69+
data.add(RowFactory.create(3.asInstanceOf[Object]))
70+
71+
// Use the method to create DataFrame
72+
val testData = spark.sqlContext.createDataFrame(data, classOf[Row])
73+
74+
for {
75+
(writePlan, _) <- captor.lineageOf(
76+
testData.write
77+
.format(sparkFormat)
78+
.options(sfOptions)
79+
.option("dbtable", tableName)
80+
.mode("overwrite")
81+
.save()
82+
)
83+
84+
(readPlan, _) <- captor.lineageOf {
85+
val df = spark.read.format(sparkFormat)
86+
.options(sfOptions)
87+
.option("dbtable", tableName) // specify the source table
88+
.load()
89+
90+
df.write.save(TempDirectory(pathOnly = true).deleteOnExit().path.toString)
91+
}
92+
} yield {
93+
writePlan.operations.write.append shouldBe false
94+
writePlan.operations.write.extra("destinationType") shouldBe Some("snowflake")
95+
writePlan.operations.write.outputSource shouldBe s"snowflake://$host.$warehouseName.$databaseName.$schemaName.$tableName"
96+
97+
readPlan.operations.reads.head.inputSources.head shouldBe writePlan.operations.write.outputSource
98+
readPlan.operations.reads.head.extra("sourceType") shouldBe Some("snowflake")
99+
readPlan.operations.write.append shouldBe false
100+
}
101+
}
102+
}
103+
}
104+
}
105+
106+
}

pom.xml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@
8383

8484
<!-- Spark -->
8585

86-
<spark.version>${spark-24.version}</spark.version>
86+
<spark.version>${spark-33.version}</spark.version>
8787
<spark-22.version>2.2.3</spark-22.version>
8888
<spark-23.version>2.3.4</spark-23.version>
8989
<spark-24.version>2.4.8</spark-24.version>
@@ -452,6 +452,11 @@
452452
<artifactId>mongo-spark-connector_${scala.binary.version}</artifactId>
453453
<version>2.4.1</version>
454454
</dependency>
455+
<dependency>
456+
<groupId>net.snowflake</groupId>
457+
<artifactId>spark-snowflake_${scala.binary.version}</artifactId>
458+
<version>2.16.0-spark_3.3</version>
459+
</dependency>
455460
<dependency>
456461
<groupId>org.elasticsearch</groupId>
457462
<artifactId>elasticsearch-hadoop</artifactId>

0 commit comments

Comments
 (0)