Skip to content

Commit 2798dc4

Browse files
committed
Add precision as an optional parameter in set-type
1 parent 297413a commit 2798dc4

File tree

8 files changed

+164
-44
lines changed

8 files changed

+164
-44
lines changed

wrangler-core/src/main/java/io/cdap/directives/column/SetType.java

+65-9
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,14 @@
2020
import io.cdap.cdap.api.annotation.Name;
2121
import io.cdap.cdap.api.annotation.Plugin;
2222
import io.cdap.cdap.api.data.schema.Schema;
23+
import io.cdap.cdap.api.data.schema.Schema.LogicalType;
2324
import io.cdap.wrangler.api.Arguments;
2425
import io.cdap.wrangler.api.Directive;
2526
import io.cdap.wrangler.api.DirectiveExecutionException;
2627
import io.cdap.wrangler.api.DirectiveParseException;
2728
import io.cdap.wrangler.api.ExecutorContext;
2829
import io.cdap.wrangler.api.Optional;
30+
import io.cdap.wrangler.api.Pair;
2931
import io.cdap.wrangler.api.Row;
3032
import io.cdap.wrangler.api.SchemaResolutionContext;
3133
import io.cdap.wrangler.api.annotations.Categories;
@@ -40,26 +42,28 @@
4042
import io.cdap.wrangler.utils.ColumnConverter;
4143

4244
import java.math.RoundingMode;
45+
import java.util.HashMap;
4346
import java.util.List;
4447
import java.util.stream.Collectors;
4548

4649
/**
4750
* A Wrangler step for converting data type of column
4851
* Accepted types are: int, short, long, double, float, string, boolean and bytes
49-
* When decimal type is selected, can also specify the scale and rounding mode
52+
* When decimal type is selected, can also specify the scale, precision and rounding mode
5053
*/
5154
@Plugin(type = "directives")
5255
@Name(SetType.NAME)
5356
@Categories(categories = {"column"})
54-
@Description("Converting data type of a column. Optional arguments scale and rounding-mode " +
55-
"are used only when type is decimal.")
57+
@Description("Converting data type of a column. Optional arguments scale, precision and "
58+
+ "rounding-mode are used only when type is decimal.")
5659
public final class SetType implements Directive, Lineage {
5760
public static final String NAME = "set-type";
5861

5962
private String col;
6063
private String type;
6164
private Integer scale;
6265
private RoundingMode roundingMode;
66+
private Integer precision;
6367

6468
@Override
6569
public UsageDefinition define() {
@@ -68,6 +72,7 @@ public UsageDefinition define() {
6872
builder.define("type", TokenType.IDENTIFIER);
6973
builder.define("scale", TokenType.NUMERIC, Optional.TRUE);
7074
builder.define("rounding-mode", TokenType.TEXT, Optional.TRUE);
75+
builder.define("precision", TokenType.PROPERTIES, "prop:{precision=<precision>}", Optional.TRUE);
7176
return builder.build();
7277
}
7378

@@ -76,14 +81,19 @@ public void initialize(Arguments args) throws DirectiveParseException {
7681
col = ((ColumnName) args.value("column")).value();
7782
type = ((Identifier) args.value("type")).value();
7883
if (type.equalsIgnoreCase("decimal")) {
84+
precision = args.contains("precision") ? (Integer) ((HashMap<String, Numeric>) args.
85+
value("precision").value()).get("precision").value().intValue() : null;
86+
if (precision != null && precision < 1) {
87+
throw new DirectiveParseException("precision cannot be less than 1");
88+
}
7989
scale = args.contains("scale") ? ((Numeric) args.value("scale")).value().intValue() : null;
80-
if (scale == null && args.contains("rounding-mode")) {
81-
throw new DirectiveParseException("'rounding-mode' can only be specified when a 'scale' is set");
90+
if (scale == null && precision == null && args.contains("rounding-mode")) {
91+
throw new DirectiveParseException("'rounding-mode' can only be specified when a 'scale' or 'precision' is set");
8292
}
8393
try {
8494
roundingMode = args.contains("rounding-mode") ?
8595
RoundingMode.valueOf(((Text) args.value("rounding-mode")).value()) :
86-
(scale == null ? RoundingMode.UNNECESSARY : RoundingMode.HALF_EVEN);
96+
(scale == null && precision == null ? RoundingMode.UNNECESSARY : RoundingMode.HALF_EVEN);
8797
} catch (IllegalArgumentException e) {
8898
throw new DirectiveParseException(String.format(
8999
"Specified rounding-mode '%s' is not a valid Java rounding mode", args.value("rounding-mode").value()), e);
@@ -99,7 +109,7 @@ public void destroy() {
99109
@Override
100110
public List<Row> execute(List<Row> rows, ExecutorContext context) throws DirectiveExecutionException {
101111
for (Row row : rows) {
102-
ColumnConverter.convertType(NAME, row, col, type, scale, roundingMode);
112+
ColumnConverter.convertType(NAME, row, col, type, scale, precision, roundingMode);
103113
}
104114
return rows;
105115
}
@@ -121,8 +131,41 @@ public Schema getOutputSchema(SchemaResolutionContext context) {
121131
.map(
122132
field -> {
123133
try {
124-
return field.getName().equals(col) ?
125-
Schema.Field.of(col, ColumnConverter.getSchemaForType(type, scale)) : field;
134+
if (field.getName().equals(col)) {
135+
Integer outputScale = scale;
136+
Integer outputPrecision = precision;
137+
Schema fieldSchema = field.getSchema().getNonNullable();
138+
Pair<Integer, Integer> scaleAndPrecision = getPrecisionAndScale(fieldSchema);
139+
Integer inputSchemaScale = scaleAndPrecision.getSecond();
140+
Integer inputSchemaPrecision = scaleAndPrecision.getFirst();
141+
142+
if (scale == null && precision == null) {
143+
outputScale = inputSchemaScale;
144+
outputPrecision = inputSchemaPrecision;
145+
} else if (scale == null && inputSchemaScale != null) {
146+
if (precision - inputSchemaScale < 1) {
147+
throw new DirectiveParseException(String.format(
148+
"Cannot set scale as '%s' and precision as '%s' when "
149+
+ "given precision - scale is less than 1 ", inputSchemaScale,
150+
precision));
151+
}
152+
outputScale = inputSchemaScale;
153+
outputPrecision = precision;
154+
155+
} else if (precision == null && inputSchemaPrecision != null) {
156+
if (inputSchemaPrecision - scale < 1) {
157+
throw new DirectiveParseException(String.format(
158+
"Cannot set scale as '%s' and precision as '%s' when "
159+
+ "given precision - scale is less than 1 ", scale,
160+
inputSchemaPrecision));
161+
}
162+
outputScale = scale;
163+
outputPrecision = inputSchemaPrecision;
164+
}
165+
return Schema.Field.of(col, ColumnConverter.getSchemaForType(type,
166+
outputScale, outputPrecision));
167+
}
168+
return field;
126169
} catch (DirectiveParseException e) {
127170
throw new RuntimeException(e);
128171
}
@@ -131,4 +174,17 @@ public Schema getOutputSchema(SchemaResolutionContext context) {
131174
.collect(Collectors.toList())
132175
);
133176
}
177+
178+
/**
179+
* extracts precision and scale from schema string
180+
*/
181+
public static Pair<Integer, Integer> getPrecisionAndScale(Schema fieldSchema) {
182+
Integer precision = null;
183+
Integer scale = null;
184+
if (fieldSchema.getLogicalType() == LogicalType.DECIMAL) {
185+
precision = fieldSchema.getPrecision();
186+
scale = fieldSchema.getScale();
187+
}
188+
return new Pair<Integer, Integer>(precision, scale);
189+
}
134190
}

wrangler-core/src/main/java/io/cdap/directives/datamodel/DataModelMapColumn.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ public void initialize(Arguments args) throws DirectiveParseException {
152152
@Override
153153
public List<Row> execute(List<Row> rows, ExecutorContext context) throws DirectiveExecutionException {
154154
for (Row row : rows) {
155-
ColumnConverter.convertType(NAME, row, column, targetFieldTypeName, null, RoundingMode.UNNECESSARY);
155+
ColumnConverter.convertType(NAME, row, column, targetFieldTypeName, null, null, RoundingMode.UNNECESSARY);
156156
ColumnConverter.rename(NAME, row, column, targetFieldName);
157157
}
158158
return rows;

wrangler-core/src/main/java/io/cdap/wrangler/parser/MigrateToV2.java

+3-2
Original file line numberDiff line numberDiff line change
@@ -126,13 +126,14 @@ public String migrate() throws DirectiveParseException {
126126
}
127127
break;
128128

129-
//set-type <column> <type> [<scale> <rounding-mode>]
129+
//set-type <column> <type> [<scale> <rounding-mode> prop:{precision=<precision>}]
130130
case "set-type": {
131131
String col = getNextToken(tokenizer, command, "col", lineno);
132132
String type = getNextToken(tokenizer, command, "type", lineno);
133133
String scale = getNextToken(tokenizer, null, command, "scale", lineno, true);
134134
String roundingMode = getNextToken(tokenizer, null, command, "rounding-mode", lineno, true);
135-
transformed.add(String.format("set-type %s %s %s %s;", col(col), type, scale, roundingMode));
135+
String precision = getNextToken(tokenizer, null, command, "precision", lineno, true);
136+
transformed.add(String.format("set-type %s %s %s %s %s;", col(col), type, scale, roundingMode, precision));
136137
}
137138
break;
138139

wrangler-core/src/main/java/io/cdap/wrangler/utils/ColumnConverter.java

+46-19
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,13 @@
1717

1818
import io.cdap.cdap.api.common.Bytes;
1919
import io.cdap.cdap.api.data.schema.Schema;
20+
import io.cdap.cdap.api.data.schema.Schema.LogicalType;
2021
import io.cdap.wrangler.api.DirectiveExecutionException;
2122
import io.cdap.wrangler.api.DirectiveParseException;
2223
import io.cdap.wrangler.api.Row;
2324

2425
import java.math.BigDecimal;
26+
import java.math.MathContext;
2527
import java.math.RoundingMode;
2628
import java.util.Collections;
2729
import java.util.HashMap;
@@ -45,7 +47,7 @@ private ColumnConverter() {
4547
* @throws DirectiveExecutionException when a column matching the target name already exists
4648
*/
4749
public static void rename(String directiveName, Row row, String column, String toName)
48-
throws DirectiveExecutionException {
50+
throws DirectiveExecutionException {
4951
int idx = row.find(column);
5052
int existingColumn = row.find(toName);
5153
if (idx == -1) {
@@ -57,9 +59,9 @@ public static void rename(String directiveName, Row row, String column, String t
5759
row.setColumn(idx, toName);
5860
} else {
5961
throw new DirectiveExecutionException(
60-
directiveName, String.format("Column '%s' already exists. Apply the 'drop %s' directive before " +
61-
"renaming '%s' to '%s'.",
62-
toName, toName, column, toName));
62+
directiveName, String.format("Column '%s' already exists. Apply the 'drop %s' directive before " +
63+
"renaming '%s' to '%s'.",
64+
toName, toName, column, toName));
6365
}
6466
}
6567

@@ -73,8 +75,8 @@ public static void rename(String directiveName, Row row, String column, String t
7375
* @throws DirectiveExecutionException when an unsupported type is specified or the column can not be converted.
7476
*/
7577
public static void convertType(String directiveName, Row row, String column, String toType,
76-
Integer scale, RoundingMode roundingMode)
77-
throws DirectiveExecutionException {
78+
Integer scale, Integer precision, RoundingMode roundingMode)
79+
throws DirectiveExecutionException {
7880
int idx = row.find(column);
7981
if (idx != -1) {
8082
Object object = row.getValue(idx);
@@ -84,21 +86,22 @@ public static void convertType(String directiveName, Row row, String column, Str
8486
try {
8587
Object converted = ColumnConverter.convertType(column, toType, object);
8688
if (toType.equalsIgnoreCase(ColumnTypeNames.DECIMAL)) {
87-
row.setValue(idx, setDecimalScale((BigDecimal) converted, scale, roundingMode));
89+
row.setValue(idx, setDecimalScaleAndPrecision((BigDecimal) converted, scale,
90+
precision, roundingMode));
8891
} else {
8992
row.setValue(idx, converted);
9093
}
9194
} catch (DirectiveExecutionException e) {
9295
throw e;
9396
} catch (Exception e) {
9497
throw new DirectiveExecutionException(
95-
directiveName, String.format("Column '%s' cannot be converted to a '%s'.", column, toType), e);
98+
directiveName, String.format("Column '%s' cannot be converted to a '%s'.", column, toType), e);
9699
}
97100
}
98101
}
99102

100103
private static Object convertType(String col, String toType, Object object)
101-
throws Exception {
104+
throws Exception {
102105
toType = toType.toUpperCase();
103106
switch (toType) {
104107
case ColumnTypeNames.INTEGER:
@@ -291,38 +294,62 @@ private static Object convertType(String col, String toType, Object object)
291294

292295
default:
293296
throw new DirectiveExecutionException(String.format(
294-
"Column '%s' is of unsupported type '%s'. Supported types are: " +
295-
"int, short, long, double, decimal, boolean, string, bytes", col, toType));
297+
"Column '%s' is of unsupported type '%s'. Supported types are: " +
298+
"int, short, long, double, decimal, boolean, string, bytes", col, toType));
296299
}
297300
throw new DirectiveExecutionException(
298301
String.format("Column '%s' has value of type '%s' and cannot be converted to a '%s'.", col,
299302
object.getClass().getSimpleName(), toType));
300303
}
301304

302-
private static BigDecimal setDecimalScale(BigDecimal decimal, Integer scale, RoundingMode roundingMode)
303-
throws DirectiveExecutionException {
304-
if (scale == null) {
305+
private static BigDecimal setDecimalScaleAndPrecision(BigDecimal decimal, Integer scale,
306+
Integer precision, RoundingMode roundingMode)
307+
throws DirectiveExecutionException {
308+
if (scale == null && precision == null) {
305309
return decimal;
306310
}
307311
try {
308-
return decimal.setScale(scale, roundingMode);
312+
if (precision == null) {
313+
return decimal.setScale(scale, roundingMode);
314+
} else if (scale == null) {
315+
return decimal.round(new MathContext(precision, roundingMode));
316+
} else {
317+
BigDecimal result;
318+
if (validateScaleAndPrecision(scale, precision, decimal)) {
319+
result = decimal.setScale(scale, roundingMode);
320+
result = result.round(new MathContext(precision, roundingMode));
321+
} else {
322+
throw new DirectiveExecutionException(String.format(
323+
"Cannot set scale as '%s' and precision as '%s' for value '%s' when"
324+
+ "given precision - scale is less than number of digits"
325+
+ " before decimal point ", scale, precision, decimal));
326+
}
327+
return result;
328+
}
309329
} catch (ArithmeticException e) {
310330
throw new DirectiveExecutionException(String.format(
311-
"Cannot set scale as '%s' for value '%s' when rounding-mode is '%s'", scale, decimal, roundingMode), e);
331+
"Cannot set scale as '%s' and precision '%s' for value '%s' when rounding-mode "
332+
+ "is '%s'", scale, precision, decimal, roundingMode), e);
312333
}
313334
}
314335

315-
public static Schema getSchemaForType(String type, Integer scale) throws DirectiveParseException {
336+
private static Boolean validateScaleAndPrecision(Integer scale, Integer precision, BigDecimal decimal) {
337+
int digitsBeforeDecimalPoint = decimal.signum() == 0 ? 1 : decimal.precision() - decimal.scale();
338+
return precision - scale >= digitsBeforeDecimalPoint;
339+
}
340+
341+
public static Schema getSchemaForType(String type, Integer scale, Integer precision) throws DirectiveParseException {
316342
Schema typeSchema;
317343
type = type.toUpperCase();
318344
if (type.equals(ColumnTypeNames.DECIMAL)) {
319345
// TODO make set-type support setting decimal precision
346+
precision = precision != null ? precision : 77;
320347
scale = scale != null ? scale : 38;
321-
typeSchema = Schema.nullableOf(Schema.decimalOf(77, scale));
348+
typeSchema = Schema.nullableOf(Schema.decimalOf(precision, scale));
322349
} else {
323350
if (!SCHEMA_TYPE_MAP.containsKey(type)) {
324351
throw new DirectiveParseException(String.format("'%s' is an unsupported type. " +
325-
"Supported types are: int, short, long, double, decimal, boolean, string, bytes", type));
352+
"Supported types are: int, short, long, double, decimal, boolean, string, bytes", type));
326353
}
327354
typeSchema = Schema.nullableOf(Schema.of(SCHEMA_TYPE_MAP.get(type)));
328355
}

wrangler-core/src/test/java/io/cdap/directives/column/SetTypeTest.java

+44-9
Original file line numberDiff line numberDiff line change
@@ -213,13 +213,48 @@ public void testToDecimalInvalidRoundingMode() throws Exception {
213213
TestingRig.execute(directives, rows);
214214
}
215215

216+
@Test
217+
public void testToDecimalWithScalePrecisionAndRoundingMode() throws Exception {
218+
List<Row> rows = Collections.singletonList(new Row("scale_1_precision_4", "122.5")
219+
.add("scale_3_precision_6", "456.789"));
220+
String[] directives = new String[] {"set-type :scale_1_precision_4 decimal 0 'FLOOR' prop:{precision=3}",
221+
"set-type :scale_3_precision_6 decimal 0 prop:{precision=5}"};
222+
List<Row> results = TestingRig.execute(directives, rows);
223+
Row row = results.get(0);
224+
225+
Assert.assertTrue(row.getValue(0) instanceof BigDecimal);
226+
Assert.assertEquals(row.getValue(0), new BigDecimal("122"));
227+
228+
Assert.assertTrue(row.getValue(1) instanceof BigDecimal);
229+
Assert.assertEquals(row.getValue(1), new BigDecimal("457"));
230+
}
231+
232+
@Test
233+
public void testToDecimalWithPrecision() throws Exception {
234+
List<Row> rows = Collections.singletonList(new Row("scale_1_precision_4", "122.5"));
235+
String[] directives = new String[] {"set-type :scale_1_precision_4 decimal 'FLOOR' prop:{precision=3}"};
236+
List<Row> results = TestingRig.execute(directives, rows);
237+
Row row = results.get(0);
238+
239+
Assert.assertTrue(row.getValue(0) instanceof BigDecimal);
240+
Assert.assertEquals(row.getValue(0), new BigDecimal("122"));
241+
242+
}
243+
244+
@Test(expected = RecipeException.class)
245+
public void testToDecimalWithInvalidPrecision() throws Exception {
246+
List<Row> rows = Collections.singletonList(new Row("scale_1_precision_4", "122.5"));
247+
String[] directives = new String[] {"set-type :scale_1_precision_4 decimal 0 'FLOOR' prop:{precision=-1}"};
248+
TestingRig.execute(directives, rows);
249+
}
250+
216251
@Test
217252
public void testToDecimalScaleIsNull() throws Exception {
218253
List<Row> rows = Collections.singletonList(new Row("scale_2", "125.45"));
219254
String[] directives = new String[] {"set-type scale_2 decimal"};
220255
Schema inputSchema = Schema.recordOf(
221256
"inputSchema",
222-
Schema.Field.of("scale_2", Schema.of(Schema.Type.DOUBLE))
257+
Schema.Field.of("scale_2", Schema.nullableOf(Schema.of(Schema.Type.DOUBLE)))
223258
);
224259

225260
Schema expectedSchema = Schema.recordOf(
@@ -377,14 +412,14 @@ public void testGetOutputSchemaForTypeChangedColumn() throws Exception {
377412
.add("D", "random").add("E", 123).add("F", "true").add("G", 12L)
378413
);
379414
Schema inputSchema = Schema.recordOf(
380-
"inputSchema",
381-
Schema.Field.of("A", Schema.of(Schema.Type.STRING)),
382-
Schema.Field.of("B", Schema.of(Schema.Type.STRING)),
383-
Schema.Field.of("C", Schema.of(Schema.Type.STRING)),
384-
Schema.Field.of("D", Schema.of(Schema.Type.STRING)),
385-
Schema.Field.of("E", Schema.of(Schema.Type.INT)),
386-
Schema.Field.of("F", Schema.of(Schema.Type.STRING)),
387-
Schema.Field.of("G", Schema.of(Schema.Type.LONG))
415+
"inputSchema",
416+
Schema.Field.of("A", Schema.nullableOf(Schema.of(Schema.Type.STRING))),
417+
Schema.Field.of("B", Schema.nullableOf(Schema.of(Schema.Type.STRING))),
418+
Schema.Field.of("C", Schema.nullableOf(Schema.of(Schema.Type.STRING))),
419+
Schema.Field.of("D", Schema.nullableOf(Schema.of(Schema.Type.STRING))),
420+
Schema.Field.of("E", Schema.nullableOf(Schema.of(Schema.Type.INT))),
421+
Schema.Field.of("F", Schema.nullableOf(Schema.of(Schema.Type.STRING))),
422+
Schema.Field.of("G", Schema.nullableOf(Schema.of(Schema.Type.LONG)))
388423
);
389424
Schema expectedSchema = Schema.recordOf(
390425
"expectedSchema",

0 commit comments

Comments
 (0)