Skip to content

Commit c475329

Browse files
[Snowflake Connector] Add tests for connector's SQL overrides.
1 parent dd32a09 commit c475329

File tree

1 file changed

+39
-8
lines changed

1 file changed

+39
-8
lines changed

dumper/app/src/test/java/com/google/edwmigration/dumper/application/dumper/connector/snowflake/SnowflakeMetadataConnectorTest.java

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@
1616
*/
1717
package com.google.edwmigration.dumper.application.dumper.connector.snowflake;
1818

19+
import static org.junit.Assert.assertEquals;
20+
import static org.junit.Assert.assertThrows;
21+
import static org.junit.Assert.assertTrue;
22+
1923
import com.google.common.collect.ImmutableMap;
2024
import com.google.common.io.Resources;
2125
import com.google.edwmigration.dumper.application.dumper.ConnectorArguments;
@@ -34,6 +38,8 @@
3438
import java.util.List;
3539
import java.util.Map;
3640
import javax.annotation.Nonnull;
41+
import org.apache.commons.lang3.ArrayUtils;
42+
import org.apache.commons.lang3.StringUtils;
3743
import org.junit.Assert;
3844
import org.junit.Assume;
3945
import org.junit.Test;
@@ -122,19 +128,19 @@ public void testDatabaseNameFailure() {
122128
Assume.assumeTrue(isDumperTest());
123129

124130
MetadataDumperUsageException exception =
125-
Assert.assertThrows(
131+
assertThrows(
126132
MetadataDumperUsageException.class,
127133
() -> {
128134
File outputFile =
129135
TestUtils.newOutputFile("compilerworks-snowflake-metadata-fail.zip");
130136
String[] args = ARGS(connector, outputFile);
131137

132-
Assert.assertEquals("--database", args[6]);
138+
assertEquals("--database", args[6]);
133139
args[7] = args[7] + "_NOT_EXISTS";
134140
run(args);
135141
});
136142

137-
Assert.assertTrue(exception.getMessage().startsWith("Database name not found"));
143+
assertTrue(exception.getMessage().startsWith("Database name not found"));
138144
}
139145

140146
@Test
@@ -147,17 +153,42 @@ public void connector_generatesExpectedSql() throws IOException {
147153
StandardCharsets.UTF_8),
148154
TaskSqlMap.class);
149155

150-
Assert.assertEquals(expectedSqls.size(), actualSqls.size());
151-
Assert.assertEquals(expectedSqls.keySet(), actualSqls.keySet());
156+
assertEquals(expectedSqls.size(), actualSqls.size());
157+
assertEquals(expectedSqls.keySet(), actualSqls.keySet());
152158
for (String name : expectedSqls.keySet()) {
153-
Assert.assertEquals(expectedSqls.get(name), actualSqls.get(name));
159+
assertEquals(expectedSqls.get(name), actualSqls.get(name));
154160
}
155161
}
156162

157-
private static Map<String, String> collectSqlStatements() throws IOException {
163+
@Test
164+
public void connector_generatesExpectedSql_withQueryOverrides() throws IOException {
165+
Map<String, String> actualSqls =
166+
collectSqlStatements("-Dsnowflake.metadata.columns.query=SQL_OVERRIDE");
167+
168+
assertEquals("SQL_OVERRIDE", actualSqls.get("columns-au.csv"));
169+
assertEquals("SQL_OVERRIDE", actualSqls.get("columns.csv"));
170+
}
171+
172+
@Test
173+
public void connector_generatesExpectedSql_withWhereOverrides() throws IOException {
174+
Map<String, String> actualSqls =
175+
collectSqlStatements("-Dsnowflake.metadata.columns.where=SQL_OVERRIDE");
176+
177+
// TODO: should be endsWith("WHERE SQL_OVERRIDE")
178+
assertTrue(
179+
actualSqls.get("columns-au.csv").endsWith("WHERE DELETED IS NULL WHERE SQL_OVERRIDE"));
180+
// TODO: should be 1
181+
assertEquals(2, StringUtils.countMatches(actualSqls.get("columns-au.csv"), " WHERE "));
182+
183+
assertTrue(actualSqls.get("columns.csv").endsWith("WHERE SQL_OVERRIDE"));
184+
assertEquals(1, StringUtils.countMatches(actualSqls.get("columns.csv"), " WHERE "));
185+
}
186+
187+
private static Map<String, String> collectSqlStatements(String... extraArgs) throws IOException {
158188
List<Task<?>> tasks = new ArrayList<>();
159189
SnowflakeMetadataConnector connector = new SnowflakeMetadataConnector();
160-
connector.addTasksTo(tasks, new ConnectorArguments("--connector", connector.getName()));
190+
String[] args = ArrayUtils.addAll(new String[] {"--connector", connector.getName()}, extraArgs);
191+
connector.addTasksTo(tasks, new ConnectorArguments(args));
161192
return tasks.stream()
162193
.filter(t -> t instanceof JdbcSelectTask)
163194
.map(t -> (JdbcSelectTask) t)

0 commit comments

Comments
 (0)