diff --git a/transportable-udfs-annotation-processor/src/main/java/com/linkedin/transport/processor/TransportProcessor.java b/transportable-udfs-annotation-processor/src/main/java/com/linkedin/transport/processor/TransportProcessor.java index 12e25002..30a110a8 100644 --- a/transportable-udfs-annotation-processor/src/main/java/com/linkedin/transport/processor/TransportProcessor.java +++ b/transportable-udfs-annotation-processor/src/main/java/com/linkedin/transport/processor/TransportProcessor.java @@ -130,10 +130,19 @@ private void processUDFClass(TypeElement udfClassElement) { udfClassElement ); } else { - String topLevelStdUdfClassName = - elementsOverridingTopLevelStdUDFMethods.iterator().next().getQualifiedName().toString(); + TypeElement topLevelStdUdfTypeElement = elementsOverridingTopLevelStdUDFMethods.iterator().next(); + String topLevelStdUdfClassName = topLevelStdUdfTypeElement.getQualifiedName().toString(); debug(String.format("TopLevelStdUDF class found: %s", topLevelStdUdfClassName)); + String udfClassName = udfClassElement.getQualifiedName().toString(); _transportUdfMetadata.addUDF(topLevelStdUdfClassName, udfClassElement.getQualifiedName().toString()); + _transportUdfMetadata.setClassNumberOfTypeParameters( + topLevelStdUdfClassName, + topLevelStdUdfTypeElement.getTypeParameters().size() + ); + _transportUdfMetadata.setClassNumberOfTypeParameters( + udfClassName, + udfClassElement.getTypeParameters().size() + ); } } diff --git a/transportable-udfs-annotation-processor/src/test/java/com/linkedin/transport/processor/TransportProcessorTest.java b/transportable-udfs-annotation-processor/src/test/java/com/linkedin/transport/processor/TransportProcessorTest.java index 0322f306..dd3c25c5 100644 --- a/transportable-udfs-annotation-processor/src/test/java/com/linkedin/transport/processor/TransportProcessorTest.java +++ b/transportable-udfs-annotation-processor/src/test/java/com/linkedin/transport/processor/TransportProcessorTest.java @@ -81,7 +81,7 @@ public void shouldNotContainMultipleOverridingsOfTopLevelStdUDFMethods1() throws .withErrorCount(1) .withErrorContaining(Constants.MORE_THAN_ONE_TYPE_OVERRIDING_ERROR) .in(forResource("udfs/UDFWithMultipleInterfaces1.java")) - .onLine(14) + .onLine(13) .atColumn(8); } @@ -96,7 +96,7 @@ public void shouldNotContainMultipleOverridingsOfTopLevelStdUDFMethods2() throws .withErrorCount(1) .withErrorContaining(Constants.MORE_THAN_ONE_TYPE_OVERRIDING_ERROR) .in(forResource("udfs/UDFWithMultipleInterfaces2.java")) - .onLine(13) + .onLine(12) .atColumn(8); } @@ -110,7 +110,7 @@ public void udfShouldNotOverrideInterfaceMethods() throws IOException { .withErrorCount(1) .withErrorContaining(Constants.MORE_THAN_ONE_TYPE_OVERRIDING_ERROR) .in(forResource("udfs/UDFOverridingInterfaceMethod.java")) - .onLine(14) + .onLine(13) .atColumn(8); } @@ -123,7 +123,7 @@ public void udfShouldImplementTopLevelStdUDF() throws IOException { .withErrorCount(1) .withErrorContaining(Constants.INTERFACE_NOT_IMPLEMENTED_ERROR) .in(forResource("udfs/UDFNotImplementingTopLevelStdUDF.java")) - .onLine(14) + .onLine(13) .atColumn(8); } diff --git a/transportable-udfs-annotation-processor/src/test/resources/outputs/empty.json b/transportable-udfs-annotation-processor/src/test/resources/outputs/empty.json index 7836d3b6..4e9cfcb3 100644 --- a/transportable-udfs-annotation-processor/src/test/resources/outputs/empty.json +++ b/transportable-udfs-annotation-processor/src/test/resources/outputs/empty.json @@ -1,3 +1,4 @@ { - "udfs": [] + "udfs": {}, + "classToNumberOfTypeParameters": {} } \ No newline at end of file diff --git a/transportable-udfs-annotation-processor/src/test/resources/outputs/overloadedUDF.json b/transportable-udfs-annotation-processor/src/test/resources/outputs/overloadedUDF.json index 9f6a2450..4f0526f0 100644 --- a/transportable-udfs-annotation-processor/src/test/resources/outputs/overloadedUDF.json +++ b/transportable-udfs-annotation-processor/src/test/resources/outputs/overloadedUDF.json @@ -1,11 +1,13 @@ { - "udfs": [ - { - "topLevelClass": "udfs.OverloadedUDF1", - "stdUDFImplementations": [ - "udfs.OverloadedUDFInt", - "udfs.OverloadedUDFString" - ] - } - ] + "udfs": { + "udfs.OverloadedUDF1": [ + "udfs.OverloadedUDFInt", + "udfs.OverloadedUDFString" + ] + }, + "classToNumberOfTypeParameters": { + "udfs.OverloadedUDFString": 0, + "udfs.OverloadedUDF1": 0, + "udfs.OverloadedUDFInt": 0 + } } \ No newline at end of file diff --git a/transportable-udfs-annotation-processor/src/test/resources/outputs/simpleUDF.json b/transportable-udfs-annotation-processor/src/test/resources/outputs/simpleUDF.json index 9323ddbd..34c7cee2 100644 --- a/transportable-udfs-annotation-processor/src/test/resources/outputs/simpleUDF.json +++ b/transportable-udfs-annotation-processor/src/test/resources/outputs/simpleUDF.json @@ -1,10 +1,10 @@ { - "udfs": [ - { - "topLevelClass": "udfs.SimpleUDF", - "stdUDFImplementations": [ - "udfs.SimpleUDF" - ] - } - ] + "udfs": { + "udfs.SimpleUDF": [ + "udfs.SimpleUDF" + ] + }, + "classToNumberOfTypeParameters": { + "udfs.SimpleUDF": 0 + } } \ No newline at end of file diff --git a/transportable-udfs-annotation-processor/src/test/resources/outputs/udfExtendingAbstractUDF.json b/transportable-udfs-annotation-processor/src/test/resources/outputs/udfExtendingAbstractUDF.json index ab58d4d8..5b72a274 100644 --- a/transportable-udfs-annotation-processor/src/test/resources/outputs/udfExtendingAbstractUDF.json +++ b/transportable-udfs-annotation-processor/src/test/resources/outputs/udfExtendingAbstractUDF.json @@ -1,10 +1,10 @@ { - "udfs": [ - { - "topLevelClass": "udfs.UDFExtendingAbstractUDF", - "stdUDFImplementations": [ - "udfs.UDFExtendingAbstractUDF" - ] - } - ] + "udfs": { + "udfs.UDFExtendingAbstractUDF": [ + "udfs.UDFExtendingAbstractUDF" + ] + }, + "classToNumberOfTypeParameters": { + "udfs.UDFExtendingAbstractUDF": 0 + } } \ No newline at end of file diff --git a/transportable-udfs-annotation-processor/src/test/resources/outputs/udfExtendingAbstractUDFImplementingInterface.json b/transportable-udfs-annotation-processor/src/test/resources/outputs/udfExtendingAbstractUDFImplementingInterface.json index d2551a77..b75531e1 100644 --- a/transportable-udfs-annotation-processor/src/test/resources/outputs/udfExtendingAbstractUDFImplementingInterface.json +++ b/transportable-udfs-annotation-processor/src/test/resources/outputs/udfExtendingAbstractUDFImplementingInterface.json @@ -1,10 +1,11 @@ { - "udfs": [ - { - "topLevelClass": "udfs.AbstractUDFImplementingInterface", - "stdUDFImplementations": [ - "udfs.UDFExtendingAbstractUDFImplementingInterface" - ] - } - ] + "udfs": { + "udfs.AbstractUDFImplementingInterface": [ + "udfs.UDFExtendingAbstractUDFImplementingInterface" + ] + }, + "classToNumberOfTypeParameters": { + "udfs.UDFExtendingAbstractUDFImplementingInterface": 0, + "udfs.AbstractUDFImplementingInterface": 0 + } } \ No newline at end of file diff --git a/transportable-udfs-annotation-processor/src/test/resources/udfs/AbstractUDF.java b/transportable-udfs-annotation-processor/src/test/resources/udfs/AbstractUDF.java index 4a482115..06536aa2 100644 --- a/transportable-udfs-annotation-processor/src/test/resources/udfs/AbstractUDF.java +++ b/transportable-udfs-annotation-processor/src/test/resources/udfs/AbstractUDF.java @@ -5,11 +5,10 @@ */ package udfs; -import com.linkedin.transport.api.data.StdString; import com.linkedin.transport.api.udf.StdUDF0; import com.linkedin.transport.api.udf.TopLevelStdUDF; -public abstract class AbstractUDF extends StdUDF0 implements TopLevelStdUDF { +public abstract class AbstractUDF extends StdUDF0 implements TopLevelStdUDF { } \ No newline at end of file diff --git a/transportable-udfs-annotation-processor/src/test/resources/udfs/AbstractUDFImplementingInterface.java b/transportable-udfs-annotation-processor/src/test/resources/udfs/AbstractUDFImplementingInterface.java index 4078d7bc..7d85fb36 100644 --- a/transportable-udfs-annotation-processor/src/test/resources/udfs/AbstractUDFImplementingInterface.java +++ b/transportable-udfs-annotation-processor/src/test/resources/udfs/AbstractUDFImplementingInterface.java @@ -5,12 +5,11 @@ */ package udfs; -import com.linkedin.transport.api.data.StdString; import com.linkedin.transport.api.udf.StdUDF0; import com.linkedin.transport.api.udf.TopLevelStdUDF; -public abstract class AbstractUDFImplementingInterface extends StdUDF0 implements TopLevelStdUDF { +public abstract class AbstractUDFImplementingInterface extends StdUDF0 implements TopLevelStdUDF { @Override public String getFunctionName() { diff --git a/transportable-udfs-annotation-processor/src/test/resources/udfs/OuterClassForInnerUDF.java b/transportable-udfs-annotation-processor/src/test/resources/udfs/OuterClassForInnerUDF.java index d5d2551d..a84ee746 100644 --- a/transportable-udfs-annotation-processor/src/test/resources/udfs/OuterClassForInnerUDF.java +++ b/transportable-udfs-annotation-processor/src/test/resources/udfs/OuterClassForInnerUDF.java @@ -6,14 +6,13 @@ package udfs; import com.google.common.collect.ImmutableList; -import com.linkedin.transport.api.data.StdString; import com.linkedin.transport.api.udf.StdUDF0; import com.linkedin.transport.api.udf.TopLevelStdUDF; import java.util.List; public class OuterClassForInnerUDF { - public class InnerUDF extends StdUDF0 implements TopLevelStdUDF { + public class InnerUDF extends StdUDF0 implements TopLevelStdUDF { @Override public String getFunctionName() { @@ -36,7 +35,7 @@ public String getOutputParameterSignature() { } @Override - public StdString eval() { + public String eval() { return null; } } diff --git a/transportable-udfs-annotation-processor/src/test/resources/udfs/OverloadedUDFInt.java b/transportable-udfs-annotation-processor/src/test/resources/udfs/OverloadedUDFInt.java index 3f130d9d..292c8606 100644 --- a/transportable-udfs-annotation-processor/src/test/resources/udfs/OverloadedUDFInt.java +++ b/transportable-udfs-annotation-processor/src/test/resources/udfs/OverloadedUDFInt.java @@ -6,12 +6,11 @@ package udfs; import com.google.common.collect.ImmutableList; -import com.linkedin.transport.api.data.StdInteger; import com.linkedin.transport.api.udf.StdUDF0; import java.util.List; -public class OverloadedUDFInt extends StdUDF0 implements OverloadedUDF1 { +public class OverloadedUDFInt extends StdUDF0 implements OverloadedUDF1 { @Override public List getInputParameterSignatures() { @@ -24,7 +23,7 @@ public String getOutputParameterSignature() { } @Override - public StdInteger eval() { + public Integer eval() { return null; } } \ No newline at end of file diff --git a/transportable-udfs-annotation-processor/src/test/resources/udfs/OverloadedUDFString.java b/transportable-udfs-annotation-processor/src/test/resources/udfs/OverloadedUDFString.java index 9782d683..d0855f55 100644 --- a/transportable-udfs-annotation-processor/src/test/resources/udfs/OverloadedUDFString.java +++ b/transportable-udfs-annotation-processor/src/test/resources/udfs/OverloadedUDFString.java @@ -6,12 +6,11 @@ package udfs; import com.google.common.collect.ImmutableList; -import com.linkedin.transport.api.data.StdString; import com.linkedin.transport.api.udf.StdUDF0; import java.util.List; -public class OverloadedUDFString extends StdUDF0 implements OverloadedUDF1 { +public class OverloadedUDFString extends StdUDF0 implements OverloadedUDF1 { @Override public List getInputParameterSignatures() { @@ -24,7 +23,7 @@ public String getOutputParameterSignature() { } @Override - public StdString eval() { + public String eval() { return null; } } \ No newline at end of file diff --git a/transportable-udfs-annotation-processor/src/test/resources/udfs/SimpleUDF.java b/transportable-udfs-annotation-processor/src/test/resources/udfs/SimpleUDF.java index 15e749ec..46231c63 100644 --- a/transportable-udfs-annotation-processor/src/test/resources/udfs/SimpleUDF.java +++ b/transportable-udfs-annotation-processor/src/test/resources/udfs/SimpleUDF.java @@ -6,13 +6,12 @@ package udfs; import com.google.common.collect.ImmutableList; -import com.linkedin.transport.api.data.StdString; import com.linkedin.transport.api.udf.StdUDF0; import com.linkedin.transport.api.udf.TopLevelStdUDF; import java.util.List; -public class SimpleUDF extends StdUDF0 implements TopLevelStdUDF { +public class SimpleUDF extends StdUDF0 implements TopLevelStdUDF { @Override public String getFunctionName() { @@ -35,7 +34,7 @@ public String getOutputParameterSignature() { } @Override - public StdString eval() { + public String eval() { return null; } } \ No newline at end of file diff --git a/transportable-udfs-annotation-processor/src/test/resources/udfs/UDFExtendingAbstractUDF.java b/transportable-udfs-annotation-processor/src/test/resources/udfs/UDFExtendingAbstractUDF.java index 99fb068c..564fcd7e 100644 --- a/transportable-udfs-annotation-processor/src/test/resources/udfs/UDFExtendingAbstractUDF.java +++ b/transportable-udfs-annotation-processor/src/test/resources/udfs/UDFExtendingAbstractUDF.java @@ -6,7 +6,6 @@ package udfs; import com.google.common.collect.ImmutableList; -import com.linkedin.transport.api.data.StdString; import com.linkedin.transport.api.udf.TopLevelStdUDF; import java.util.List; @@ -34,7 +33,7 @@ public String getOutputParameterSignature() { } @Override - public StdString eval() { + public String eval() { return null; } } \ No newline at end of file diff --git a/transportable-udfs-annotation-processor/src/test/resources/udfs/UDFExtendingAbstractUDFImplementingInterface.java b/transportable-udfs-annotation-processor/src/test/resources/udfs/UDFExtendingAbstractUDFImplementingInterface.java index 2fc6d5ce..db8e3bc7 100644 --- a/transportable-udfs-annotation-processor/src/test/resources/udfs/UDFExtendingAbstractUDFImplementingInterface.java +++ b/transportable-udfs-annotation-processor/src/test/resources/udfs/UDFExtendingAbstractUDFImplementingInterface.java @@ -6,7 +6,6 @@ package udfs; import com.google.common.collect.ImmutableList; -import com.linkedin.transport.api.data.StdString; import java.util.List; @@ -23,7 +22,7 @@ public String getOutputParameterSignature() { } @Override - public StdString eval() { + public String eval() { return null; } } \ No newline at end of file diff --git a/transportable-udfs-annotation-processor/src/test/resources/udfs/UDFNotImplementingTopLevelStdUDF.java b/transportable-udfs-annotation-processor/src/test/resources/udfs/UDFNotImplementingTopLevelStdUDF.java index 43e8ffa9..862403e4 100644 --- a/transportable-udfs-annotation-processor/src/test/resources/udfs/UDFNotImplementingTopLevelStdUDF.java +++ b/transportable-udfs-annotation-processor/src/test/resources/udfs/UDFNotImplementingTopLevelStdUDF.java @@ -6,12 +6,11 @@ package udfs; import com.google.common.collect.ImmutableList; -import com.linkedin.transport.api.data.StdString; import com.linkedin.transport.api.udf.StdUDF0; import java.util.List; -public class UDFNotImplementingTopLevelStdUDF extends StdUDF0 { +public class UDFNotImplementingTopLevelStdUDF extends StdUDF0 { @Override public List getInputParameterSignatures() { @@ -24,7 +23,7 @@ public String getOutputParameterSignature() { } @Override - public StdString eval() { + public String eval() { return null; } } \ No newline at end of file diff --git a/transportable-udfs-annotation-processor/src/test/resources/udfs/UDFOverridingInterfaceMethod.java b/transportable-udfs-annotation-processor/src/test/resources/udfs/UDFOverridingInterfaceMethod.java index 346ff553..2d97547e 100644 --- a/transportable-udfs-annotation-processor/src/test/resources/udfs/UDFOverridingInterfaceMethod.java +++ b/transportable-udfs-annotation-processor/src/test/resources/udfs/UDFOverridingInterfaceMethod.java @@ -6,12 +6,11 @@ package udfs; import com.google.common.collect.ImmutableList; -import com.linkedin.transport.api.data.StdBoolean; import com.linkedin.transport.api.udf.StdUDF0; import java.util.List; -public class UDFOverridingInterfaceMethod extends StdUDF0 implements OverloadedUDF1 { +public class UDFOverridingInterfaceMethod extends StdUDF0 implements OverloadedUDF1 { @Override public String getFunctionName() { @@ -29,7 +28,7 @@ public String getOutputParameterSignature() { } @Override - public StdBoolean eval() { + public Boolean eval() { return null; } } \ No newline at end of file diff --git a/transportable-udfs-annotation-processor/src/test/resources/udfs/UDFWithMultipleInterfaces1.java b/transportable-udfs-annotation-processor/src/test/resources/udfs/UDFWithMultipleInterfaces1.java index 54e57c16..83a38c4d 100644 --- a/transportable-udfs-annotation-processor/src/test/resources/udfs/UDFWithMultipleInterfaces1.java +++ b/transportable-udfs-annotation-processor/src/test/resources/udfs/UDFWithMultipleInterfaces1.java @@ -6,12 +6,11 @@ package udfs; import com.google.common.collect.ImmutableList; -import com.linkedin.transport.api.data.StdBoolean; import com.linkedin.transport.api.udf.StdUDF0; import java.util.List; -public class UDFWithMultipleInterfaces1 extends StdUDF0 implements OverloadedUDF1, OverloadedUDF2 { +public class UDFWithMultipleInterfaces1 extends StdUDF0 implements OverloadedUDF1, OverloadedUDF2 { @Override public String getFunctionName() { @@ -34,7 +33,7 @@ public String getOutputParameterSignature() { } @Override - public StdBoolean eval() { + public Boolean eval() { return null; } } \ No newline at end of file diff --git a/transportable-udfs-annotation-processor/src/test/resources/udfs/UDFWithMultipleInterfaces2.java b/transportable-udfs-annotation-processor/src/test/resources/udfs/UDFWithMultipleInterfaces2.java index e8e62bb1..f2fbe270 100644 --- a/transportable-udfs-annotation-processor/src/test/resources/udfs/UDFWithMultipleInterfaces2.java +++ b/transportable-udfs-annotation-processor/src/test/resources/udfs/UDFWithMultipleInterfaces2.java @@ -6,7 +6,6 @@ package udfs; import com.google.common.collect.ImmutableList; -import com.linkedin.transport.api.data.StdString; import java.util.List; @@ -33,7 +32,7 @@ public String getOutputParameterSignature() { } @Override - public StdString eval() { + public String eval() { return null; } } \ No newline at end of file diff --git a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/StdFactory.java b/transportable-udfs-api/src/main/java/com/linkedin/transport/api/StdFactory.java index a7f4fc5a..76b9e239 100644 --- a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/StdFactory.java +++ b/transportable-udfs-api/src/main/java/com/linkedin/transport/api/StdFactory.java @@ -5,29 +5,20 @@ */ package com.linkedin.transport.api; -import com.linkedin.transport.api.data.StdArray; -import com.linkedin.transport.api.data.StdBoolean; -import com.linkedin.transport.api.data.StdBinary; -import com.linkedin.transport.api.data.StdData; -import com.linkedin.transport.api.data.StdDouble; -import com.linkedin.transport.api.data.StdFloat; -import com.linkedin.transport.api.data.StdInteger; -import com.linkedin.transport.api.data.StdLong; -import com.linkedin.transport.api.data.StdMap; -import com.linkedin.transport.api.data.StdString; -import com.linkedin.transport.api.data.StdStruct; +import com.linkedin.transport.api.data.ArrayData; +import com.linkedin.transport.api.data.MapData; +import com.linkedin.transport.api.data.RowData; import com.linkedin.transport.api.types.StdArrayType; import com.linkedin.transport.api.types.StdMapType; -import com.linkedin.transport.api.types.StdStructType; import com.linkedin.transport.api.types.StdType; import com.linkedin.transport.api.udf.StdUDF; import java.io.Serializable; -import java.nio.ByteBuffer; import java.util.List; /** - * {@link StdFactory} is used to create {@link StdData} and {@link StdType} objects inside Standard UDFs. + * {@link StdFactory} is used to create containter types (e.g., {@link ArrayData}, {@link MapData}, {@link RowData}) + * and {@link StdType} objects inside Standard UDFs. * * Specific APIs of {@link StdFactory} are implemented by each target platform (e.g., Spark, Trino, Hive) individually. * A {@link StdFactory} object is available inside Standard UDFs using {@link StdUDF#getStdFactory()}. @@ -36,135 +27,79 @@ public interface StdFactory extends Serializable { /** - * Creates a {@link StdInteger} representing a given integer value. - * - * @param value the input integer value - * @return {@link StdInteger} with the given integer value - */ - StdInteger createInteger(int value); - - /** - * Creates a {@link StdLong} representing a given long value. - * - * @param value the input long value - * @return {@link StdLong} with the given long value - */ - StdLong createLong(long value); - - /** - * Creates a {@link StdBoolean} representing a given boolean value. - * - * @param value the input boolean value - * @return {@link StdBoolean} with the given boolean value - */ - StdBoolean createBoolean(boolean value); - - /** - * Creates a {@link StdString} representing a given {@link String} value. - * - * @param value the input {@link String} value - * @return {@link StdString} with the given {@link String} value - */ - StdString createString(String value); - - /** - * Creates a {@link StdFloat} representing a given float value. - * - * @param value the input float value - * @return {@link StdFloat} with the given float value - */ - StdFloat createFloat(float value); - - /** - * Creates a {@link StdDouble} representing a given double value. - * - * @param value the input double value - * @return {@link StdDouble} with the given double value - */ - StdDouble createDouble(double value); - - /** - * Creates a {@link StdBinary} representing a given {@link ByteBuffer} value. - * - * @param value the input {@link ByteBuffer} value - * @return {@link StdBinary} with the given {@link ByteBuffer} value - */ - StdBinary createBinary(ByteBuffer value); - - /** - * Creates an empty {@link StdArray} whose type is given by the given {@link StdType}. + * Creates an empty {@link ArrayData} whose type is given by the given {@link StdType}. * * It is expected that the top-level {@link StdType} is a {@link StdArrayType}. * * @param stdType type of the array to be created * @param expectedSize expected number of entries in the array - * @return an empty {@link StdArray} + * @return an empty {@link ArrayData} */ - StdArray createArray(StdType stdType, int expectedSize); + ArrayData createArray(StdType stdType, int expectedSize); /** - * Creates an empty {@link StdArray} whose type is given by the given {@link StdType}. + * Creates an empty {@link ArrayData} whose type is given by the given {@link StdType}. * * It is expected that the top-level {@link StdType} is a {@link StdArrayType}. * * @param stdType type of the array to be created - * @return an empty {@link StdArray} + * @return an empty {@link ArrayData} */ - StdArray createArray(StdType stdType); + ArrayData createArray(StdType stdType); /** - * Creates an empty {@link StdMap} whose type is given by the given {@link StdType}. + * Creates an empty {@link MapData} whose type is given by the given {@link StdType}. * * It is expected that the top-level {@link StdType} is a {@link StdMapType}. * * @param stdType type of the map to be created - * @return an empty {@link StdMap} + * @return an empty {@link MapData} */ - StdMap createMap(StdType stdType); + MapData createMap(StdType stdType); /** - * Creates a {@link StdStruct} with the given field names and types. + * Creates a {@link RowData} with the given field names and types. * * @param fieldNames names of the struct fields * @param fieldTypes types of the struct fields - * @return a {@link StdStruct} with all fields initialized to null + * @return a {@link RowData} with all fields initialized to null */ - StdStruct createStruct(List fieldNames, List fieldTypes); + RowData createStruct(List fieldNames, List fieldTypes); /** - * Creates a {@link StdStruct} with the given field types. Field names will be field0, field1, field2... + * Creates a {@link RowData} with the given field types. Field names will be field0, field1, field2... * * @param fieldTypes types of the struct fields - * @return a {@link StdStruct} with all fields initialized to null + * @return a {@link RowData} with all fields initialized to null */ - StdStruct createStruct(List fieldTypes); + RowData createStruct(List fieldTypes); /** - * Creates a {@link StdStruct} whose type is given by the given {@link StdType}. + * Creates a {@link RowData} whose type is given by the given {@link StdType}. * - * It is expected that the top-level {@link StdType} is a {@link StdStructType}. + * It is expected that the top-level {@link StdType} is a {@link com.linkedin.transport.api.types.RowType}. * * @param stdType type of the struct to be created - * @return a {@link StdStruct} with all fields initialized to null + * @return a {@link RowData} with all fields initialized to null */ - StdStruct createStruct(StdType stdType); + RowData createStruct(StdType stdType); /** * Creates a {@link StdType} representing the given type signature. * * The following are considered valid type signatures: *
    - *
  • {@code "varchar"} - Represents SQL varchar type. Corresponding standard type is {@link StdString}
  • - *
  • {@code "integer"} - Represents SQL int type. Corresponding standard type is {@link StdInteger}
  • - *
  • {@code "bigint"} - Represents SQL bigint/long type. Corresponding standard type is {@link StdLong}
  • - *
  • {@code "boolean"} - Represents SQL boolean type. Corresponding standard type is {@link StdBoolean}
  • + *
  • {@code "varchar"} - Represents SQL varchar type. Corresponding Transport type is {@link String}
  • + *
  • {@code "integer"} - Represents SQL int type. Corresponding Transport type is {@link Integer}
  • + *
  • {@code "bigint"} - Represents SQL bigint/long type. Corresponding Transport type is {@link Long}
  • + *
  • {@code "boolean"} - Represents SQL boolean type. Corresponding Transport type is {@link Boolean}
  • *
  • {@code "array(T)"} - Represents SQL array type, where {@code T} is type signature of array element. - * Corresponding standard type is {@link StdArray}
  • + * Corresponding Transport type is {@link ArrayData} *
  • {@code "map(K,V)"} - Represents SQL map type, where {@code K} and {@code V} are type signatures of the map - * keys and values respectively. array element. Corresponding standard type is {@link StdMap}
  • + * keys and values respectively. Corresponding Transport type is {@link MapData} *
  • {@code "row(f0 T0, f1 T1,... fn Tn)"} - Represents SQL struct type, where {@code f0}...{@code fn} are field * names and {@code T0}...{@code Tn} are type signatures for the fields. Field names are optional; if not - * specified they default to {@code field0}...{@code fieldn}. Corresponding standard type is {@link StdStruct}
  • + * specified they default to {@code field0}...{@code fieldn}. Corresponding Transport type is {@link RowData} *
* * Generic type parameters can also be used as part of the type signatures; e.g., The type signature {@code "map(K,V)"} diff --git a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdArray.java b/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/ArrayData.java similarity index 76% rename from transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdArray.java rename to transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/ArrayData.java index ac698ae8..65a6b9a6 100644 --- a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdArray.java +++ b/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/ArrayData.java @@ -5,8 +5,8 @@ */ package com.linkedin.transport.api.data; -/** A Standard UDF data type for representing arrays. */ -public interface StdArray extends StdData, Iterable { +/** A Transport UDF data type for representing arrays. */ +public interface ArrayData extends Iterable { /** Returns the number of elements in the array. */ int size(); @@ -16,12 +16,12 @@ public interface StdArray extends StdData, Iterable { * * @param idx the index of the element to be retrieved */ - StdData get(int idx); + E get(int idx); /** * Adds an element to the end of the array. * * @param e the element to append to the array */ - void add(StdData e); + void add(E e); } diff --git a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdMap.java b/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/MapData.java similarity index 78% rename from transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdMap.java rename to transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/MapData.java index 8e67500e..39bd6965 100644 --- a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdMap.java +++ b/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/MapData.java @@ -9,8 +9,8 @@ import java.util.Set; -/** A Standard UDF data type for representing maps. */ -public interface StdMap extends StdData { +/** A Transport UDF data type for representing maps. */ +public interface MapData { /** Returns the number of key-value pairs in the map. */ int size(); @@ -20,7 +20,7 @@ public interface StdMap extends StdData { * * @param key the key whose value is to be returned */ - StdData get(StdData key); + V get(K key); /** * Adds the given value to the map against the given key. @@ -28,18 +28,18 @@ public interface StdMap extends StdData { * @param key the key to which the value is to be associated * @param value the value to be associated with the key */ - void put(StdData key, StdData value); + void put(K key, V value); /** Returns a {@link Set} of all the keys in the map. */ - Set keySet(); + Set keySet(); /** Returns a {@link Collection} of all the values in the map. */ - Collection values(); + Collection values(); /** * Returns true if the map contains the given key, false otherwise. * * @param key the key to be checked */ - boolean containsKey(StdData key); + boolean containsKey(K key); } diff --git a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/PlatformData.java b/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/PlatformData.java index 75df0518..41fc7bae 100644 --- a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/PlatformData.java +++ b/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/PlatformData.java @@ -5,7 +5,7 @@ */ package com.linkedin.transport.api.data; -/** An interface for all platform-specific implementations of {@link StdData}. */ +/** An interface to handle platform-specific container types. */ public interface PlatformData { /** Returns the underlying platform-specific object holding the data. */ diff --git a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdStruct.java b/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/RowData.java similarity index 50% rename from transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdStruct.java rename to transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/RowData.java index 14ccff80..2d8f1ce0 100644 --- a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdStruct.java +++ b/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/RowData.java @@ -8,39 +8,39 @@ import java.util.List; -/** A Standard UDF data type for representing structs. */ -public interface StdStruct extends StdData { +/** A Transport UDF data type for representing SQL ROW/STRUCT data type. */ +public interface RowData { /** - * Returns the value of the field at the given position in the struct. + * Returns the value of the field at the given position in the row. * - * @param index the position of the field in the struct + * @param index the position of the field in the row */ - StdData getField(int index); + Object getField(int index); /** - * Returns the value of the given field from the struct. + * Returns the value of the given field from the row. * * @param name the name of the field */ - StdData getField(String name); + Object getField(String name); /** - * Sets the value of the field at the given position in the struct. + * Sets the value of the field at the given position in the row. * - * @param index the position of the field in the struct + * @param index the position of the field in the row * @param value the value to be set */ - void setField(int index, StdData value); + void setField(int index, Object value); /** - * Sets the value of the given field in the struct. + * Sets the value of the given field in the row. * * @param name the name of the field * @param value the value to be set */ - void setField(String name, StdData value); + void setField(String name, Object value); /** Returns a {@link List} of all fields in the struct. */ - List fields(); + List fields(); } diff --git a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdBinary.java b/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdBinary.java deleted file mode 100644 index d1fc4acb..00000000 --- a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdBinary.java +++ /dev/null @@ -1,15 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.api.data; - -import java.nio.ByteBuffer; - -/** A Standard UDF data type for representing binary objects. */ -public interface StdBinary extends StdData { - - /** Returns the underlying {@link ByteBuffer} value. */ - ByteBuffer get(); -} diff --git a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdBoolean.java b/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdBoolean.java deleted file mode 100644 index ff230bc1..00000000 --- a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdBoolean.java +++ /dev/null @@ -1,13 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.api.data; - -/** A Standard UDF data type for representing booleans. */ -public interface StdBoolean extends StdData { - - /** Returns the underlying boolean value. */ - boolean get(); -} diff --git a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdData.java b/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdData.java deleted file mode 100644 index 77b3d1d7..00000000 --- a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdData.java +++ /dev/null @@ -1,19 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.api.data; - -import com.linkedin.transport.api.StdFactory; - - -/** - * An interface for all data types in Standard UDFs. - * - * {@link StdData} is the main interface through which StdUDFs receive input data and return output data. All Standard - * UDF data types (e.g., {@link StdInteger}, {@link StdArray}, {@link StdMap}) must extend {@link StdData}. Methods - * inside {@link StdFactory} can be used to create {@link StdData} objects. - */ -public interface StdData { -} diff --git a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdDouble.java b/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdDouble.java deleted file mode 100644 index a96fcc0e..00000000 --- a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdDouble.java +++ /dev/null @@ -1,13 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.api.data; - -/** A Standard UDF data type for representing doubles. */ -public interface StdDouble extends StdData { - - /** Returns the underlying double value. */ - double get(); -} diff --git a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdFloat.java b/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdFloat.java deleted file mode 100644 index da76dd28..00000000 --- a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdFloat.java +++ /dev/null @@ -1,13 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.api.data; - -/** A Standard UDF data type for representing floats. */ -public interface StdFloat extends StdData { - - /** Returns the underlying float value. */ - float get(); -} diff --git a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdInteger.java b/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdInteger.java deleted file mode 100644 index c74a92dd..00000000 --- a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdInteger.java +++ /dev/null @@ -1,13 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.api.data; - -/** A Standard UDF data type for representing integers. */ -public interface StdInteger extends StdData { - - /** Returns the underlying int value. */ - int get(); -} diff --git a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdLong.java b/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdLong.java deleted file mode 100644 index 84f322f7..00000000 --- a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdLong.java +++ /dev/null @@ -1,13 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.api.data; - -/** A Standard UDF data type for representing longs. */ -public interface StdLong extends StdData { - - /** Returns the underlying long value. */ - long get(); -} diff --git a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdString.java b/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdString.java deleted file mode 100644 index 7ccd8385..00000000 --- a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdString.java +++ /dev/null @@ -1,13 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.api.data; - -/** A Standard UDF data type for representing strings. */ -public interface StdString extends StdData { - - /** Returns the underlying {@link String} value. */ - String get(); -} diff --git a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdTimestamp.java b/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdTimestamp.java deleted file mode 100644 index 18ff9bc1..00000000 --- a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdTimestamp.java +++ /dev/null @@ -1,13 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.api.data; - -/** A Standard UDF data type for representing timestamps. */ -public interface StdTimestamp extends StdData { - - /** Returns the number of milliseconds elapsed from epoch for the {@link StdTimestamp}. */ - long toEpoch(); -} diff --git a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/types/StdStructType.java b/transportable-udfs-api/src/main/java/com/linkedin/transport/api/types/RowType.java similarity index 89% rename from transportable-udfs-api/src/main/java/com/linkedin/transport/api/types/StdStructType.java rename to transportable-udfs-api/src/main/java/com/linkedin/transport/api/types/RowType.java index 521ec2d7..59938208 100644 --- a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/types/StdStructType.java +++ b/transportable-udfs-api/src/main/java/com/linkedin/transport/api/types/RowType.java @@ -9,7 +9,7 @@ /** A {@link StdType} representing a struct type. */ -public interface StdStructType extends StdType { +public interface RowType extends StdType { /** Returns a {@link List} of the types of all the struct fields. */ List fieldTypes(); diff --git a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/udf/StdUDF.java b/transportable-udfs-api/src/main/java/com/linkedin/transport/api/udf/StdUDF.java index a4e29220..f90b91ca 100644 --- a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/udf/StdUDF.java +++ b/transportable-udfs-api/src/main/java/com/linkedin/transport/api/udf/StdUDF.java @@ -6,8 +6,6 @@ package com.linkedin.transport.api.udf; import com.linkedin.transport.api.StdFactory; -import com.linkedin.transport.api.data.StdData; -import com.linkedin.transport.api.types.StdType; import java.util.List; @@ -19,8 +17,7 @@ * abstract class for UDFs expecting {@code i} arguments. Similar to lambda expressions, StdUDF(i) abstract classes are * type-parameterized by the input types and output type of the eval function. Each class is type-parameterized by * {@code (i+1)} type parameters; {@code i} type parameters for the UDF input types, and one type parameter for the - * output type. All types (both input and output types) must extend the {@link StdData} - * interface. + * output type. */ public abstract class StdUDF { private StdFactory _stdFactory; @@ -40,7 +37,7 @@ public abstract class StdUDF { * of contained UDF. * * @param stdFactory a {@link StdFactory} object which can be used to create - * {@link StdData} and {@link StdType} objects + * data and type objects */ public void init(StdFactory stdFactory) { _stdFactory = stdFactory; @@ -85,8 +82,8 @@ public final boolean[] getAndCheckNullableArguments() { protected abstract int numberOfArguments(); /** - * Returns a {@link StdFactory} object which can be used to create {@link StdData} and - * {@link StdType} objects + * Returns a {@link StdFactory} object which can be used to create data and + * type objects */ public StdFactory getStdFactory() { return _stdFactory; diff --git a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/udf/StdUDF0.java b/transportable-udfs-api/src/main/java/com/linkedin/transport/api/udf/StdUDF0.java index b62fe95f..d3558fc3 100644 --- a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/udf/StdUDF0.java +++ b/transportable-udfs-api/src/main/java/com/linkedin/transport/api/udf/StdUDF0.java @@ -5,15 +5,13 @@ */ package com.linkedin.transport.api.udf; -import com.linkedin.transport.api.data.StdData; - /** * A Standard UDF with zero input arguments. * * @param the type of the return value of the {@link StdUDF} */ -public abstract class StdUDF0 extends StdUDF { +public abstract class StdUDF0 extends StdUDF { /** * Returns the output of the {@link StdUDF} given the input arguments. diff --git a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/udf/StdUDF1.java b/transportable-udfs-api/src/main/java/com/linkedin/transport/api/udf/StdUDF1.java index 28d0ad71..18e8e769 100644 --- a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/udf/StdUDF1.java +++ b/transportable-udfs-api/src/main/java/com/linkedin/transport/api/udf/StdUDF1.java @@ -5,8 +5,6 @@ */ package com.linkedin.transport.api.udf; -import com.linkedin.transport.api.data.StdData; - /** * A Standard UDF with one input argument. @@ -17,7 +15,7 @@ // Suppressing class parameter type parameter name and arg naming style checks since this naming convention is more // suitable to Standard UDFs, and the code is more readable this way. @SuppressWarnings({"checkstyle:classtypeparametername", "checkstyle:regexpsinglelinejava"}) -public abstract class StdUDF1 extends StdUDF { +public abstract class StdUDF1 extends StdUDF { /** * Returns the output of the {@link StdUDF} given the input arguments. @@ -38,7 +36,7 @@ public abstract class StdUDF1 extends Std * hence obtaining the most recent version of a file. * Example: 'hdfs:///data/derived/dwh/prop/testMemberId/#LATEST/testMemberId.txt' * - * The arguments passed to {@link #eval(StdData)} are passed to this method as well to allow users to construct + * The arguments passed to {@link #eval(Object)} are passed to this method as well to allow users to construct * required file paths from arguments passed to the UDF. Since this method is called before any rows are processed, * only constant UDF arguments should be used to construct the file paths. Values of non-constant arguments are not * deterministic, and are null for most platforms. (Constant arguments are arguments whose literal values are given diff --git a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/udf/StdUDF2.java b/transportable-udfs-api/src/main/java/com/linkedin/transport/api/udf/StdUDF2.java index 3e020ae1..3eb293ba 100644 --- a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/udf/StdUDF2.java +++ b/transportable-udfs-api/src/main/java/com/linkedin/transport/api/udf/StdUDF2.java @@ -5,8 +5,6 @@ */ package com.linkedin.transport.api.udf; -import com.linkedin.transport.api.data.StdData; - /** * A Standard UDF with three input arguments. @@ -18,7 +16,7 @@ // Suppressing class parameter type parameter name and arg naming style checks since this naming convention is more // suitable to Standard UDFs, and the code is more readable this way. @SuppressWarnings({"checkstyle:classtypeparametername", "checkstyle:regexpsinglelinejava"}) -public abstract class StdUDF2 extends StdUDF { +public abstract class StdUDF2 extends StdUDF { /** * Returns the output of the {@link StdUDF} given the input arguments. @@ -40,7 +38,7 @@ public abstract class StdUDF2 +public abstract class StdUDF3 extends StdUDF { /** @@ -43,7 +41,7 @@ public abstract class StdUDF3 +public abstract class StdUDF4 extends StdUDF { /** @@ -45,7 +43,7 @@ public abstract class StdUDF4 extends StdUDF { +public abstract class StdUDF5 extends StdUDF { /** * Returns the output of the {@link StdUDF} given the input arguments. @@ -47,7 +45,7 @@ public abstract class StdUDF5 extends StdUDF { +public abstract class StdUDF6 extends StdUDF { /** * Returns the output of the {@link StdUDF} given the input arguments. @@ -49,7 +47,7 @@ public abstract class StdUDF6 extends StdUDF { +public abstract class StdUDF7 extends StdUDF { /** * Returns the output of the {@link StdUDF} given the input arguments. @@ -51,7 +49,7 @@ public abstract class StdUDF7 extends StdUDF { +public abstract class StdUDF8 extends StdUDF { /** * Returns the output of the {@link StdUDF} given the input arguments. @@ -53,7 +51,7 @@ public abstract class StdUDF8 boundVariables) { } @Override - public StdInteger createInteger(int value) { - return new AvroInteger(value); + public ArrayData createArray(StdType stdType, int size) { + return new AvroArrayData((Schema) stdType.underlyingType(), size); } @Override - public StdLong createLong(long value) { - return new AvroLong(value); - } - - @Override - public StdBoolean createBoolean(boolean value) { - return new AvroBoolean(value); - } - - @Override - public StdString createString(String value) { - return new AvroString(new Utf8(value)); - } - - @Override - public StdFloat createFloat(float value) { - return new AvroFloat(value); - } - - @Override - public StdDouble createDouble(double value) { - return new AvroDouble(value); - } - - @Override - public StdBinary createBinary(ByteBuffer value) { - return new AvroBinary(value); - } - - @Override - public StdArray createArray(StdType stdType, int size) { - return new AvroArray((Schema) stdType.underlyingType(), size); - } - - @Override - public StdArray createArray(StdType stdType) { + public ArrayData createArray(StdType stdType) { return createArray(stdType, 0); } @Override - public StdMap createMap(StdType stdType) { - return new AvroMap((Schema) stdType.underlyingType()); + public MapData createMap(StdType stdType) { + return new AvroMapData((Schema) stdType.underlyingType()); } @Override - public StdStruct createStruct(List fieldNames, List fieldTypes) { + public RowData createStruct(List fieldNames, List fieldTypes) { if (fieldNames.size() != fieldTypes.size()) { throw new RuntimeException( "Field names and types are of different lengths: " + "Field names length is " + fieldNames.size() + ". " @@ -112,18 +61,18 @@ public StdStruct createStruct(List fieldNames, List fieldTypes) for (int i = 0; i < fieldTypes.size(); i++) { fields.add(new Field(fieldNames.get(i), (Schema) fieldTypes.get(i).underlyingType(), null, null)); } - return new AvroStruct(Schema.createRecord(fields)); + return new AvroRowData(Schema.createRecord(fields)); } @Override - public StdStruct createStruct(List fieldTypes) { + public RowData createStruct(List fieldTypes) { return createStruct(IntStream.range(0, fieldTypes.size()).mapToObj(i -> "field" + i).collect(Collectors.toList()), fieldTypes); } @Override - public StdStruct createStruct(StdType stdType) { - return new AvroStruct((Schema) stdType.underlyingType()); + public RowData createStruct(StdType stdType) { + return new AvroRowData((Schema) stdType.underlyingType()); } @Override diff --git a/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/AvroWrapper.java b/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/AvroWrapper.java index 1f7657e5..d8149724 100644 --- a/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/AvroWrapper.java +++ b/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/AvroWrapper.java @@ -5,18 +5,11 @@ */ package com.linkedin.transport.avro; -import com.linkedin.transport.api.data.StdData; +import com.linkedin.transport.api.data.PlatformData; import com.linkedin.transport.api.types.StdType; -import com.linkedin.transport.avro.data.AvroArray; -import com.linkedin.transport.avro.data.AvroBinary; -import com.linkedin.transport.avro.data.AvroBoolean; -import com.linkedin.transport.avro.data.AvroDouble; -import com.linkedin.transport.avro.data.AvroFloat; -import com.linkedin.transport.avro.data.AvroInteger; -import com.linkedin.transport.avro.data.AvroLong; -import com.linkedin.transport.avro.data.AvroMap; -import com.linkedin.transport.avro.data.AvroString; -import com.linkedin.transport.avro.data.AvroStruct; +import com.linkedin.transport.avro.data.AvroArrayData; +import com.linkedin.transport.avro.data.AvroMapData; +import com.linkedin.transport.avro.data.AvroRowData; import com.linkedin.transport.avro.types.AvroArrayType; import com.linkedin.transport.avro.types.AvroBinaryType; import com.linkedin.transport.avro.types.AvroBooleanType; @@ -26,13 +19,12 @@ import com.linkedin.transport.avro.types.AvroLongType; import com.linkedin.transport.avro.types.AvroMapType; import com.linkedin.transport.avro.types.AvroStringType; -import com.linkedin.transport.avro.types.AvroStructType; +import com.linkedin.transport.avro.types.AvroRowType; import java.nio.ByteBuffer; import java.util.List; import java.util.Map; import org.apache.avro.Schema; import org.apache.avro.generic.GenericArray; -import org.apache.avro.generic.GenericEnumSymbol; import org.apache.avro.generic.GenericRecord; import org.apache.avro.util.Utf8; @@ -42,50 +34,28 @@ public class AvroWrapper { private AvroWrapper() { } - public static StdData createStdData(Object avroData, Schema avroSchema) { + public static Object createStdData(Object avroData, Schema avroSchema) { switch (avroSchema.getType()) { case INT: - return new AvroInteger((Integer) avroData); case LONG: - return new AvroLong((Long) avroData); case BOOLEAN: - return new AvroBoolean((Boolean) avroData); - case ENUM: { - if (avroData == null) { - return new AvroString(null); - } - - if (avroData instanceof String) { - return new AvroString(new Utf8((String) avroData)); - } else if (avroData instanceof GenericEnumSymbol) { - return new AvroString(new Utf8(((GenericEnumSymbol) avroData).toString())); - } - throw new IllegalArgumentException("Unsupported type for Avro enum: " + avroData.getClass()); - } - case STRING: { - if (avroData == null) { - return new AvroString(null); - } - - if (avroData instanceof Utf8) { - return new AvroString((Utf8) avroData); - } else if (avroData instanceof String) { - return new AvroString(new Utf8((String) avroData)); - } - throw new IllegalArgumentException("Unsupported type for Avro string: " + avroData.getClass()); - } case FLOAT: - return new AvroFloat((Float) avroData); case DOUBLE: - return new AvroDouble((Double) avroData); case BYTES: - return new AvroBinary((ByteBuffer) avroData); + return avroData; + case STRING: + case ENUM: + if (avroData == null) { + return null; + } else { + return avroData.toString(); + } case ARRAY: - return new AvroArray((GenericArray) avroData, avroSchema); + return new AvroArrayData((GenericArray) avroData, avroSchema); case MAP: - return new AvroMap((Map) avroData, avroSchema); + return new AvroMapData((Map) avroData, avroSchema); case RECORD: - return new AvroStruct((GenericRecord) avroData, avroSchema); + return new AvroRowData((GenericRecord) avroData, avroSchema); case UNION: { Schema nonNullableType = getNonNullComponent(avroSchema); if (avroData == null) { @@ -100,6 +70,18 @@ public static StdData createStdData(Object avroData, Schema avroSchema) { } } + public static Object getPlatformData(Object transportData) { + if (transportData instanceof Integer || transportData instanceof Long || transportData instanceof Double + || transportData instanceof Boolean || transportData instanceof ByteBuffer) { + return transportData; + } else if (transportData instanceof String) { + return transportData == null ? null : new Utf8((String) transportData); + } else { + return transportData == null ? null : ((PlatformData) transportData).getUnderlyingData(); + } + } + + /** * Returns a non null component of a simple union schema. The supported union schema must have * only two fields where one of them is null type, the other is returned. @@ -139,7 +121,7 @@ public static StdType createStdType(Schema avroSchema) { case MAP: return new AvroMapType(avroSchema); case RECORD: - return new AvroStructType(avroSchema); + return new AvroRowType(avroSchema); case UNION: { Schema nonNullableType = getNonNullComponent(avroSchema); return createStdType(nonNullableType); diff --git a/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/StdUdfWrapper.java b/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/StdUdfWrapper.java index a1ea3e0d..a75a8e6a 100644 --- a/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/StdUdfWrapper.java +++ b/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/StdUdfWrapper.java @@ -7,7 +7,6 @@ import com.linkedin.transport.api.StdFactory; import com.linkedin.transport.api.data.PlatformData; -import com.linkedin.transport.api.data.StdData; import com.linkedin.transport.api.udf.StdUDF; import com.linkedin.transport.api.udf.StdUDF0; import com.linkedin.transport.api.udf.StdUDF1; @@ -36,7 +35,7 @@ public abstract class StdUdfWrapper { protected boolean _requiredFilesProcessed; protected StdFactory _stdFactory; private boolean[] _nullableArguments; - private StdData[] _args; + private Object[] _args; /** * Given input schemas, this method matches them to the expected type signatures, and finds bindings to the @@ -68,12 +67,27 @@ protected boolean containsNullValuedNonNullableArgument(Object[] arguments) { return false; } - protected StdData wrap(Object avroObject, StdData stdData) { - if (avroObject != null) { - ((PlatformData) stdData).setUnderlyingData(avroObject); - return stdData; - } else { - return null; + protected Object wrap(Object avroObject, Schema inputSchema, Object stdData) { + switch (inputSchema.getType()) { + case INT: + case LONG: + case BOOLEAN: + return avroObject; + case STRING: + return avroObject == null ? null : avroObject.toString(); + case ARRAY: + case MAP: + case RECORD: + if (avroObject != null) { + ((PlatformData) stdData).setUnderlyingData(avroObject); + return stdData; + } else { + return null; + } + case NULL: + return null; + default: + throw new RuntimeException("Unrecognized Avro Schema: " + inputSchema.getClass()); } } @@ -82,22 +96,24 @@ protected StdData wrap(Object avroObject, StdData stdData) { protected abstract Class getTopLevelUdfClass(); protected void createStdData() { - _args = new StdData[_inputSchemas.length]; + _args = new Object[_inputSchemas.length]; for (int i = 0; i < _inputSchemas.length; i++) { _args[i] = AvroWrapper.createStdData(null, _inputSchemas[i]); } } - private StdData[] wrapArguments(Object[] arguments) { - return IntStream.range(0, _args.length).mapToObj(i -> wrap(arguments[i], _args[i])).toArray(StdData[]::new); + private Object[] wrapArguments(Object[] arguments) { + return IntStream.range(0, _args.length).mapToObj( + i -> wrap(arguments[i], _inputSchemas[i], _args[i]) + ).toArray(Object[]::new); } public Object evaluate(Object[] arguments) { if (containsNullValuedNonNullableArgument(arguments)) { return null; } - StdData[] args = wrapArguments(arguments); - StdData result; + Object[] args = wrapArguments(arguments); + Object result; switch (args.length) { case 0: result = ((StdUDF0) _stdUdf).eval(); @@ -129,6 +145,6 @@ public Object evaluate(Object[] arguments) { default: throw new UnsupportedOperationException("eval not yet supported for StdUDF" + args.length); } - return result == null ? null : ((PlatformData) result).getUnderlyingData(); + return AvroWrapper.getPlatformData(result); } } diff --git a/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroArray.java b/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroArrayData.java similarity index 65% rename from transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroArray.java rename to transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroArrayData.java index 1557ed6c..3cb2e6f6 100644 --- a/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroArray.java +++ b/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroArrayData.java @@ -6,8 +6,7 @@ package com.linkedin.transport.avro.data; import com.linkedin.transport.api.data.PlatformData; -import com.linkedin.transport.api.data.StdArray; -import com.linkedin.transport.api.data.StdData; +import com.linkedin.transport.api.data.ArrayData; import com.linkedin.transport.avro.AvroWrapper; import java.util.Iterator; import org.apache.avro.Schema; @@ -15,16 +14,16 @@ import org.apache.avro.generic.GenericData; -public class AvroArray implements StdArray, PlatformData { +public class AvroArrayData implements ArrayData, PlatformData { private final Schema _elementSchema; private GenericArray _genericArray; - public AvroArray(GenericArray genericArray, Schema arraySchema) { + public AvroArrayData(GenericArray genericArray, Schema arraySchema) { _genericArray = genericArray; _elementSchema = arraySchema.getElementType(); } - public AvroArray(Schema arraySchema, int size) { + public AvroArrayData(Schema arraySchema, int size) { _elementSchema = arraySchema.getElementType(); _genericArray = new GenericData.Array(size, arraySchema); } @@ -35,18 +34,18 @@ public int size() { } @Override - public StdData get(int idx) { - return AvroWrapper.createStdData(_genericArray.get(idx), _elementSchema); + public E get(int idx) { + return (E) AvroWrapper.createStdData(_genericArray.get(idx), _elementSchema); } @Override - public void add(StdData e) { - _genericArray.add(((PlatformData) e).getUnderlyingData()); + public void add(E e) { + _genericArray.add(AvroWrapper.getPlatformData(e)); } @Override - public Iterator iterator() { - return new Iterator() { + public Iterator iterator() { + return new Iterator() { private final Iterator _iterator = _genericArray.iterator(); @Override @@ -55,8 +54,8 @@ public boolean hasNext() { } @Override - public StdData next() { - return AvroWrapper.createStdData(_iterator.next(), _elementSchema); + public E next() { + return (E) AvroWrapper.createStdData(_iterator.next(), _elementSchema); } }; } diff --git a/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroBinary.java b/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroBinary.java deleted file mode 100644 index 902e610d..00000000 --- a/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroBinary.java +++ /dev/null @@ -1,34 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.avro.data; - -import com.linkedin.transport.api.data.PlatformData; -import com.linkedin.transport.api.data.StdBinary; -import java.nio.ByteBuffer; - - -public class AvroBinary implements StdBinary, PlatformData { - private ByteBuffer _byteBuffer; - - public AvroBinary(ByteBuffer aByteBuffer) { - _byteBuffer = aByteBuffer; - } - - @Override - public Object getUnderlyingData() { - return _byteBuffer; - } - - @Override - public void setUnderlyingData(Object value) { - _byteBuffer = (ByteBuffer) value; - } - - @Override - public ByteBuffer get() { - return _byteBuffer; - } -} diff --git a/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroBoolean.java b/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroBoolean.java deleted file mode 100644 index 99f83738..00000000 --- a/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroBoolean.java +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.avro.data; - -import com.linkedin.transport.api.data.PlatformData; -import com.linkedin.transport.api.data.StdBoolean; - - -public class AvroBoolean implements StdBoolean, PlatformData { - private Boolean _boolean; - - public AvroBoolean(Boolean aBoolean) { - _boolean = aBoolean; - } - - @Override - public boolean get() { - return _boolean; - } - - @Override - public Object getUnderlyingData() { - return _boolean; - } - - @Override - public void setUnderlyingData(Object value) { - _boolean = (Boolean) value; - } -} diff --git a/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroDouble.java b/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroDouble.java deleted file mode 100644 index 214443ae..00000000 --- a/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroDouble.java +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.avro.data; - -import com.linkedin.transport.api.data.PlatformData; -import com.linkedin.transport.api.data.StdDouble; - - -public class AvroDouble implements StdDouble, PlatformData { - private Double _double; - - public AvroDouble(Double aDouble) { - _double = aDouble; - } - - @Override - public Object getUnderlyingData() { - return _double; - } - - @Override - public void setUnderlyingData(Object value) { - _double = (Double) value; - } - - @Override - public double get() { - return _double; - } -} diff --git a/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroFloat.java b/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroFloat.java deleted file mode 100644 index c4547d81..00000000 --- a/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroFloat.java +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.avro.data; - -import com.linkedin.transport.api.data.PlatformData; -import com.linkedin.transport.api.data.StdFloat; - - -public class AvroFloat implements StdFloat, PlatformData { - private Float _float; - - public AvroFloat(Float aFloat) { - _float = aFloat; - } - - @Override - public Object getUnderlyingData() { - return _float; - } - - @Override - public void setUnderlyingData(Object value) { - _float = (Float) value; - } - - @Override - public float get() { - return _float; - } -} diff --git a/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroInteger.java b/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroInteger.java deleted file mode 100644 index 5a170f3b..00000000 --- a/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroInteger.java +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.avro.data; - -import com.linkedin.transport.api.data.PlatformData; -import com.linkedin.transport.api.data.StdInteger; - - -public class AvroInteger implements StdInteger, PlatformData { - private Integer _integer; - - public AvroInteger(Integer integer) { - _integer = integer; - } - - @Override - public int get() { - return _integer; - } - - @Override - public Object getUnderlyingData() { - return _integer; - } - - @Override - public void setUnderlyingData(Object value) { - _integer = (Integer) value; - } -} diff --git a/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroLong.java b/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroLong.java deleted file mode 100644 index a56af06c..00000000 --- a/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroLong.java +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.avro.data; - -import com.linkedin.transport.api.data.PlatformData; -import com.linkedin.transport.api.data.StdLong; - - -public class AvroLong implements StdLong, PlatformData { - private Long _long; - - public AvroLong(Long aLong) { - _long = aLong; - } - - @Override - public long get() { - return _long; - } - - @Override - public Object getUnderlyingData() { - return _long; - } - - @Override - public void setUnderlyingData(Object value) { - _long = (Long) value; - } -} diff --git a/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroMap.java b/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroMapData.java similarity index 59% rename from transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroMap.java rename to transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroMapData.java index d0913d53..2b95796c 100644 --- a/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroMap.java +++ b/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroMapData.java @@ -6,8 +6,7 @@ package com.linkedin.transport.avro.data; import com.linkedin.transport.api.data.PlatformData; -import com.linkedin.transport.api.data.StdData; -import com.linkedin.transport.api.data.StdMap; +import com.linkedin.transport.api.data.MapData; import com.linkedin.transport.avro.AvroWrapper; import java.util.AbstractSet; import java.util.Collection; @@ -21,18 +20,18 @@ import static org.apache.avro.Schema.Type.*; -public class AvroMap implements StdMap, PlatformData { +public class AvroMapData implements MapData, PlatformData { private Map _map; private final Schema _keySchema; private final Schema _valueSchema; - public AvroMap(Map map, Schema mapSchema) { + public AvroMapData(Map map, Schema mapSchema) { _map = map; _keySchema = Schema.create(STRING); _valueSchema = mapSchema.getValueType(); } - public AvroMap(Schema mapSchema) { + public AvroMapData(Schema mapSchema) { _map = new LinkedHashMap<>(); _keySchema = Schema.create(STRING); _valueSchema = mapSchema.getValueType(); @@ -54,21 +53,21 @@ public int size() { } @Override - public StdData get(StdData key) { - return AvroWrapper.createStdData(_map.get(((PlatformData) key).getUnderlyingData()), _valueSchema); + public V get(K key) { + return (V) AvroWrapper.createStdData(_map.get(AvroWrapper.getPlatformData(key)), _valueSchema); } @Override - public void put(StdData key, StdData value) { - _map.put(((PlatformData) key).getUnderlyingData(), ((PlatformData) value).getUnderlyingData()); + public void put(K key, V value) { + _map.put(AvroWrapper.getPlatformData(key), AvroWrapper.getPlatformData(value)); } @Override - public Set keySet() { - return new AbstractSet() { + public Set keySet() { + return new AbstractSet() { @Override - public Iterator iterator() { - return new Iterator() { + public Iterator iterator() { + return new Iterator() { Iterator keySet = _map.keySet().iterator(); @Override public boolean hasNext() { @@ -76,8 +75,8 @@ public boolean hasNext() { } @Override - public StdData next() { - return AvroWrapper.createStdData(keySet.next(), _keySchema); + public K next() { + return (K) AvroWrapper.createStdData(keySet.next(), _keySchema); } }; } @@ -90,12 +89,12 @@ public int size() { } @Override - public Collection values() { - return _map.values().stream().map(v -> AvroWrapper.createStdData(v, _valueSchema)).collect(Collectors.toList()); + public Collection values() { + return _map.values().stream().map(v -> (V) AvroWrapper.createStdData(v, _valueSchema)).collect(Collectors.toList()); } @Override - public boolean containsKey(StdData key) { - return _map.containsKey(((PlatformData) key).getUnderlyingData()); + public boolean containsKey(K key) { + return _map.containsKey(AvroWrapper.getPlatformData(key)); } } diff --git a/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroStruct.java b/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroRowData.java similarity index 68% rename from transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroStruct.java rename to transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroRowData.java index f018d5bc..64fa5e4c 100644 --- a/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroStruct.java +++ b/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroRowData.java @@ -6,8 +6,7 @@ package com.linkedin.transport.avro.data; import com.linkedin.transport.api.data.PlatformData; -import com.linkedin.transport.api.data.StdData; -import com.linkedin.transport.api.data.StdStruct; +import com.linkedin.transport.api.data.RowData; import com.linkedin.transport.avro.AvroWrapper; import java.util.List; import java.util.stream.Collectors; @@ -17,17 +16,17 @@ import org.apache.avro.generic.GenericRecord; -public class AvroStruct implements StdStruct, PlatformData { +public class AvroRowData implements RowData, PlatformData { private final Schema _recordSchema; private GenericRecord _genericRecord; - public AvroStruct(GenericRecord genericRecord, Schema recordSchema) { + public AvroRowData(GenericRecord genericRecord, Schema recordSchema) { _genericRecord = genericRecord; _recordSchema = recordSchema; } - public AvroStruct(Schema recordSchema) { + public AvroRowData(Schema recordSchema) { _genericRecord = new Record(recordSchema); _recordSchema = recordSchema; } @@ -43,27 +42,27 @@ public void setUnderlyingData(Object value) { } @Override - public StdData getField(int index) { + public Object getField(int index) { return AvroWrapper.createStdData(_genericRecord.get(index), _recordSchema.getFields().get(index).schema()); } @Override - public StdData getField(String name) { + public Object getField(String name) { return AvroWrapper.createStdData(_genericRecord.get(name), _recordSchema.getField(name).schema()); } @Override - public void setField(int index, StdData value) { - _genericRecord.put(index, ((PlatformData) value).getUnderlyingData()); + public void setField(int index, Object value) { + _genericRecord.put(index, AvroWrapper.getPlatformData(value)); } @Override - public void setField(String name, StdData value) { - _genericRecord.put(name, ((PlatformData) value).getUnderlyingData()); + public void setField(String name, Object value) { + _genericRecord.put(name, AvroWrapper.getPlatformData(value)); } @Override - public List fields() { + public List fields() { return IntStream.range(0, _recordSchema.getFields().size()).mapToObj(i -> getField(i)).collect(Collectors.toList()); } } diff --git a/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroString.java b/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroString.java deleted file mode 100644 index 745df05e..00000000 --- a/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroString.java +++ /dev/null @@ -1,34 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.avro.data; - -import com.linkedin.transport.api.data.PlatformData; -import com.linkedin.transport.api.data.StdString; -import org.apache.avro.util.Utf8; - - -public class AvroString implements StdString, PlatformData { - private Utf8 _string; - - public AvroString(Utf8 string) { - _string = string; - } - - @Override - public String get() { - return _string.toString(); - } - - @Override - public Object getUnderlyingData() { - return _string; - } - - @Override - public void setUnderlyingData(Object value) { - _string = (Utf8) value; - } -} diff --git a/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/types/AvroStructType.java b/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/types/AvroRowType.java similarity index 82% rename from transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/types/AvroStructType.java rename to transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/types/AvroRowType.java index 2c97b39c..3923b3f5 100644 --- a/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/types/AvroStructType.java +++ b/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/types/AvroRowType.java @@ -5,7 +5,7 @@ */ package com.linkedin.transport.avro.types; -import com.linkedin.transport.api.types.StdStructType; +import com.linkedin.transport.api.types.RowType; import com.linkedin.transport.api.types.StdType; import com.linkedin.transport.avro.AvroWrapper; import java.util.List; @@ -13,10 +13,10 @@ import org.apache.avro.Schema; -public class AvroStructType implements StdStructType { +public class AvroRowType implements RowType { final private Schema _schema; - public AvroStructType(Schema schema) { + public AvroRowType(Schema schema) { _schema = schema; } diff --git a/transportable-udfs-avro/src/test/java/com/linkedin/transport/avro/TestAvroWrapper.java b/transportable-udfs-avro/src/test/java/com/linkedin/transport/avro/TestAvroWrapper.java index d81a452a..2257e977 100644 --- a/transportable-udfs-avro/src/test/java/com/linkedin/transport/avro/TestAvroWrapper.java +++ b/transportable-udfs-avro/src/test/java/com/linkedin/transport/avro/TestAvroWrapper.java @@ -7,37 +7,20 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import com.linkedin.transport.api.data.PlatformData; -import com.linkedin.transport.api.data.StdData; import com.linkedin.transport.api.types.StdType; -import com.linkedin.transport.avro.data.AvroArray; -import com.linkedin.transport.avro.data.AvroBinary; -import com.linkedin.transport.avro.data.AvroBoolean; -import com.linkedin.transport.avro.data.AvroDouble; -import com.linkedin.transport.avro.data.AvroFloat; -import com.linkedin.transport.avro.data.AvroInteger; -import com.linkedin.transport.avro.data.AvroLong; -import com.linkedin.transport.avro.data.AvroMap; -import com.linkedin.transport.avro.data.AvroString; -import com.linkedin.transport.avro.data.AvroStruct; +import com.linkedin.transport.avro.data.AvroArrayData; +import com.linkedin.transport.avro.data.AvroMapData; +import com.linkedin.transport.avro.data.AvroRowData; import com.linkedin.transport.avro.types.AvroArrayType; -import com.linkedin.transport.avro.types.AvroBinaryType; -import com.linkedin.transport.avro.types.AvroBooleanType; -import com.linkedin.transport.avro.types.AvroDoubleType; -import com.linkedin.transport.avro.types.AvroFloatType; -import com.linkedin.transport.avro.types.AvroIntegerType; import com.linkedin.transport.avro.types.AvroLongType; import com.linkedin.transport.avro.types.AvroMapType; -import com.linkedin.transport.avro.types.AvroStringType; -import com.linkedin.transport.avro.types.AvroStructType; -import java.nio.ByteBuffer; +import com.linkedin.transport.avro.types.AvroRowType; import java.util.Arrays; import java.util.Map; import org.apache.avro.Schema; import org.apache.avro.generic.GenericArray; import org.apache.avro.generic.GenericData; import org.apache.avro.generic.GenericRecord; -import org.apache.avro.util.Utf8; import org.testng.annotations.Test; import static org.testng.Assert.*; @@ -54,61 +37,6 @@ private Schema createSchema(String fieldName, String typeName) { String.format("{\"name\": \"%s\",\"type\": %s}", fieldName, typeName)); } - private void testSimpleType(String typeName, Class expectedAvroTypeClass, - Object testData, Class expectedDataClass) { - Schema avroSchema = createSchema(String.format("\"%s\"", typeName)); - - StdType stdType = AvroWrapper.createStdType(avroSchema); - assertTrue(expectedAvroTypeClass.isAssignableFrom(stdType.getClass())); - assertEquals(avroSchema, stdType.underlyingType()); - - StdData stdData = AvroWrapper.createStdData(testData, avroSchema); - assertNotNull(stdData); - assertTrue(expectedDataClass.isAssignableFrom(stdData.getClass())); - if ("string".equals(typeName)) { - // Use String values for equality assertion as we support both Utf8 and String input types - assertEquals(testData.toString(), ((PlatformData) stdData).getUnderlyingData().toString()); - } else { - assertEquals(testData, ((PlatformData) stdData).getUnderlyingData()); - } - } - - @Test - public void testBooleanType() { - testSimpleType("boolean", AvroBooleanType.class, true, AvroBoolean.class); - } - - @Test - public void testIntegerType() { - testSimpleType("int", AvroIntegerType.class, 1, AvroInteger.class); - } - - @Test - public void testLongType() { - testSimpleType("long", AvroLongType.class, 1L, AvroLong.class); - } - - @Test - public void testFloatType() { - testSimpleType("float", AvroFloatType.class, 1.0f, AvroFloat.class); - } - - @Test - public void testDoubleType() { - testSimpleType("double", AvroDoubleType.class, 1.0, AvroDouble.class); - } - - @Test - public void testStringType() { - testSimpleType("string", AvroStringType.class, new Utf8("foo"), AvroString.class); - testSimpleType("string", AvroStringType.class, "foo", AvroString.class); - } - - @Test - public void testBinaryType() { - testSimpleType("bytes", AvroBinaryType.class, ByteBuffer.wrap("bar".getBytes()), AvroBinary.class); - } - @Test public void testEnumType() { Schema field1 = createSchema("field1", "" @@ -122,17 +50,17 @@ public void testEnumType() { GenericRecord record1 = new GenericData.Record(structSchema); record1.put("field1", "A"); - StdData stdEnumData1 = AvroWrapper.createStdData(record1.get("field1"), + Object stdEnumData1 = AvroWrapper.createStdData(record1.get("field1"), Schema.createEnum("SampleEnum", "", "", Arrays.asList("A", "B"))); - assertTrue(stdEnumData1 instanceof AvroString); - assertEquals("A", ((AvroString) stdEnumData1).get()); + assertTrue(stdEnumData1 instanceof String); + assertEquals("A", ((String) stdEnumData1)); GenericRecord record2 = new GenericData.Record(structSchema); record1.put("field1", new GenericData.EnumSymbol(field1, "A")); - StdData stdEnumData2 = AvroWrapper.createStdData(record1.get("field1"), + Object stdEnumData2 = AvroWrapper.createStdData(record1.get("field1"), Schema.createEnum("SampleEnum", "", "", Arrays.asList("A", "B"))); - assertTrue(stdEnumData2 instanceof AvroString); - assertEquals("A", ((AvroString) stdEnumData2).get()); + assertTrue(stdEnumData2 instanceof String); + assertEquals("A", ((String) stdEnumData2)); } @Test @@ -142,14 +70,14 @@ public void testArrayType() { StdType stdArrayType = AvroWrapper.createStdType(arraySchema); assertTrue(stdArrayType instanceof AvroArrayType); - assertEquals(arraySchema, stdArrayType.underlyingType()); + assertEquals(arraySchema, ((AvroArrayType) stdArrayType).underlyingType()); assertEquals(elementType, ((AvroArrayType) stdArrayType).elementType().underlyingType()); GenericArray value = new GenericData.Array<>(arraySchema, Arrays.asList(1, 2)); - StdData stdArrayData = AvroWrapper.createStdData(value, arraySchema); - assertTrue(stdArrayData instanceof AvroArray); - assertEquals(2, ((AvroArray) stdArrayData).size()); - assertEquals(value, ((AvroArray) stdArrayData).getUnderlyingData()); + Object stdArrayData = AvroWrapper.createStdData(value, arraySchema); + assertTrue(stdArrayData instanceof AvroArrayData); + assertEquals(2, ((AvroArrayData) stdArrayData).size()); + assertEquals(value, ((AvroArrayData) stdArrayData).getUnderlyingData()); } @Test @@ -163,10 +91,10 @@ public void testMapType() { assertEquals(valueType, ((AvroMapType) stdMapType).valueType().underlyingType()); Map value = ImmutableMap.of("foo", 1L, "bar", 2L); - StdData stdMapData = AvroWrapper.createStdData(value, mapSchema); - assertTrue(stdMapData instanceof AvroMap); - assertEquals(2, ((AvroMap) stdMapData).size()); - assertEquals(value, ((AvroMap) stdMapData).getUnderlyingData()); + Object stdMapData = AvroWrapper.createStdData(value, mapSchema); + assertTrue(stdMapData instanceof AvroMapData); + assertEquals(2, ((AvroMapData) stdMapData).size()); + assertEquals(value, ((AvroMapData) stdMapData).getUnderlyingData()); } @Test @@ -179,21 +107,21 @@ public void testRecordType() { )); StdType stdStructType = AvroWrapper.createStdType(structSchema); - assertTrue(stdStructType instanceof AvroStructType); + assertTrue(stdStructType instanceof AvroRowType); assertEquals(structSchema, stdStructType.underlyingType()); - assertEquals(field1, ((AvroStructType) stdStructType).fieldTypes().get(0).underlyingType()); - assertEquals(field2, ((AvroStructType) stdStructType).fieldTypes().get(1).underlyingType()); + assertEquals(field1, ((AvroRowType) stdStructType).fieldTypes().get(0).underlyingType()); + assertEquals(field2, ((AvroRowType) stdStructType).fieldTypes().get(1).underlyingType()); GenericRecord value = new GenericData.Record(structSchema); value.put("field1", 1); value.put("field2", 2.0); - StdData stdStructData = AvroWrapper.createStdData(value, structSchema); - assertTrue(stdStructData instanceof AvroStruct); - AvroStruct avroStruct = (AvroStruct) stdStructData; + Object stdStructData = AvroWrapper.createStdData(value, structSchema); + assertTrue(stdStructData instanceof AvroRowData); + AvroRowData avroStruct = (AvroRowData) stdStructData; assertEquals(2, avroStruct.fields().size()); assertEquals(value, avroStruct.getUnderlyingData()); - assertEquals(1, ((PlatformData) avroStruct.getField("field1")).getUnderlyingData()); - assertEquals(2.0, ((PlatformData) avroStruct.getField("field2")).getUnderlyingData()); + assertEquals(1, avroStruct.getField("field1")); + assertEquals(2.0, avroStruct.getField("field2")); } @Test @@ -205,11 +133,11 @@ public void testValidUnionType() { assertTrue(stdLongType instanceof AvroLongType); assertEquals(nonNullType, stdLongType.underlyingType()); - StdData stdLongData = AvroWrapper.createStdData(1L, unionSchema); - assertTrue(stdLongData instanceof AvroLong); - assertEquals(1L, ((AvroLong) stdLongData).get()); + Object stdLongData = AvroWrapper.createStdData(1L, unionSchema); + assertTrue(stdLongData instanceof Long); + assertEquals(1L, stdLongData); - StdData stdNullData = AvroWrapper.createStdData(null, unionSchema); + Object stdNullData = AvroWrapper.createStdData(null, unionSchema); assertNull(stdNullData); } @@ -242,21 +170,21 @@ public void testStructWithSimpleUnionField() { GenericRecord record1 = new GenericData.Record(structSchema); record1.put("field1", 1); record1.put("field2", 3.0); - AvroStruct avroStruct1 = (AvroStruct) AvroWrapper.createStdData(record1, structSchema); + AvroRowData avroStruct1 = (AvroRowData) AvroWrapper.createStdData(record1, structSchema); assertEquals(2, avroStruct1.fields().size()); - assertEquals(3.0, ((PlatformData) avroStruct1.getField("field2")).getUnderlyingData()); + assertEquals(3.0, avroStruct1.getField("field2")); GenericRecord record2 = new GenericData.Record(structSchema); record2.put("field1", 1); record2.put("field2", null); - AvroStruct avroStruct2 = (AvroStruct) AvroWrapper.createStdData(record2, structSchema); + AvroRowData avroStruct2 = (AvroRowData) AvroWrapper.createStdData(record2, structSchema); assertEquals(2, avroStruct2.fields().size()); assertNull(avroStruct2.getField("field2")); assertNull(avroStruct2.fields().get(1)); GenericRecord record3 = new GenericData.Record(structSchema); record3.put("field1", 1); - AvroStruct avroStruct3 = (AvroStruct) AvroWrapper.createStdData(record3, structSchema); + AvroRowData avroStruct3 = (AvroRowData) AvroWrapper.createStdData(record3, structSchema); assertEquals(2, avroStruct3.fields().size()); assertNull(avroStruct3.getField("field2")); assertNull(avroStruct3.fields().get(1)); diff --git a/transportable-udfs-codegen/src/main/java/com/linkedin/transport/codegen/SparkWrapperGenerator.java b/transportable-udfs-codegen/src/main/java/com/linkedin/transport/codegen/SparkWrapperGenerator.java index bec08d8e..1bc19ded 100644 --- a/transportable-udfs-codegen/src/main/java/com/linkedin/transport/codegen/SparkWrapperGenerator.java +++ b/transportable-udfs-codegen/src/main/java/com/linkedin/transport/codegen/SparkWrapperGenerator.java @@ -11,7 +11,9 @@ import java.io.File; import java.io.IOException; import java.io.InputStream; +import java.util.Arrays; import java.util.Collection; +import java.util.Map; import java.util.stream.Collectors; import org.apache.commons.io.IOUtils; import org.apache.commons.text.StringSubstitutor; @@ -23,6 +25,7 @@ public class SparkWrapperGenerator implements WrapperGenerator { private static final String SPARK_WRAPPER_TEMPLATE_RESOURCE_PATH = "wrapper-templates/spark"; private static final String SUBSTITUTOR_KEY_WRAPPER_PACKAGE = "wrapperPackage"; private static final String SUBSTITUTOR_KEY_WRAPPER_CLASS = "wrapperClass"; + private static final String SUBSTITUTOR_KEY_WRAPPER_CLASS_PARAMERTERS = "wrapperClassParameters"; private static final String SUBSTITUTOR_KEY_UDF_TOP_LEVEL_CLASS = "udfTopLevelClass"; private static final String SUBSTITUTOR_KEY_UDF_IMPLEMENTATIONS = "udfImplementations"; @@ -30,12 +33,16 @@ public class SparkWrapperGenerator implements WrapperGenerator { public void generateWrappers(WrapperGeneratorContext context) { TransportUDFMetadata udfMetadata = context.getTransportUdfMetadata(); for (String topLevelClass : udfMetadata.getTopLevelClasses()) { - generateWrapper(topLevelClass, udfMetadata.getStdUDFImplementations(topLevelClass), + generateWrapper( + topLevelClass, + udfMetadata.getStdUDFImplementations(topLevelClass), + udfMetadata.getClassToNumberOfTypeParameters(), context.getSourcesOutputDir()); } } - private void generateWrapper(String topLevelClass, Collection implementationClasses, File outputDir) { + private void generateWrapper(String topLevelClass, Collection implementationClasses, + Map classToNumberOfTypeParameters, File outputDir) { final String wrapperTemplate; try (InputStream wrapperTemplateStream = Thread.currentThread() .getContextClassLoader() @@ -49,13 +56,15 @@ private void generateWrapper(String topLevelClass, Collection implementa ClassName wrapperClass = ClassName.get(topLevelClassName.packageName() + "." + SPARK_PACKAGE_SUFFIX, topLevelClassName.simpleName()); String udfImplementationInstantiations = implementationClasses.stream() - .map(clazz -> "new " + clazz + "()") + .map(clazz -> "new " + clazz + parameters(clazz, classToNumberOfTypeParameters) + "()") .collect(Collectors.joining(", ")); + String topLevelClassNameString = topLevelClassName.toString(); ImmutableMap substitutionMap = ImmutableMap.of( SUBSTITUTOR_KEY_WRAPPER_PACKAGE, wrapperClass.packageName(), SUBSTITUTOR_KEY_WRAPPER_CLASS, wrapperClass.simpleName(), - SUBSTITUTOR_KEY_UDF_TOP_LEVEL_CLASS, topLevelClassName.toString(), + SUBSTITUTOR_KEY_UDF_TOP_LEVEL_CLASS, topLevelClassNameString + + parameters(topLevelClassNameString, classToNumberOfTypeParameters), SUBSTITUTOR_KEY_UDF_IMPLEMENTATIONS, udfImplementationInstantiations ); @@ -69,4 +78,11 @@ private void generateWrapper(String topLevelClass, Collection implementa throw new RuntimeException("Error writing wrapper to file", e); } } + + private static String parameters(String clazz, Map classToNumberOfTypeParameters) { + int numberOfTypeParameters = classToNumberOfTypeParameters.get(clazz); + String[] objectTypes = new String[numberOfTypeParameters]; + Arrays.fill(objectTypes, "Object"); + return numberOfTypeParameters > 0 ? "[" + String.join(", ", objectTypes) + "]" : ""; + } } diff --git a/transportable-udfs-codegen/src/test/resources/inputs/sample-udf-metadata.json b/transportable-udfs-codegen/src/test/resources/inputs/sample-udf-metadata.json index 4d6fb8ae..6da584f2 100644 --- a/transportable-udfs-codegen/src/test/resources/inputs/sample-udf-metadata.json +++ b/transportable-udfs-codegen/src/test/resources/inputs/sample-udf-metadata.json @@ -1,17 +1,17 @@ { - "udfs": [ - { - "topLevelClass": "udfs.OverloadedUDF", - "stdUDFImplementations": [ - "udfs.OverloadedUDFInt", - "udfs.OverloadedUDFString" - ] - }, - { - "topLevelClass": "udfs.SimpleUDF", - "stdUDFImplementations": [ - "udfs.SimpleUDF" - ] - } - ] -} \ No newline at end of file + "udfs": { + "udfs.OverloadedUDF": [ + "udfs.OverloadedUDFInt", + "udfs.OverloadedUDFString" + ], + "udfs.SimpleUDF": [ + "udfs.SimpleUDF" + ] + }, + "classToNumberOfTypeParameters": { + "udfs.OverloadedUDFString": 0, + "udfs.OverloadedUDF": 0, + "udfs.OverloadedUDFInt": 0, + "udfs.SimpleUDF": 0 + } +} diff --git a/transportable-udfs-compile-utils/src/main/java/com/linkedin/transport/compile/TransportUDFMetadata.java b/transportable-udfs-compile-utils/src/main/java/com/linkedin/transport/compile/TransportUDFMetadata.java index 48db80f5..c10cc44b 100644 --- a/transportable-udfs-compile-utils/src/main/java/com/linkedin/transport/compile/TransportUDFMetadata.java +++ b/transportable-udfs-compile-utils/src/main/java/com/linkedin/transport/compile/TransportUDFMetadata.java @@ -9,14 +9,18 @@ import com.google.common.collect.Multimap; import com.google.gson.Gson; import com.google.gson.GsonBuilder; +import com.google.gson.JsonArray; +import com.google.gson.JsonElement; +import com.google.gson.JsonObject; +import com.google.gson.JsonParser; import java.io.File; import java.io.FileReader; import java.io.IOException; import java.io.Reader; import java.io.Writer; import java.util.Collection; -import java.util.LinkedList; -import java.util.List; +import java.util.HashMap; +import java.util.Map; import java.util.Set; @@ -26,6 +30,7 @@ public class TransportUDFMetadata { private static final Gson GSON; private Multimap _udfs; + private Map _classToNumberOfTypeParameters; static { GSON = new GsonBuilder().setPrettyPrinting().create(); @@ -33,14 +38,15 @@ public class TransportUDFMetadata { public TransportUDFMetadata() { _udfs = LinkedHashMultimap.create(); + _classToNumberOfTypeParameters = new HashMap<>(); } public void addUDF(String topLevelClass, String stdUDFImplementation) { _udfs.put(topLevelClass, stdUDFImplementation); } - public void addUDF(String topLevelClass, Collection stdUDFImplementations) { - _udfs.putAll(topLevelClass, stdUDFImplementations); + public void setClassNumberOfTypeParameters(String clazz, int numberOfTypeParameters) { + _classToNumberOfTypeParameters.put(clazz, numberOfTypeParameters); } public Set getTopLevelClasses() { @@ -51,8 +57,12 @@ public Collection getStdUDFImplementations(String topLevelClass) { return _udfs.get(topLevelClass); } + public Map getClassToNumberOfTypeParameters() { + return _classToNumberOfTypeParameters; + } + public void toJson(Writer writer) { - GSON.toJson(TransportUDFMetadataSerDe.fromUDFMetadata(this), writer); + GSON.toJson(TransportUDFMetadataSerDe.serialize(this), writer); } public static TransportUDFMetadata fromJsonFile(File jsonFile) { @@ -64,50 +74,49 @@ public static TransportUDFMetadata fromJsonFile(File jsonFile) { } public static TransportUDFMetadata fromJson(Reader reader) { - return TransportUDFMetadataSerDe.toUDFMetadata(GSON.fromJson(reader, TransportUDFMetadataJson.class)); + return TransportUDFMetadataSerDe.deserialize(new JsonParser().parse(reader)); } - /** - * Represents the JSON object structure of the Transport UDF metadata resource file - */ - private static class TransportUDFMetadataJson { - private List udfs; + private static class TransportUDFMetadataSerDe { - TransportUDFMetadataJson() { - this.udfs = new LinkedList<>(); + public static TransportUDFMetadata deserialize(JsonElement json) { + TransportUDFMetadata metadata = new TransportUDFMetadata(); + JsonObject root = json.getAsJsonObject(); + + // Deserialize udfs + JsonObject udfs = root.getAsJsonObject("udfs"); + udfs.keySet().forEach(topLevelClass -> { + JsonArray stdUdfImplementations = udfs.getAsJsonArray(topLevelClass); + for (int i = 0; i < stdUdfImplementations.size(); i++) { + metadata.addUDF(topLevelClass, stdUdfImplementations.get(i).getAsString()); + } + }); + + // Deserialize classToNumberOfTypeParameters + JsonObject classToNumberOfTypeParameters = root.getAsJsonObject("classToNumberOfTypeParameters"); + classToNumberOfTypeParameters.entrySet().forEach( + e -> metadata.setClassNumberOfTypeParameters(e.getKey(), e.getValue().getAsInt()) + ); + return metadata; } - static class UDFInfo { - private String topLevelClass; - private Collection stdUDFImplementations; - - UDFInfo(String topLevelClass, Collection stdUDFImplementations) { - this.topLevelClass = topLevelClass; - this.stdUDFImplementations = stdUDFImplementations; + public static JsonElement serialize(TransportUDFMetadata metadata) { + // Serialzie _udfs + JsonObject udfs = new JsonObject(); + for (Map.Entry> entry : metadata._udfs.asMap().entrySet()) { + JsonArray stdUdfImplementations = new JsonArray(); + entry.getValue().forEach(f -> stdUdfImplementations.add(f)); + udfs.add(entry.getKey(), stdUdfImplementations); } - } - } - /** - * Converts objects between {@link TransportUDFMetadata} and {@link TransportUDFMetadataJson} - */ - private static class TransportUDFMetadataSerDe { - - private static TransportUDFMetadataJson fromUDFMetadata(TransportUDFMetadata metadata) { - TransportUDFMetadataJson metadataJson = new TransportUDFMetadataJson(); - for (String topLevelClass : metadata.getTopLevelClasses()) { - metadataJson.udfs.add( - new TransportUDFMetadataJson.UDFInfo(topLevelClass, metadata.getStdUDFImplementations(topLevelClass))); - } - return metadataJson; - } + // Serialize _classToNumberOfTypeParameters + JsonObject classToNumberOfTypeParameters = new JsonObject(); + metadata._classToNumberOfTypeParameters.forEach((clazz, n) -> classToNumberOfTypeParameters.addProperty(clazz, n)); - private static TransportUDFMetadata toUDFMetadata(TransportUDFMetadataJson metadataJson) { - TransportUDFMetadata metadata = new TransportUDFMetadata(); - for (TransportUDFMetadataJson.UDFInfo udf : metadataJson.udfs) { - metadata.addUDF(udf.topLevelClass, udf.stdUDFImplementations); - } - return metadata; + JsonObject root = new JsonObject(); + root.add("udfs", udfs); + root.add("classToNumberOfTypeParameters", classToNumberOfTypeParameters); + return root; } } } diff --git a/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/ArrayElementAtFunction.java b/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/ArrayElementAtFunction.java index 2697f8db..e9ee2b22 100644 --- a/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/ArrayElementAtFunction.java +++ b/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/ArrayElementAtFunction.java @@ -6,15 +6,26 @@ package com.linkedin.transport.examples; import com.google.common.collect.ImmutableList; -import com.linkedin.transport.api.data.StdArray; -import com.linkedin.transport.api.data.StdData; -import com.linkedin.transport.api.data.StdInteger; +import com.linkedin.transport.api.data.ArrayData; import com.linkedin.transport.api.udf.StdUDF2; import com.linkedin.transport.api.udf.TopLevelStdUDF; import java.util.List; -public class ArrayElementAtFunction extends StdUDF2 implements TopLevelStdUDF { +/** + * Another way to define this class using generics can look like this + * + * public class ArrayElementAtFunction extends StdUDF2, Integer, K> implements TopLevelStdUDF { + * + * @Override + * public K eval(ArrayData a1, Integer idx) { + * return a1.get(idx); + * } + * + * } + * + */ +public class ArrayElementAtFunction extends StdUDF2 implements TopLevelStdUDF { @Override public String getFunctionName() { @@ -40,7 +51,7 @@ public String getOutputParameterSignature() { } @Override - public StdData eval(StdArray a1, StdInteger idx) { - return a1.get(idx.get()); + public Object eval(ArrayData a1, Integer idx) { + return a1.get(idx); } } diff --git a/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/ArrayFillFunction.java b/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/ArrayFillFunction.java index ae5a9ac1..a9cff404 100644 --- a/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/ArrayFillFunction.java +++ b/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/ArrayFillFunction.java @@ -7,16 +7,14 @@ import com.google.common.collect.ImmutableList; import com.linkedin.transport.api.StdFactory; -import com.linkedin.transport.api.data.StdArray; -import com.linkedin.transport.api.data.StdData; -import com.linkedin.transport.api.data.StdLong; +import com.linkedin.transport.api.data.ArrayData; import com.linkedin.transport.api.types.StdType; import com.linkedin.transport.api.udf.StdUDF2; import com.linkedin.transport.api.udf.TopLevelStdUDF; import java.util.List; -public class ArrayFillFunction extends StdUDF2 implements TopLevelStdUDF { +public class ArrayFillFunction extends StdUDF2> implements TopLevelStdUDF { private StdType _arrayType; @@ -40,9 +38,9 @@ public void init(StdFactory stdFactory) { } @Override - public StdArray eval(StdData a, StdLong length) { - StdArray array = getStdFactory().createArray(_arrayType); - for (int i = 0; i < length.get(); i++) { + public ArrayData eval(K a, Long length) { + ArrayData array = getStdFactory().createArray(_arrayType); + for (int i = 0; i < length; i++) { array.add(a); } return array; diff --git a/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/BinaryDuplicateFunction.java b/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/BinaryDuplicateFunction.java index 26a63111..8252b816 100644 --- a/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/BinaryDuplicateFunction.java +++ b/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/BinaryDuplicateFunction.java @@ -6,24 +6,22 @@ package com.linkedin.transport.examples; import com.google.common.collect.ImmutableList; -import com.linkedin.transport.api.data.StdBinary; import com.linkedin.transport.api.udf.StdUDF1; import com.linkedin.transport.api.udf.TopLevelStdUDF; import java.nio.ByteBuffer; import java.util.List; -public class BinaryDuplicateFunction extends StdUDF1 implements TopLevelStdUDF { +public class BinaryDuplicateFunction extends StdUDF1 implements TopLevelStdUDF { @Override - public StdBinary eval(StdBinary binaryObject) { - ByteBuffer byteBuffer = binaryObject.get(); + public ByteBuffer eval(ByteBuffer byteBuffer) { ByteBuffer results = ByteBuffer.allocate(2 * byteBuffer.array().length); for (int i = 0; i < 2; i++) { for (byte b : byteBuffer.array()) { results.put(b); } } - return getStdFactory().createBinary(results); + return results; } @Override diff --git a/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/BinaryObjectSizeFunction.java b/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/BinaryObjectSizeFunction.java index 0f4b538a..39b56cd4 100644 --- a/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/BinaryObjectSizeFunction.java +++ b/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/BinaryObjectSizeFunction.java @@ -6,17 +6,16 @@ package com.linkedin.transport.examples; import com.google.common.collect.ImmutableList; -import com.linkedin.transport.api.data.StdBinary; -import com.linkedin.transport.api.data.StdInteger; import com.linkedin.transport.api.udf.StdUDF1; import com.linkedin.transport.api.udf.TopLevelStdUDF; +import java.nio.ByteBuffer; import java.util.List; -public class BinaryObjectSizeFunction extends StdUDF1 implements TopLevelStdUDF { +public class BinaryObjectSizeFunction extends StdUDF1 implements TopLevelStdUDF { @Override - public StdInteger eval(StdBinary binaryObject) { - return getStdFactory().createInteger(binaryObject.get().array().length); + public Integer eval(ByteBuffer byteBuffer) { + return byteBuffer.array().length; } @Override diff --git a/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/FileLookupFunction.java b/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/FileLookupFunction.java index 8112e443..e9ed378f 100644 --- a/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/FileLookupFunction.java +++ b/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/FileLookupFunction.java @@ -7,9 +7,6 @@ import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; -import com.linkedin.transport.api.data.StdBoolean; -import com.linkedin.transport.api.data.StdInteger; -import com.linkedin.transport.api.data.StdString; import com.linkedin.transport.api.udf.StdUDF2; import com.linkedin.transport.api.udf.TopLevelStdUDF; import java.io.BufferedReader; @@ -21,14 +18,14 @@ import org.apache.commons.io.IOUtils; -public class FileLookupFunction extends StdUDF2 implements TopLevelStdUDF { +public class FileLookupFunction extends StdUDF2 implements TopLevelStdUDF { private Set ids; @Override - public StdBoolean eval(StdString filename, StdInteger intToCheck) { + public Boolean eval(String filename, Integer intToCheck) { Preconditions.checkNotNull(intToCheck, "Integer to check should not be null"); - return getStdFactory().createBoolean(ids.contains(intToCheck.get())); + return ids.contains(intToCheck); } @Override @@ -57,8 +54,8 @@ public String getFunctionDescription() { } @Override - public String[] getRequiredFiles(StdString filename, StdInteger intToCheck) { - return new String[]{filename.get()}; + public String[] getRequiredFiles(String filename, Integer intToCheck) { + return new String[]{filename}; } public void processRequiredFiles(String[] localPaths) { diff --git a/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/MapFromTwoArraysFunction.java b/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/MapFromTwoArraysFunction.java index 6fd99981..d6002a50 100644 --- a/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/MapFromTwoArraysFunction.java +++ b/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/MapFromTwoArraysFunction.java @@ -7,15 +7,16 @@ import com.google.common.collect.ImmutableList; import com.linkedin.transport.api.StdFactory; -import com.linkedin.transport.api.data.StdArray; -import com.linkedin.transport.api.data.StdMap; +import com.linkedin.transport.api.data.ArrayData; +import com.linkedin.transport.api.data.MapData; import com.linkedin.transport.api.types.StdType; import com.linkedin.transport.api.udf.StdUDF2; import com.linkedin.transport.api.udf.TopLevelStdUDF; import java.util.List; -public class MapFromTwoArraysFunction extends StdUDF2 implements TopLevelStdUDF { +public class MapFromTwoArraysFunction extends StdUDF2, ArrayData, MapData> + implements TopLevelStdUDF { private StdType _mapType; @@ -35,16 +36,16 @@ public String getOutputParameterSignature() { @Override public void init(StdFactory stdFactory) { super.init(stdFactory); - // Note: we create the _mapType once in init() and then reuse it to create StdMap objects + // Note: we create the _mapType once in init() and then reuse it to create MapData objects _mapType = getStdFactory().createStdType(getOutputParameterSignature()); } @Override - public StdMap eval(StdArray a1, StdArray a2) { + public MapData eval(ArrayData a1, ArrayData a2) { if (a1.size() != a2.size()) { return null; } - StdMap map = getStdFactory().createMap(_mapType); + MapData map = getStdFactory().createMap(_mapType); for (int i = 0; i < a1.size(); i++) { map.put(a1.get(i), a2.get(i)); } diff --git a/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/MapKeySetFunction.java b/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/MapKeySetFunction.java index a76c4403..af81e024 100644 --- a/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/MapKeySetFunction.java +++ b/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/MapKeySetFunction.java @@ -7,16 +7,15 @@ import com.google.common.collect.ImmutableList; import com.linkedin.transport.api.StdFactory; -import com.linkedin.transport.api.data.StdArray; -import com.linkedin.transport.api.data.StdData; -import com.linkedin.transport.api.data.StdMap; +import com.linkedin.transport.api.data.ArrayData; +import com.linkedin.transport.api.data.MapData; import com.linkedin.transport.api.types.StdType; import com.linkedin.transport.api.udf.StdUDF1; import com.linkedin.transport.api.udf.TopLevelStdUDF; import java.util.List; -public class MapKeySetFunction extends StdUDF1 implements TopLevelStdUDF { +public class MapKeySetFunction extends StdUDF1, ArrayData> implements TopLevelStdUDF { private StdType _mapType; @@ -39,9 +38,9 @@ public void init(StdFactory stdFactory) { } @Override - public StdArray eval(StdMap map) { - StdArray result = getStdFactory().createArray(_mapType); - for (StdData key : map.keySet()) { + public ArrayData eval(MapData map) { + ArrayData result = getStdFactory().createArray(_mapType); + for (K key : map.keySet()) { result.add(key); } return result; diff --git a/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/MapValuesFunction.java b/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/MapValuesFunction.java index f22ff7f7..82b34ef1 100644 --- a/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/MapValuesFunction.java +++ b/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/MapValuesFunction.java @@ -7,16 +7,15 @@ import com.google.common.collect.ImmutableList; import com.linkedin.transport.api.StdFactory; -import com.linkedin.transport.api.data.StdArray; -import com.linkedin.transport.api.data.StdData; -import com.linkedin.transport.api.data.StdMap; +import com.linkedin.transport.api.data.ArrayData; +import com.linkedin.transport.api.data.MapData; import com.linkedin.transport.api.types.StdType; import com.linkedin.transport.api.udf.StdUDF1; import com.linkedin.transport.api.udf.TopLevelStdUDF; import java.util.List; -public class MapValuesFunction extends StdUDF1 implements TopLevelStdUDF { +public class MapValuesFunction extends StdUDF1, ArrayData> implements TopLevelStdUDF { private StdType _mapType; @@ -39,9 +38,9 @@ public void init(StdFactory stdFactory) { } @Override - public StdArray eval(StdMap map) { - StdArray result = getStdFactory().createArray(_mapType); - for (StdData value : map.values()) { + public ArrayData eval(MapData map) { + ArrayData result = getStdFactory().createArray(_mapType); + for (V value : map.values()) { result.add(value); } return result; diff --git a/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/NestedMapFromTwoArraysFunction.java b/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/NestedMapFromTwoArraysFunction.java index 986bcda0..5e244a11 100644 --- a/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/NestedMapFromTwoArraysFunction.java +++ b/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/NestedMapFromTwoArraysFunction.java @@ -7,16 +7,16 @@ import com.google.common.collect.ImmutableList; import com.linkedin.transport.api.StdFactory; -import com.linkedin.transport.api.data.StdArray; -import com.linkedin.transport.api.data.StdMap; -import com.linkedin.transport.api.data.StdStruct; +import com.linkedin.transport.api.data.ArrayData; +import com.linkedin.transport.api.data.MapData; +import com.linkedin.transport.api.data.RowData; import com.linkedin.transport.api.types.StdType; import com.linkedin.transport.api.udf.StdUDF1; import com.linkedin.transport.api.udf.TopLevelStdUDF; import java.util.List; -public class NestedMapFromTwoArraysFunction extends StdUDF1 implements TopLevelStdUDF { +public class NestedMapFromTwoArraysFunction extends StdUDF1 implements TopLevelStdUDF { private StdType _arrayType; private StdType _mapType; @@ -43,31 +43,31 @@ public void init(StdFactory stdFactory) { } @Override - public StdArray eval(StdArray a1) { - StdArray result = getStdFactory().createArray(_arrayType); + public ArrayData eval(ArrayData a1) { + ArrayData result = getStdFactory().createArray(_arrayType); for (int i = 0; i < a1.size(); i++) { if (a1.get(i) == null) { return null; } - StdStruct inputRow = (StdStruct) a1.get(i); + RowData inputRow = (RowData) a1.get(i); if (inputRow.getField(0) == null || inputRow.getField(1) == null) { return null; } - StdArray kValues = (StdArray) inputRow.getField(0); - StdArray vValues = (StdArray) inputRow.getField(1); + ArrayData kValues = (ArrayData) inputRow.getField(0); + ArrayData vValues = (ArrayData) inputRow.getField(1); if (kValues.size() != vValues.size()) { return null; } - StdMap map = getStdFactory().createMap(_mapType); + MapData map = getStdFactory().createMap(_mapType); for (int j = 0; j < kValues.size(); j++) { map.put(kValues.get(j), vValues.get(j)); } - StdStruct outputRow = getStdFactory().createStruct(_rowType); + RowData outputRow = getStdFactory().createStruct(_rowType); outputRow.setField(0, map); result.add(outputRow); diff --git a/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/NumericAddDoubleFunction.java b/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/NumericAddDoubleFunction.java index 6ee9c918..80b0fb2b 100644 --- a/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/NumericAddDoubleFunction.java +++ b/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/NumericAddDoubleFunction.java @@ -6,15 +6,14 @@ package com.linkedin.transport.examples; import com.google.common.collect.ImmutableList; -import com.linkedin.transport.api.data.StdDouble; import com.linkedin.transport.api.udf.StdUDF2; import java.util.List; -public class NumericAddDoubleFunction extends StdUDF2 implements NumericAddFunction { +public class NumericAddDoubleFunction extends StdUDF2 implements NumericAddFunction { @Override - public StdDouble eval(StdDouble first, StdDouble second) { - return getStdFactory().createDouble(first.get() + second.get()); + public Double eval(Double first, Double second) { + return first + second; } @Override diff --git a/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/NumericAddFloatFunction.java b/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/NumericAddFloatFunction.java index 643b558b..a2a0ab47 100644 --- a/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/NumericAddFloatFunction.java +++ b/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/NumericAddFloatFunction.java @@ -6,15 +6,14 @@ package com.linkedin.transport.examples; import com.google.common.collect.ImmutableList; -import com.linkedin.transport.api.data.StdFloat; import com.linkedin.transport.api.udf.StdUDF2; import java.util.List; -public class NumericAddFloatFunction extends StdUDF2 implements NumericAddFunction { +public class NumericAddFloatFunction extends StdUDF2 implements NumericAddFunction { @Override - public StdFloat eval(StdFloat first, StdFloat second) { - return getStdFactory().createFloat(first.get() + second.get()); + public Float eval(Float first, Float second) { + return first + second; } @Override diff --git a/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/NumericAddIntFunction.java b/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/NumericAddIntFunction.java index cc5fb900..bcdb3696 100644 --- a/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/NumericAddIntFunction.java +++ b/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/NumericAddIntFunction.java @@ -6,16 +6,15 @@ package com.linkedin.transport.examples; import com.google.common.collect.ImmutableList; -import com.linkedin.transport.api.data.StdInteger; import com.linkedin.transport.api.udf.StdUDF2; import java.util.List; -public class NumericAddIntFunction extends StdUDF2 +public class NumericAddIntFunction extends StdUDF2 implements NumericAddFunction { @Override - public StdInteger eval(StdInteger first, StdInteger second) { - return getStdFactory().createInteger(first.get() + second.get()); + public Integer eval(Integer first, Integer second) { + return first + second; } @Override diff --git a/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/NumericAddLongFunction.java b/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/NumericAddLongFunction.java index a530e586..c24d2148 100644 --- a/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/NumericAddLongFunction.java +++ b/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/NumericAddLongFunction.java @@ -6,15 +6,14 @@ package com.linkedin.transport.examples; import com.google.common.collect.ImmutableList; -import com.linkedin.transport.api.data.StdLong; import com.linkedin.transport.api.udf.StdUDF2; import java.util.List; -public class NumericAddLongFunction extends StdUDF2 implements NumericAddFunction { +public class NumericAddLongFunction extends StdUDF2 implements NumericAddFunction { @Override - public StdLong eval(StdLong first, StdLong second) { - return getStdFactory().createLong(first.get() + second.get()); + public Long eval(Long first, Long second) { + return first + second; } @Override diff --git a/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/StructCreateByIndexFunction.java b/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/StructCreateByIndexFunction.java index 5a23283d..ffa78ba7 100644 --- a/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/StructCreateByIndexFunction.java +++ b/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/StructCreateByIndexFunction.java @@ -7,15 +7,14 @@ import com.google.common.collect.ImmutableList; import com.linkedin.transport.api.StdFactory; -import com.linkedin.transport.api.data.StdData; -import com.linkedin.transport.api.data.StdStruct; +import com.linkedin.transport.api.data.RowData; import com.linkedin.transport.api.types.StdType; import com.linkedin.transport.api.udf.StdUDF2; import com.linkedin.transport.api.udf.TopLevelStdUDF; import java.util.List; -public class StructCreateByIndexFunction extends StdUDF2 implements TopLevelStdUDF { +public class StructCreateByIndexFunction extends StdUDF2 implements TopLevelStdUDF { private StdType _field1Type; private StdType _field2Type; @@ -41,11 +40,11 @@ public void init(StdFactory stdFactory) { } @Override - public StdStruct eval(StdData field1Value, StdData field2Value) { - StdStruct struct = getStdFactory().createStruct(ImmutableList.of(_field1Type, _field2Type)); - struct.setField(0, field1Value); - struct.setField(1, field2Value); - return struct; + public RowData eval(Object field1Value, Object field2Value) { + RowData row = getStdFactory().createStruct(ImmutableList.of(_field1Type, _field2Type)); + row.setField(0, field1Value); + row.setField(1, field2Value); + return row; } @Override diff --git a/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/StructCreateByNameFunction.java b/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/StructCreateByNameFunction.java index 36ca3472..b4f2a0c0 100644 --- a/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/StructCreateByNameFunction.java +++ b/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/StructCreateByNameFunction.java @@ -7,16 +7,14 @@ import com.google.common.collect.ImmutableList; import com.linkedin.transport.api.StdFactory; -import com.linkedin.transport.api.data.StdData; -import com.linkedin.transport.api.data.StdString; -import com.linkedin.transport.api.data.StdStruct; +import com.linkedin.transport.api.data.RowData; import com.linkedin.transport.api.types.StdType; import com.linkedin.transport.api.udf.StdUDF4; import com.linkedin.transport.api.udf.TopLevelStdUDF; import java.util.List; -public class StructCreateByNameFunction extends StdUDF4 implements TopLevelStdUDF { +public class StructCreateByNameFunction extends StdUDF4 implements TopLevelStdUDF { private StdType _field1Type; private StdType _field2Type; @@ -44,13 +42,13 @@ public void init(StdFactory stdFactory) { } @Override - public StdStruct eval(StdString field1Name, StdData field1Value, StdString field2Name, StdData field2Value) { - StdStruct struct = getStdFactory().createStruct( - ImmutableList.of(field1Name.get(), field2Name.get()), + public RowData eval(String field1Name, Object field1Value, String field2Name, Object field2Value) { + RowData struct = getStdFactory().createStruct( + ImmutableList.of(field1Name, field2Name), ImmutableList.of(_field1Type, _field2Type) ); - struct.setField(field1Name.get(), field1Value); - struct.setField(field2Name.get(), field2Value); + struct.setField(field1Name, field1Value); + struct.setField(field2Name, field2Value); return struct; } diff --git a/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/HiveFactory.java b/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/HiveFactory.java index e0373b63..c058e56b 100644 --- a/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/HiveFactory.java +++ b/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/HiveFactory.java @@ -5,34 +5,18 @@ */ package com.linkedin.transport.hive; -import com.google.common.base.Preconditions; import com.linkedin.transport.api.StdFactory; -import com.linkedin.transport.api.data.StdArray; -import com.linkedin.transport.api.data.StdBoolean; -import com.linkedin.transport.api.data.StdBinary; -import com.linkedin.transport.api.data.StdDouble; -import com.linkedin.transport.api.data.StdFloat; -import com.linkedin.transport.api.data.StdInteger; -import com.linkedin.transport.api.data.StdLong; -import com.linkedin.transport.api.data.StdMap; -import com.linkedin.transport.api.data.StdString; -import com.linkedin.transport.api.data.StdStruct; +import com.linkedin.transport.api.data.ArrayData; +import com.linkedin.transport.api.data.MapData; +import com.linkedin.transport.api.data.RowData; import com.linkedin.transport.api.types.StdType; -import com.linkedin.transport.hive.data.HiveArray; -import com.linkedin.transport.hive.data.HiveBoolean; -import com.linkedin.transport.hive.data.HiveBinary; -import com.linkedin.transport.hive.data.HiveDouble; -import com.linkedin.transport.hive.data.HiveFloat; -import com.linkedin.transport.hive.data.HiveInteger; -import com.linkedin.transport.hive.data.HiveLong; -import com.linkedin.transport.hive.data.HiveMap; -import com.linkedin.transport.hive.data.HiveString; -import com.linkedin.transport.hive.data.HiveStruct; +import com.linkedin.transport.hive.data.HiveArrayData; +import com.linkedin.transport.hive.data.HiveMapData; +import com.linkedin.transport.hive.data.HiveRowData; import com.linkedin.transport.hive.types.objectinspector.CacheableObjectInspectorConverters; import com.linkedin.transport.hive.typesystem.HiveTypeFactory; import com.linkedin.transport.typesystem.AbstractBoundVariables; import com.linkedin.transport.typesystem.TypeSignature; -import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; @@ -45,7 +29,6 @@ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters.Converter; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; public class HiveFactory implements StdFactory { @@ -61,59 +44,23 @@ public HiveFactory(AbstractBoundVariables boundVariables) { } @Override - public StdInteger createInteger(int value) { - return new HiveInteger(value, PrimitiveObjectInspectorFactory.javaIntObjectInspector, this); - } - - @Override - public StdLong createLong(long value) { - return new HiveLong(value, PrimitiveObjectInspectorFactory.javaLongObjectInspector, this); - } - - @Override - public StdBoolean createBoolean(boolean value) { - return new HiveBoolean(value, PrimitiveObjectInspectorFactory.javaBooleanObjectInspector, this); - } - - @Override - public StdString createString(String value) { - Preconditions.checkNotNull(value, "Cannot create a null StdString"); - return new HiveString(value, PrimitiveObjectInspectorFactory.javaStringObjectInspector, this); - } - - @Override - public StdFloat createFloat(float value) { - return new HiveFloat(value, PrimitiveObjectInspectorFactory.javaFloatObjectInspector, this); - } - - @Override - public StdDouble createDouble(double value) { - return new HiveDouble(value, PrimitiveObjectInspectorFactory.javaDoubleObjectInspector, this); - } - - @Override - public StdBinary createBinary(ByteBuffer value) { - return new HiveBinary(value.array(), PrimitiveObjectInspectorFactory.javaByteArrayObjectInspector, this); - } - - @Override - public StdArray createArray(StdType stdType, int expectedSize) { + public ArrayData createArray(StdType stdType, int expectedSize) { ListObjectInspector listObjectInspector = (ListObjectInspector) stdType.underlyingType(); - return new HiveArray( + return new HiveArrayData( new ArrayList(expectedSize), ObjectInspectorFactory.getStandardListObjectInspector(listObjectInspector.getListElementObjectInspector()), this); } @Override - public StdArray createArray(StdType stdType) { + public ArrayData createArray(StdType stdType) { return createArray(stdType, 0); } @Override - public StdMap createMap(StdType stdType) { + public MapData createMap(StdType stdType) { MapObjectInspector mapObjectInspector = (MapObjectInspector) stdType.underlyingType(); - return new HiveMap( + return new HiveMapData( new HashMap(), ObjectInspectorFactory.getStandardMapObjectInspector( mapObjectInspector.getMapKeyObjectInspector(), @@ -122,8 +69,8 @@ public StdMap createMap(StdType stdType) { } @Override - public StdStruct createStruct(List fieldNames, List fieldTypes) { - return new HiveStruct( + public RowData createStruct(List fieldNames, List fieldTypes) { + return new HiveRowData( new ArrayList(Arrays.asList(new Object[fieldTypes.size()])), ObjectInspectorFactory.getStandardStructObjectInspector( fieldNames, @@ -133,16 +80,16 @@ public StdStruct createStruct(List fieldNames, List fieldTypes) } @Override - public StdStruct createStruct(List fieldTypes) { + public RowData createStruct(List fieldTypes) { List fieldNames = IntStream.range(0, fieldTypes.size()).mapToObj(i -> "field" + i).collect(Collectors.toList()); return createStruct(fieldNames, fieldTypes); } @Override - public StdStruct createStruct(StdType stdType) { + public RowData createStruct(StdType stdType) { StructObjectInspector structObjectInspector = (StructObjectInspector) stdType.underlyingType(); - return new HiveStruct( + return new HiveRowData( new ArrayList(Arrays.asList(new Object[structObjectInspector.getAllStructFieldRefs().size()])), ObjectInspectorFactory.getStandardStructObjectInspector( structObjectInspector.getAllStructFieldRefs() diff --git a/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/HiveWrapper.java b/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/HiveWrapper.java index 3b9daa43..2b06d7a4 100644 --- a/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/HiveWrapper.java +++ b/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/HiveWrapper.java @@ -6,18 +6,11 @@ package com.linkedin.transport.hive; import com.linkedin.transport.api.StdFactory; -import com.linkedin.transport.api.data.StdData; import com.linkedin.transport.api.types.StdType; -import com.linkedin.transport.hive.data.HiveArray; -import com.linkedin.transport.hive.data.HiveBoolean; -import com.linkedin.transport.hive.data.HiveBinary; -import com.linkedin.transport.hive.data.HiveDouble; -import com.linkedin.transport.hive.data.HiveFloat; -import com.linkedin.transport.hive.data.HiveInteger; -import com.linkedin.transport.hive.data.HiveLong; -import com.linkedin.transport.hive.data.HiveMap; -import com.linkedin.transport.hive.data.HiveString; -import com.linkedin.transport.hive.data.HiveStruct; +import com.linkedin.transport.hive.data.HiveArrayData; +import com.linkedin.transport.hive.data.HiveData; +import com.linkedin.transport.hive.data.HiveMapData; +import com.linkedin.transport.hive.data.HiveRowData; import com.linkedin.transport.hive.types.HiveArrayType; import com.linkedin.transport.hive.types.HiveBooleanType; import com.linkedin.transport.hive.types.HiveBinaryType; @@ -27,11 +20,13 @@ import com.linkedin.transport.hive.types.HiveLongType; import com.linkedin.transport.hive.types.HiveMapType; import com.linkedin.transport.hive.types.HiveStringType; -import com.linkedin.transport.hive.types.HiveStructType; +import com.linkedin.transport.hive.types.HiveRowType; import com.linkedin.transport.hive.types.HiveUnknownType; +import java.nio.ByteBuffer; import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.BinaryObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.BooleanObjectInspector; @@ -39,6 +34,14 @@ import org.apache.hadoop.hive.serde2.objectinspector.primitive.FloatObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.IntObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.LongObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.SettableBinaryObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.SettableBooleanObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.SettableDoubleObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.SettableFloatObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.SettableIntObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.SettableLongObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.SettableStringObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.VoidObjectInspector; @@ -48,28 +51,23 @@ public final class HiveWrapper { private HiveWrapper() { } - public static StdData createStdData(Object hiveData, ObjectInspector hiveObjectInspector, StdFactory stdFactory) { - if (hiveObjectInspector instanceof IntObjectInspector) { - return new HiveInteger(hiveData, (IntObjectInspector) hiveObjectInspector, stdFactory); - } else if (hiveObjectInspector instanceof LongObjectInspector) { - return new HiveLong(hiveData, (LongObjectInspector) hiveObjectInspector, stdFactory); - } else if (hiveObjectInspector instanceof BooleanObjectInspector) { - return new HiveBoolean(hiveData, (BooleanObjectInspector) hiveObjectInspector, stdFactory); - } else if (hiveObjectInspector instanceof StringObjectInspector) { - return new HiveString(hiveData, (StringObjectInspector) hiveObjectInspector, stdFactory); - } else if (hiveObjectInspector instanceof FloatObjectInspector) { - return new HiveFloat(hiveData, (FloatObjectInspector) hiveObjectInspector, stdFactory); - } else if (hiveObjectInspector instanceof DoubleObjectInspector) { - return new HiveDouble(hiveData, (DoubleObjectInspector) hiveObjectInspector, stdFactory); + public static Object createStdData(Object hiveData, ObjectInspector hiveObjectInspector, StdFactory stdFactory) { + if (hiveObjectInspector instanceof IntObjectInspector || hiveObjectInspector instanceof LongObjectInspector + || hiveObjectInspector instanceof FloatObjectInspector || hiveObjectInspector instanceof DoubleObjectInspector + || hiveObjectInspector instanceof BooleanObjectInspector + || hiveObjectInspector instanceof StringObjectInspector) { + return ((PrimitiveObjectInspector) hiveObjectInspector).getPrimitiveJavaObject(hiveData); } else if (hiveObjectInspector instanceof BinaryObjectInspector) { - return new HiveBinary(hiveData, (BinaryObjectInspector) hiveObjectInspector, stdFactory); + BinaryObjectInspector binaryObjectInspector = (BinaryObjectInspector) hiveObjectInspector; + return hiveData == null ? null : ByteBuffer.wrap(binaryObjectInspector.getPrimitiveJavaObject(hiveData)); } else if (hiveObjectInspector instanceof ListObjectInspector) { ListObjectInspector listObjectInspector = (ListObjectInspector) hiveObjectInspector; - return new HiveArray(hiveData, listObjectInspector, stdFactory); + return new HiveArrayData(hiveData, listObjectInspector, stdFactory); } else if (hiveObjectInspector instanceof MapObjectInspector) { - return new HiveMap(hiveData, hiveObjectInspector, stdFactory); + return new HiveMapData(hiveData, hiveObjectInspector, stdFactory); } else if (hiveObjectInspector instanceof StructObjectInspector) { - return new HiveStruct(hiveData, hiveObjectInspector, stdFactory); + return new HiveRowData(((StructObjectInspector) hiveObjectInspector).getStructFieldsDataAsList(hiveData).toArray(), + hiveObjectInspector, stdFactory); } else if (hiveObjectInspector instanceof VoidObjectInspector) { return null; } @@ -97,11 +95,57 @@ public static StdType createStdType(ObjectInspector hiveObjectInspector) { } else if (hiveObjectInspector instanceof MapObjectInspector) { return new HiveMapType((MapObjectInspector) hiveObjectInspector); } else if (hiveObjectInspector instanceof StructObjectInspector) { - return new HiveStructType((StructObjectInspector) hiveObjectInspector); + return new HiveRowType((StructObjectInspector) hiveObjectInspector); } else if (hiveObjectInspector instanceof VoidObjectInspector) { return new HiveUnknownType((VoidObjectInspector) hiveObjectInspector); } assert false : "Unrecognized Hive ObjectInspector: " + hiveObjectInspector.getClass(); return null; } + + public static Object getPlatformDataForObjectInspector(Object transportData, ObjectInspector oi) { + if (transportData == null) { + return null; + } else if (oi instanceof IntObjectInspector) { + return ((SettableIntObjectInspector) oi).create((Integer) transportData); + } else if (oi instanceof LongObjectInspector) { + return ((SettableLongObjectInspector) oi).create((Long) transportData); + } else if (oi instanceof FloatObjectInspector) { + return ((SettableFloatObjectInspector) oi).create((Float) transportData); + } else if (oi instanceof DoubleObjectInspector) { + return ((SettableDoubleObjectInspector) oi).create((Double) transportData); + } else if (oi instanceof BooleanObjectInspector) { + return ((SettableBooleanObjectInspector) oi).create((Boolean) transportData); + } else if (oi instanceof StringObjectInspector) { + return ((SettableStringObjectInspector) oi).create((String) transportData); + } else if (oi instanceof BinaryObjectInspector) { + return ((SettableBinaryObjectInspector) oi).create(((ByteBuffer) transportData).array()); + } else { + return ((HiveData) transportData).getUnderlyingDataForObjectInspector(oi); + } + } + + public static Object getStandardObject(Object transportData) { + if (transportData == null) { + return null; + } else if (transportData instanceof Integer) { + return PrimitiveObjectInspectorFactory.writableIntObjectInspector.create((Integer) transportData); + } else if (transportData instanceof Long) { + return PrimitiveObjectInspectorFactory.writableLongObjectInspector.create((Long) transportData); + } else if (transportData instanceof Float) { + return PrimitiveObjectInspectorFactory.writableFloatObjectInspector.create((Float) transportData); + } else if (transportData instanceof Double) { + return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector.create((Double) transportData); + } else if (transportData instanceof Boolean) { + return PrimitiveObjectInspectorFactory.writableBooleanObjectInspector.create((Boolean) transportData); + } else if (transportData instanceof String) { + return PrimitiveObjectInspectorFactory.writableStringObjectInspector.create((String) transportData); + } else if (transportData instanceof ByteBuffer) { + return PrimitiveObjectInspectorFactory.writableBinaryObjectInspector.create(((ByteBuffer) transportData).array()); + } else { + return ((HiveData) transportData).getUnderlyingDataForObjectInspector( + ((HiveData) transportData).getUnderlyingObjectInspector() + ); + } + } } diff --git a/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/StdUdfWrapper.java b/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/StdUdfWrapper.java index bfa5cb6d..b932bf77 100644 --- a/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/StdUdfWrapper.java +++ b/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/StdUdfWrapper.java @@ -7,7 +7,6 @@ import com.linkedin.transport.api.StdFactory; import com.linkedin.transport.api.data.PlatformData; -import com.linkedin.transport.api.data.StdData; import com.linkedin.transport.api.udf.StdUDF; import com.linkedin.transport.api.udf.StdUDF0; import com.linkedin.transport.api.udf.StdUDF1; @@ -23,6 +22,7 @@ import com.linkedin.transport.utils.FileSystemUtils; import java.io.FileNotFoundException; import java.io.IOException; +import java.nio.ByteBuffer; import java.util.Arrays; import java.util.List; import java.util.stream.IntStream; @@ -35,7 +35,8 @@ import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; import org.apache.hadoop.hive.serde2.objectinspector.ConstantObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; - +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.BinaryObjectInspector; /** * Base class for all Hive Standard UDFs. It provides a standard way of type validation, binding, and output type @@ -49,7 +50,8 @@ public abstract class StdUdfWrapper extends GenericUDF { protected StdFactory _stdFactory; private boolean[] _nullableArguments; private String[] _distributedCacheFiles; - private StdData[] _args; + private Object[] _args; + private ObjectInspector _outputObjectInspector; /** * Given input object inspectors, this method matches them to the expected type signatures, and finds bindings to the @@ -70,7 +72,8 @@ public ObjectInspector initialize(ObjectInspector[] arguments) { _stdUdf.init(_stdFactory); _requiredFilesProcessed = false; createStdData(); - return hiveTypeInference.getOutputDataType(); + _outputObjectInspector = hiveTypeInference.getOutputDataType(); + return _outputObjectInspector; } @Override @@ -108,14 +111,23 @@ protected boolean containsNullValuedNonNullableConstants() { return false; } - protected StdData wrap(DeferredObject hiveDeferredObject, StdData stdData) { + protected Object wrap(DeferredObject hiveDeferredObject, ObjectInspector inputObjectInspector, Object stdData) { try { Object hiveObject = hiveDeferredObject.get(); - if (hiveObject != null) { - ((PlatformData) stdData).setUnderlyingData(hiveObject); - return stdData; + if (inputObjectInspector instanceof BinaryObjectInspector) { + return hiveObject == null ? null : ByteBuffer.wrap( + ((BinaryObjectInspector) inputObjectInspector).getPrimitiveJavaObject(hiveObject) + ); + } + if (inputObjectInspector instanceof PrimitiveObjectInspector) { + return ((PrimitiveObjectInspector) inputObjectInspector).getPrimitiveJavaObject(hiveObject); } else { - return null; + if (hiveObject != null) { + ((PlatformData) stdData).setUnderlyingData(hiveObject); + return stdData; + } else { + return null; + } } } catch (HiveException e) { throw new RuntimeException("Cannot extract Hive Object from Deferred Object"); @@ -127,21 +139,35 @@ protected StdData wrap(DeferredObject hiveDeferredObject, StdData stdData) { protected abstract Class getTopLevelUdfClass(); protected void createStdData() { - _args = new StdData[_inputObjectInspectors.length]; + _args = new Object[_inputObjectInspectors.length]; for (int i = 0; i < _inputObjectInspectors.length; i++) { _args[i] = HiveWrapper.createStdData(null, _inputObjectInspectors[i], _stdFactory); } } - private StdData[] wrapArguments(DeferredObject[] deferredObjects) { - return IntStream.range(0, _args.length).mapToObj(i -> wrap(deferredObjects[i], _args[i])).toArray(StdData[]::new); + private Object getPlatformData(Object transportData) { + if (transportData == null) { + return null; + } else if (transportData instanceof Integer || transportData instanceof Long || transportData instanceof Boolean + || transportData instanceof String || transportData instanceof Float || transportData instanceof Double + || transportData instanceof ByteBuffer) { + return HiveWrapper.getPlatformDataForObjectInspector(transportData, _outputObjectInspector); + } else { + return ((PlatformData) transportData).getUnderlyingData(); + } + } + + private Object[] wrapArguments(DeferredObject[] deferredObjects) { + return IntStream.range(0, _args.length).mapToObj( + i -> wrap(deferredObjects[i], _inputObjectInspectors[i], _args[i]) + ).toArray(Object[]::new); } - private StdData[] wrapConstants() { + private Object[] wrapConstants() { return Arrays.stream(_inputObjectInspectors) .map(oi -> (oi instanceof ConstantObjectInspector) ? HiveWrapper.createStdData( ((ConstantObjectInspector) oi).getWritableConstantValue(), oi, _stdFactory) : null) - .toArray(StdData[]::new); + .toArray(Object[]::new); } @Override @@ -152,8 +178,8 @@ public Object evaluate(DeferredObject[] arguments) throws HiveException { if (!_requiredFilesProcessed) { processRequiredFiles(); } - StdData[] args = wrapArguments(arguments); - StdData result; + Object[] args = wrapArguments(arguments); + Object result; switch (args.length) { case 0: result = ((StdUDF0) _stdUdf).eval(); @@ -185,7 +211,7 @@ public Object evaluate(DeferredObject[] arguments) throws HiveException { default: throw new UnsupportedOperationException("eval not yet supported for StdUDF" + args.length); } - return result == null ? null : ((PlatformData) result).getUnderlyingData(); + return getPlatformData(result); } @Override @@ -193,7 +219,7 @@ public String[] getRequiredFiles() { if (containsNullValuedNonNullableConstants()) { return new String[]{}; } - StdData[] args = wrapConstants(); + Object[] args = wrapConstants(); String[] requiredFiles; switch (args.length) { case 0: diff --git a/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveArray.java b/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveArrayData.java similarity index 72% rename from transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveArray.java rename to transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveArrayData.java index 57cb0e8c..d0bf8ec4 100644 --- a/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveArray.java +++ b/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveArrayData.java @@ -6,8 +6,7 @@ package com.linkedin.transport.hive.data; import com.linkedin.transport.api.StdFactory; -import com.linkedin.transport.api.data.StdArray; -import com.linkedin.transport.api.data.StdData; +import com.linkedin.transport.api.data.ArrayData; import com.linkedin.transport.hive.HiveWrapper; import java.util.Iterator; import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; @@ -15,12 +14,12 @@ import org.apache.hadoop.hive.serde2.objectinspector.SettableListObjectInspector; -public class HiveArray extends HiveData implements StdArray { +public class HiveArrayData extends HiveData implements ArrayData { final ListObjectInspector _listObjectInspector; final ObjectInspector _elementObjectInspector; - public HiveArray(Object object, ObjectInspector objectInspector, StdFactory stdFactory) { + public HiveArrayData(Object object, ObjectInspector objectInspector, StdFactory stdFactory) { super(stdFactory); _object = object; _listObjectInspector = (ListObjectInspector) objectInspector; @@ -33,19 +32,21 @@ public int size() { } @Override - public StdData get(int idx) { - return HiveWrapper.createStdData(_listObjectInspector.getListElement(_object, idx), _elementObjectInspector, + public E get(int idx) { + return (E) HiveWrapper.createStdData( + _listObjectInspector.getListElement(_object, idx), + _elementObjectInspector, _stdFactory); } @Override - public void add(StdData e) { + public void add(E e) { if (_listObjectInspector instanceof SettableListObjectInspector) { SettableListObjectInspector settableListObjectInspector = (SettableListObjectInspector) _listObjectInspector; int originalSize = size(); settableListObjectInspector.resize(_object, originalSize + 1); settableListObjectInspector.set(_object, originalSize, - ((HiveData) e).getUnderlyingDataForObjectInspector(_elementObjectInspector)); + HiveWrapper.getPlatformDataForObjectInspector(e, _elementObjectInspector)); _isObjectModified = true; } else { throw new RuntimeException("Attempt to modify an immutable Hive object of type: " @@ -59,8 +60,8 @@ public ObjectInspector getUnderlyingObjectInspector() { } @Override - public Iterator iterator() { - return new Iterator() { + public Iterator iterator() { + return new Iterator() { int size = size(); int currentIndex = 0; @@ -70,8 +71,8 @@ public boolean hasNext() { } @Override - public StdData next() { - StdData element = HiveWrapper.createStdData(_listObjectInspector.getListElement(_object, currentIndex), + public E next() { + E element = (E) HiveWrapper.createStdData(_listObjectInspector.getListElement(_object, currentIndex), _elementObjectInspector, _stdFactory); currentIndex++; return element; diff --git a/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveBinary.java b/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveBinary.java deleted file mode 100644 index c5c14e40..00000000 --- a/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveBinary.java +++ /dev/null @@ -1,34 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.hive.data; - -import com.linkedin.transport.api.StdFactory; -import com.linkedin.transport.api.data.StdBinary; -import java.nio.ByteBuffer; -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.BinaryObjectInspector; - - -public class HiveBinary extends HiveData implements StdBinary { - - private final BinaryObjectInspector _binaryObjectInspector; - - public HiveBinary(Object object, BinaryObjectInspector binaryObjectInspector, StdFactory stdFactory) { - super(stdFactory); - _object = object; - _binaryObjectInspector = binaryObjectInspector; - } - - @Override - public ByteBuffer get() { - return ByteBuffer.wrap(_binaryObjectInspector.getPrimitiveJavaObject(_object)); - } - - @Override - public ObjectInspector getUnderlyingObjectInspector() { - return _binaryObjectInspector; - } -} diff --git a/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveBoolean.java b/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveBoolean.java deleted file mode 100644 index b4537170..00000000 --- a/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveBoolean.java +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.hive.data; - -import com.linkedin.transport.api.StdFactory; -import com.linkedin.transport.api.data.StdBoolean; -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.BooleanObjectInspector; - - -public class HiveBoolean extends HiveData implements StdBoolean { - - final BooleanObjectInspector _booleanObjectInspector; - - public HiveBoolean(Object object, BooleanObjectInspector booleanObjectInspector, StdFactory stdFactory) { - super(stdFactory); - _object = object; - _booleanObjectInspector = booleanObjectInspector; - } - - @Override - public boolean get() { - return _booleanObjectInspector.get(_object); - } - - @Override - public ObjectInspector getUnderlyingObjectInspector() { - return _booleanObjectInspector; - } -} diff --git a/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveData.java b/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveData.java index 51beb456..94337266 100644 --- a/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveData.java +++ b/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveData.java @@ -59,10 +59,6 @@ public ObjectInspector getStandardObjectInspector() { getUnderlyingObjectInspector(), ObjectInspectorUtils.ObjectInspectorCopyOption.WRITABLE); } - public Object getStandardObject() { - return getUnderlyingDataForObjectInspector(getStandardObjectInspector()); - } - private Object getObjectFromCache(ObjectInspector oi) { if (_isObjectModified) { _cachedObjectsForObjectInspectors.clear(); diff --git a/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveDouble.java b/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveDouble.java deleted file mode 100644 index e5447f00..00000000 --- a/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveDouble.java +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.hive.data; - -import com.linkedin.transport.api.StdFactory; -import com.linkedin.transport.api.data.StdDouble; -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.DoubleObjectInspector; - - -public class HiveDouble extends HiveData implements StdDouble { - - private final DoubleObjectInspector _doubleObjectInspector; - - public HiveDouble(Object object, DoubleObjectInspector floatObjectInspector, StdFactory stdFactory) { - super(stdFactory); - _object = object; - _doubleObjectInspector = floatObjectInspector; - } - - @Override - public double get() { - return _doubleObjectInspector.get(_object); - } - - @Override - public ObjectInspector getUnderlyingObjectInspector() { - return _doubleObjectInspector; - } -} diff --git a/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveFloat.java b/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveFloat.java deleted file mode 100644 index a630d73b..00000000 --- a/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveFloat.java +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.hive.data; - -import com.linkedin.transport.api.StdFactory; -import com.linkedin.transport.api.data.StdFloat; -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.FloatObjectInspector; - - -public class HiveFloat extends HiveData implements StdFloat { - - private final FloatObjectInspector _floatObjectInspector; - - public HiveFloat(Object object, FloatObjectInspector floatObjectInspector, StdFactory stdFactory) { - super(stdFactory); - _object = object; - _floatObjectInspector = floatObjectInspector; - } - - @Override - public float get() { - return _floatObjectInspector.get(_object); - } - - @Override - public ObjectInspector getUnderlyingObjectInspector() { - return _floatObjectInspector; - } -} diff --git a/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveInteger.java b/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveInteger.java deleted file mode 100644 index a1d2e38f..00000000 --- a/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveInteger.java +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.hive.data; - -import com.linkedin.transport.api.StdFactory; -import com.linkedin.transport.api.data.StdInteger; -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.IntObjectInspector; - - -public class HiveInteger extends HiveData implements StdInteger { - - final IntObjectInspector _intObjectInspector; - - public HiveInteger(Object object, IntObjectInspector intObjectInspector, StdFactory stdFactory) { - super(stdFactory); - _object = object; - _intObjectInspector = intObjectInspector; - } - - @Override - public int get() { - return _intObjectInspector.get(_object); - } - - @Override - public ObjectInspector getUnderlyingObjectInspector() { - return _intObjectInspector; - } -} diff --git a/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveLong.java b/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveLong.java deleted file mode 100644 index 0b662b59..00000000 --- a/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveLong.java +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.hive.data; - -import com.linkedin.transport.api.StdFactory; -import com.linkedin.transport.api.data.StdLong; -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.LongObjectInspector; - - -public class HiveLong extends HiveData implements StdLong { - - final LongObjectInspector _longObjectInspector; - - public HiveLong(Object object, LongObjectInspector longObjectInspector, StdFactory stdFactory) { - super(stdFactory); - _object = object; - _longObjectInspector = longObjectInspector; - } - - @Override - public long get() { - return _longObjectInspector.get(_object); - } - - @Override - public ObjectInspector getUnderlyingObjectInspector() { - return _longObjectInspector; - } -} diff --git a/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveMap.java b/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveMapData.java similarity index 66% rename from transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveMap.java rename to transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveMapData.java index 70f5132b..54da6042 100644 --- a/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveMap.java +++ b/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveMapData.java @@ -6,8 +6,7 @@ package com.linkedin.transport.hive.data; import com.linkedin.transport.api.StdFactory; -import com.linkedin.transport.api.data.StdData; -import com.linkedin.transport.api.data.StdMap; +import com.linkedin.transport.api.data.MapData; import com.linkedin.transport.hive.HiveWrapper; import java.util.AbstractCollection; import java.util.AbstractSet; @@ -20,13 +19,13 @@ import org.apache.hadoop.hive.serde2.objectinspector.SettableMapObjectInspector; -public class HiveMap extends HiveData implements StdMap { +public class HiveMapData extends HiveData implements MapData { final MapObjectInspector _mapObjectInspector; final ObjectInspector _keyObjectInspector; final ObjectInspector _valueObjectInspector; - public HiveMap(Object object, ObjectInspector objectInspector, StdFactory stdFactory) { + public HiveMapData(Object object, ObjectInspector objectInspector, StdFactory stdFactory) { super(stdFactory); _object = object; _mapObjectInspector = (MapObjectInspector) objectInspector; @@ -40,30 +39,30 @@ public int size() { } @Override - public StdData get(StdData key) { + public V get(K key) { MapObjectInspector mapOI = _mapObjectInspector; Object mapObj = _object; Object keyObj; try { - keyObj = ((HiveData) key).getUnderlyingDataForObjectInspector(_keyObjectInspector); + keyObj = HiveWrapper.getPlatformDataForObjectInspector(key, _keyObjectInspector); } catch (RuntimeException e) { // Cannot convert key argument to Map's KeyOI. So convert both the map and the key arg to // objects having standard OIs mapOI = (MapObjectInspector) getStandardObjectInspector(); - mapObj = getStandardObject(); - keyObj = ((HiveData) key).getStandardObject(); + mapObj = HiveWrapper.getStandardObject(this); + keyObj = HiveWrapper.getStandardObject(key); } - return HiveWrapper.createStdData( + return (V) HiveWrapper.createStdData( mapOI.getMapValueElement(mapObj, keyObj), mapOI.getMapValueObjectInspector(), _stdFactory); } @Override - public void put(StdData key, StdData value) { + public void put(K key, V value) { if (_mapObjectInspector instanceof SettableMapObjectInspector) { - Object keyObj = ((HiveData) key).getUnderlyingDataForObjectInspector(_keyObjectInspector); - Object valueObj = ((HiveData) value).getUnderlyingDataForObjectInspector(_valueObjectInspector); + Object keyObj = HiveWrapper.getPlatformDataForObjectInspector(key, _keyObjectInspector); + Object valueObj = HiveWrapper.getPlatformDataForObjectInspector(value, _valueObjectInspector); ((SettableMapObjectInspector) _mapObjectInspector).put( _object, @@ -79,11 +78,11 @@ public void put(StdData key, StdData value) { //TODO: Cache the result of .getMap(_object) below for subsequent calls. @Override - public Set keySet() { - return new AbstractSet() { + public Set keySet() { + return new AbstractSet() { @Override - public Iterator iterator() { - return new Iterator() { + public Iterator iterator() { + return new Iterator() { Iterator mapKeyIterator = _mapObjectInspector.getMap(_object).keySet().iterator(); @Override @@ -92,26 +91,26 @@ public boolean hasNext() { } @Override - public StdData next() { - return HiveWrapper.createStdData(mapKeyIterator.next(), _keyObjectInspector, _stdFactory); + public K next() { + return (K) HiveWrapper.createStdData(mapKeyIterator.next(), _keyObjectInspector, _stdFactory); } }; } @Override public int size() { - return HiveMap.this.size(); + return HiveMapData.this.size(); } }; } //TODO: Cache the result of .getMap(_object) below for subsequent calls. @Override - public Collection values() { - return new AbstractCollection() { + public Collection values() { + return new AbstractCollection() { @Override - public Iterator iterator() { - return new Iterator() { + public Iterator iterator() { + return new Iterator() { Iterator mapValueIterator = _mapObjectInspector.getMap(_object).values().iterator(); @Override @@ -120,30 +119,30 @@ public boolean hasNext() { } @Override - public StdData next() { - return HiveWrapper.createStdData(mapValueIterator.next(), _valueObjectInspector, _stdFactory); + public V next() { + return (V) HiveWrapper.createStdData(mapValueIterator.next(), _valueObjectInspector, _stdFactory); } }; } @Override public int size() { - return HiveMap.this.size(); + return HiveMapData.this.size(); } }; } @Override - public boolean containsKey(StdData key) { + public boolean containsKey(K key) { Object mapObj = _object; Object keyObj; try { - keyObj = ((HiveData) key).getUnderlyingDataForObjectInspector(_keyObjectInspector); + keyObj = HiveWrapper.getPlatformDataForObjectInspector(key, _keyObjectInspector); } catch (RuntimeException e) { // Cannot convert key argument to Map's KeyOI. So convertboth the map and the key arg to // objects having standard OIs - mapObj = getStandardObject(); - keyObj = ((HiveData) key).getStandardObject(); + mapObj = HiveWrapper.getStandardObject(this); + keyObj = HiveWrapper.getStandardObject(key); } return ((Map) mapObj).containsKey(keyObj); diff --git a/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveStruct.java b/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveRowData.java similarity index 79% rename from transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveStruct.java rename to transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveRowData.java index 80872eff..5704374e 100644 --- a/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveStruct.java +++ b/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveRowData.java @@ -6,8 +6,7 @@ package com.linkedin.transport.hive.data; import com.linkedin.transport.api.StdFactory; -import com.linkedin.transport.api.data.StdData; -import com.linkedin.transport.api.data.StdStruct; +import com.linkedin.transport.api.data.RowData; import com.linkedin.transport.hive.HiveWrapper; import java.util.List; import java.util.stream.Collectors; @@ -18,18 +17,18 @@ import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; -public class HiveStruct extends HiveData implements StdStruct { +public class HiveRowData extends HiveData implements RowData { StructObjectInspector _structObjectInspector; - public HiveStruct(Object object, ObjectInspector objectInspector, StdFactory stdFactory) { + public HiveRowData(Object object, ObjectInspector objectInspector, StdFactory stdFactory) { super(stdFactory); _object = object; _structObjectInspector = (StructObjectInspector) objectInspector; } @Override - public StdData getField(int index) { + public Object getField(int index) { StructField structField = _structObjectInspector.getAllStructFieldRefs().get(index); return HiveWrapper.createStdData( _structObjectInspector.getStructFieldData(_object, structField), @@ -38,7 +37,7 @@ public StdData getField(int index) { } @Override - public StdData getField(String name) { + public Object getField(String name) { StructField structField = _structObjectInspector.getStructFieldRef(name); return HiveWrapper.createStdData( _structObjectInspector.getStructFieldData(_object, structField), @@ -47,11 +46,11 @@ public StdData getField(String name) { } @Override - public void setField(int index, StdData value) { + public void setField(int index, Object value) { if (_structObjectInspector instanceof SettableStructObjectInspector) { StructField field = _structObjectInspector.getAllStructFieldRefs().get(index); ((SettableStructObjectInspector) _structObjectInspector).setStructFieldData(_object, - field, ((HiveData) value).getUnderlyingDataForObjectInspector(field.getFieldObjectInspector()) + field, HiveWrapper.getPlatformDataForObjectInspector(value, field.getFieldObjectInspector()) ); _isObjectModified = true; } else { @@ -61,11 +60,11 @@ public void setField(int index, StdData value) { } @Override - public void setField(String name, StdData value) { + public void setField(String name, Object value) { if (_structObjectInspector instanceof SettableStructObjectInspector) { StructField field = _structObjectInspector.getStructFieldRef(name); ((SettableStructObjectInspector) _structObjectInspector).setStructFieldData(_object, - field, ((HiveData) value).getUnderlyingDataForObjectInspector(field.getFieldObjectInspector())); + field, HiveWrapper.getPlatformDataForObjectInspector(value, field.getFieldObjectInspector())); _isObjectModified = true; } else { throw new RuntimeException("Attempt to modify an immutable Hive object of type: " @@ -74,7 +73,7 @@ public void setField(String name, StdData value) { } @Override - public List fields() { + public List fields() { return IntStream.range(0, _structObjectInspector.getAllStructFieldRefs().size()).mapToObj(i -> getField(i)) .collect(Collectors.toList()); } diff --git a/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveString.java b/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveString.java deleted file mode 100644 index 83310309..00000000 --- a/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveString.java +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.hive.data; - -import com.linkedin.transport.api.StdFactory; -import com.linkedin.transport.api.data.StdString; -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector; - - -public class HiveString extends HiveData implements StdString { - - final StringObjectInspector _stringObjectInspector; - - public HiveString(Object object, StringObjectInspector stringObjectInspector, StdFactory stdFactory) { - super(stdFactory); - _object = object; - _stringObjectInspector = stringObjectInspector; - } - - @Override - public String get() { - return _stringObjectInspector.getPrimitiveJavaObject(_object); - } - - @Override - public ObjectInspector getUnderlyingObjectInspector() { - return _stringObjectInspector; - } -} diff --git a/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/types/HiveStructType.java b/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/types/HiveRowType.java similarity index 83% rename from transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/types/HiveStructType.java rename to transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/types/HiveRowType.java index f4393776..c9ceef43 100644 --- a/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/types/HiveStructType.java +++ b/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/types/HiveRowType.java @@ -5,7 +5,7 @@ */ package com.linkedin.transport.hive.types; -import com.linkedin.transport.api.types.StdStructType; +import com.linkedin.transport.api.types.RowType; import com.linkedin.transport.api.types.StdType; import com.linkedin.transport.hive.HiveWrapper; import java.util.List; @@ -13,11 +13,11 @@ import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; -public class HiveStructType implements StdStructType { +public class HiveRowType implements RowType { final StructObjectInspector _structObjectInspector; - public HiveStructType(StructObjectInspector structObjectInspector) { + public HiveRowType(StructObjectInspector structObjectInspector) { _structObjectInspector = structObjectInspector; } diff --git a/transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/SparkFactory.scala b/transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/SparkFactory.scala index 07b61ba2..87d5625d 100644 --- a/transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/SparkFactory.scala +++ b/transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/SparkFactory.scala @@ -8,58 +8,36 @@ package com.linkedin.transport.spark import java.nio.ByteBuffer import java.util.{List => JavaList} -import com.google.common.base.Preconditions import com.linkedin.transport.api.StdFactory import com.linkedin.transport.api.data._ import com.linkedin.transport.api.types.StdType import com.linkedin.transport.spark.data._ import com.linkedin.transport.spark.typesystem.SparkTypeFactory import com.linkedin.transport.typesystem.{AbstractBoundVariables, TypeSignature} -import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructField, StructType} class SparkFactory(private val _boundVariables: AbstractBoundVariables[DataType]) extends StdFactory { private val _sparkTypeFactory: SparkTypeFactory = new SparkTypeFactory - override def createInteger(value: Int): StdInteger = SparkInteger(value) - - override def createLong(value: Long): StdLong = SparkLong(value) - - override def createBoolean(value: Boolean): StdBoolean = SparkBoolean(value) - - override def createString(value: String): StdString = { - Preconditions.checkNotNull(value, "Cannot create a null StdString".asInstanceOf[Any]) - SparkString(UTF8String.fromString(value)) - } - - override def createFloat(value: Float): StdFloat = SparkFloat(value) - - override def createDouble(value: Double): StdDouble = SparkDouble(value) - - override def createBinary(value: ByteBuffer): StdBinary = { - Preconditions.checkNotNull(value, "Cannot create a null StdBinary".asInstanceOf[Any]) - SparkBinary(value.array()) - } - - override def createArray(stdType: StdType): StdArray = createArray(stdType, 0) + override def createArray(stdType: StdType): ArrayData[_] = createArray(stdType, 0) // we do not pass size to `new Array()` as the size argument of createArray is supposed to be just a hint about - // the expected number of entries in the StdArray. `new Array(size)` will create an array with null entries - override def createArray(stdType: StdType, size: Int): StdArray = SparkArray( + // the expected number of entries in the ArrayData. `new Array(size)` will create an array with null entries + override def createArray(stdType: StdType, size: Int): ArrayData[_] = SparkArrayData( null, stdType.underlyingType().asInstanceOf[ArrayType] ) - override def createMap(stdType: StdType): StdMap = SparkMap( + override def createMap(stdType: StdType): MapData[_, _] = SparkMapData( //TODO: make these as separate mutable standard spark types null, stdType.underlyingType().asInstanceOf[MapType] ) - override def createStruct(fieldTypes: JavaList[StdType]): StdStruct = { + override def createStruct(fieldTypes: JavaList[StdType]): RowData = { createStruct(null, fieldTypes) } - override def createStruct(fieldNames: JavaList[String], fieldTypes: JavaList[StdType]): StdStruct = { + override def createStruct(fieldNames: JavaList[String], fieldTypes: JavaList[StdType]): RowData = { val structFields = new Array[StructField](fieldTypes.size()) (0 until fieldTypes.size()).foreach({ idx => { @@ -69,13 +47,13 @@ class SparkFactory(private val _boundVariables: AbstractBoundVariables[DataType] ) } }) - SparkStruct(null, StructType(structFields)) + SparkRowData(null, StructType(structFields)) } - override def createStruct(stdType: StdType): StdStruct = { + override def createStruct(stdType: StdType): RowData = { //TODO: make these as separate mutable standard spark types val structType: StructType = stdType.underlyingType().asInstanceOf[StructType] - SparkStruct(null, structType) + SparkRowData(null, structType) } override def createStdType(typeSignature: String): StdType = SparkWrapper.createStdType( diff --git a/transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/SparkWrapper.scala b/transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/SparkWrapper.scala index 29e935db..b365f716 100644 --- a/transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/SparkWrapper.scala +++ b/transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/SparkWrapper.scala @@ -7,38 +7,58 @@ package com.linkedin.transport.spark import java.nio.ByteBuffer -import com.linkedin.transport.api.data.StdData +import com.linkedin.transport.api.data.PlatformData import com.linkedin.transport.api.types.StdType import com.linkedin.transport.spark.data._ import com.linkedin.transport.spark.types._ import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String object SparkWrapper { - def createStdData(data: Any, dataType: DataType): StdData = { // scalastyle:ignore cyclomatic.complexity + def createStdData(data: Any, dataType: DataType): Object = { // scalastyle:ignore cyclomatic.complexity if (data == null) { null } else { dataType match { - case _: IntegerType => SparkInteger(data.asInstanceOf[Integer]) - case _: LongType => SparkLong(data.asInstanceOf[java.lang.Long]) - case _: BooleanType => SparkBoolean(data.asInstanceOf[java.lang.Boolean]) - case _: StringType => SparkString(data.asInstanceOf[UTF8String]) - case _: FloatType => SparkFloat(data.asInstanceOf[java.lang.Float]) - case _: DoubleType => SparkDouble(data.asInstanceOf[java.lang.Double]) - case _: BinaryType => SparkBinary(data.asInstanceOf[Array[Byte]]) - case _: ArrayType => SparkArray(data.asInstanceOf[ArrayData], dataType.asInstanceOf[ArrayType]) - case _: MapType => SparkMap(data.asInstanceOf[MapData], dataType.asInstanceOf[MapType]) - case _: StructType => SparkStruct(data.asInstanceOf[InternalRow], dataType.asInstanceOf[StructType]) + case _: IntegerType => data.asInstanceOf[Object] + case _: LongType => data.asInstanceOf[Object] + case _: BooleanType => data.asInstanceOf[Object] + case _: StringType => data.asInstanceOf[UTF8String].toString + case _: FloatType => data.asInstanceOf[Object] + case _: DoubleType => data.asInstanceOf[Object] + case _: BinaryType => ByteBuffer.wrap(data.asInstanceOf[Array[Byte]]) + case _: ArrayType => SparkArrayData( + data.asInstanceOf[org.apache.spark.sql.catalyst.util.ArrayData], dataType.asInstanceOf[ArrayType] + ) + case _: MapType => SparkMapData( + data.asInstanceOf[org.apache.spark.sql.catalyst.util.MapData], dataType.asInstanceOf[MapType] + ) + case _: StructType => SparkRowData(data.asInstanceOf[InternalRow], dataType.asInstanceOf[StructType]) case _: NullType => null case _ => throw new UnsupportedOperationException("Unrecognized Spark Type: " + dataType.getClass) } } } + def getPlatformData(transportData: Object): Object = { + if (transportData == null) { + null + } else { + transportData match { + case _: java.lang.Integer => transportData + case _: java.lang.Long => transportData + case _: java.lang.Float => transportData + case _: java.lang.Double => transportData + case _: java.lang.Boolean => transportData + case _: java.lang.String => UTF8String.fromString(transportData.asInstanceOf[String]) + case _: ByteBuffer => transportData.asInstanceOf[ByteBuffer].array() + case _ => transportData.asInstanceOf[PlatformData].getUnderlyingData + } + } + } + def createStdType(dataType: DataType): StdType = dataType match { case _: IntegerType => SparkIntegerType(dataType.asInstanceOf[IntegerType]) case _: LongType => SparkLongType(dataType.asInstanceOf[LongType]) @@ -49,7 +69,7 @@ object SparkWrapper { case _: BinaryType => SparkBinaryType(dataType.asInstanceOf[BinaryType]) case _: ArrayType => SparkArrayType(dataType.asInstanceOf[ArrayType]) case _: MapType => SparkMapType(dataType.asInstanceOf[MapType]) - case _: StructType => SparkStructType(dataType.asInstanceOf[StructType]) + case _: StructType => SparkRowType(dataType.asInstanceOf[StructType]) case _: NullType => SparkUnknownType(dataType.asInstanceOf[NullType]) case _ => throw new UnsupportedOperationException("Unrecognized Spark Type: " + dataType.getClass) } diff --git a/transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/StdUdfWrapper.scala b/transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/StdUdfWrapper.scala index f1cca7d4..5eca65a1 100644 --- a/transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/StdUdfWrapper.scala +++ b/transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/StdUdfWrapper.scala @@ -10,7 +10,6 @@ import java.nio.file.Paths import java.util.List import com.linkedin.transport.api.StdFactory -import com.linkedin.transport.api.data.{PlatformData, StdData} import com.linkedin.transport.api.udf._ import com.linkedin.transport.spark.typesystem.SparkTypeInference import com.linkedin.transport.utils.FileSystemUtils @@ -64,29 +63,29 @@ abstract class StdUdfWrapper(_expressions: Seq[Expression]) extends Expression if (wrappedConstants != null) { val requiredFiles = wrappedConstants.length match { case 0 => - _stdUdf.asInstanceOf[StdUDF0[StdData]].getRequiredFiles() + _stdUdf.asInstanceOf[StdUDF0[Object]].getRequiredFiles() case 1 => - _stdUdf.asInstanceOf[StdUDF1[StdData, StdData]].getRequiredFiles(wrappedConstants(0)) + _stdUdf.asInstanceOf[StdUDF1[Object, Object]].getRequiredFiles(wrappedConstants(0)) case 2 => - _stdUdf.asInstanceOf[StdUDF2[StdData, StdData, StdData]].getRequiredFiles(wrappedConstants(0), + _stdUdf.asInstanceOf[StdUDF2[Object, Object, Object]].getRequiredFiles(wrappedConstants(0), wrappedConstants(1)) case 3 => - _stdUdf.asInstanceOf[StdUDF3[StdData, StdData, StdData, StdData]].getRequiredFiles(wrappedConstants(0), + _stdUdf.asInstanceOf[StdUDF3[Object, Object, Object, Object]].getRequiredFiles(wrappedConstants(0), wrappedConstants(1), wrappedConstants(2)) case 4 => - _stdUdf.asInstanceOf[StdUDF4[StdData, StdData, StdData, StdData, StdData]].getRequiredFiles(wrappedConstants(0), + _stdUdf.asInstanceOf[StdUDF4[Object, Object, Object, Object, Object]].getRequiredFiles(wrappedConstants(0), wrappedConstants(1), wrappedConstants(2), wrappedConstants(3)) case 5 => - _stdUdf.asInstanceOf[StdUDF5[StdData, StdData, StdData, StdData, StdData, StdData]].getRequiredFiles(wrappedConstants(0), + _stdUdf.asInstanceOf[StdUDF5[Object, Object, Object, Object, Object, Object]].getRequiredFiles(wrappedConstants(0), wrappedConstants(1), wrappedConstants(2), wrappedConstants(3), wrappedConstants(4)) case 6 => - _stdUdf.asInstanceOf[StdUDF6[StdData, StdData, StdData, StdData, StdData, StdData, StdData]].getRequiredFiles(wrappedConstants(0), + _stdUdf.asInstanceOf[StdUDF6[Object, Object, Object, Object, Object, Object, Object]].getRequiredFiles(wrappedConstants(0), wrappedConstants(1), wrappedConstants(2), wrappedConstants(3), wrappedConstants(4), wrappedConstants(5)) case 7 => - _stdUdf.asInstanceOf[StdUDF7[StdData, StdData, StdData, StdData, StdData, StdData, StdData, StdData]].getRequiredFiles(wrappedConstants(0), + _stdUdf.asInstanceOf[StdUDF7[Object, Object, Object, Object, Object, Object, Object, Object]].getRequiredFiles(wrappedConstants(0), wrappedConstants(1), wrappedConstants(2), wrappedConstants(3), wrappedConstants(4), wrappedConstants(5), wrappedConstants(6)) case 8 => - _stdUdf.asInstanceOf[StdUDF8[StdData, StdData, StdData, StdData, StdData, StdData, StdData, StdData, StdData]].getRequiredFiles(wrappedConstants(0), + _stdUdf.asInstanceOf[StdUDF8[Object, Object, Object, Object, Object, Object, Object, Object, Object]].getRequiredFiles(wrappedConstants(0), wrappedConstants(1), wrappedConstants(2), wrappedConstants(3), wrappedConstants(4), wrappedConstants(5), wrappedConstants(6), wrappedConstants(7)) case _ => throw new UnsupportedOperationException("getRequiredFiles not yet supported for StdUDF" + _expressions.length) @@ -108,8 +107,8 @@ abstract class StdUdfWrapper(_expressions: Seq[Expression]) extends Expression } } // scalastyle:on magic.number - private final def checkNullsAndWrapConstants(): Array[StdData] = { - val wrappedConstants = new Array[StdData](_expressions.length) + private final def checkNullsAndWrapConstants(): Array[Object] = { + val wrappedConstants = new Array[Object](_expressions.length) for (i <- _expressions.indices) { val constantValue = if (_expressions(i).foldable) _expressions(i).eval() else null if (!_nullableArguments(i) && _expressions(i).foldable && constantValue == null) { @@ -135,42 +134,41 @@ abstract class StdUdfWrapper(_expressions: Seq[Expression]) extends Expression } val stdResult = wrappedArguments.length match { case 0 => - _stdUdf.asInstanceOf[StdUDF0[StdData]].eval() + _stdUdf.asInstanceOf[StdUDF0[Object]].eval() case 1 => - _stdUdf.asInstanceOf[StdUDF1[StdData, StdData]].eval(wrappedArguments(0)) + _stdUdf.asInstanceOf[StdUDF1[Object, Object]].eval(wrappedArguments(0)) case 2 => - _stdUdf.asInstanceOf[StdUDF2[StdData, StdData, StdData]].eval(wrappedArguments(0), wrappedArguments(1)) + _stdUdf.asInstanceOf[StdUDF2[Object, Object, Object]].eval(wrappedArguments(0), wrappedArguments(1)) case 3 => - _stdUdf.asInstanceOf[StdUDF3[StdData, StdData, StdData, StdData]].eval(wrappedArguments(0), wrappedArguments(1), + _stdUdf.asInstanceOf[StdUDF3[Object, Object, Object, Object]].eval(wrappedArguments(0), wrappedArguments(1), wrappedArguments(2)) case 4 => - _stdUdf.asInstanceOf[StdUDF4[StdData, StdData, StdData, StdData, StdData]].eval(wrappedArguments(0), + _stdUdf.asInstanceOf[StdUDF4[Object, Object, Object, Object, Object]].eval(wrappedArguments(0), wrappedArguments(1), wrappedArguments(2), wrappedArguments(3)) case 5 => - _stdUdf.asInstanceOf[StdUDF5[StdData, StdData, StdData, StdData, StdData, StdData]].eval(wrappedArguments(0), + _stdUdf.asInstanceOf[StdUDF5[Object, Object, Object, Object, Object, Object]].eval(wrappedArguments(0), wrappedArguments(1), wrappedArguments(2), wrappedArguments(3), wrappedArguments(4)) case 6 => - _stdUdf.asInstanceOf[StdUDF6[StdData, StdData, StdData, StdData, StdData, StdData, StdData]].eval(wrappedArguments(0), + _stdUdf.asInstanceOf[StdUDF6[Object, Object, Object, Object, Object, Object, Object]].eval(wrappedArguments(0), wrappedArguments(1), wrappedArguments(2), wrappedArguments(3), wrappedArguments(4), wrappedArguments(5)) case 7 => - _stdUdf.asInstanceOf[StdUDF7[StdData, StdData, StdData, StdData, StdData, StdData, StdData, StdData]].eval(wrappedArguments(0), + _stdUdf.asInstanceOf[StdUDF7[Object, Object, Object, Object, Object, Object, Object, Object]].eval(wrappedArguments(0), wrappedArguments(1), wrappedArguments(2), wrappedArguments(3), wrappedArguments(4), wrappedArguments(5), wrappedArguments(6)) case 8 => - _stdUdf.asInstanceOf[StdUDF8[StdData, StdData, StdData, StdData, StdData, StdData, StdData, StdData, StdData]].eval(wrappedArguments(0), + _stdUdf.asInstanceOf[StdUDF8[Object, Object, Object, Object, Object, Object, Object, Object, Object]].eval(wrappedArguments(0), wrappedArguments(1), wrappedArguments(2), wrappedArguments(3), wrappedArguments(4), wrappedArguments(5), wrappedArguments(6), wrappedArguments(7)) case _ => throw new UnsupportedOperationException("eval not yet supported for StdUDF" + _expressions.length) } - if (stdResult == null) null else stdResult.asInstanceOf[PlatformData].getUnderlyingData + SparkWrapper.getPlatformData(stdResult) } } // scalastyle:on magic.number - - private final def checkNullsAndWrapArguments(input: InternalRow): Array[StdData] = { - val wrappedArguments = new Array[StdData](_expressions.length) + private final def checkNullsAndWrapArguments(input: InternalRow): Array[Object] = { + val wrappedArguments = new Array[Object](_expressions.length) for (i <- _expressions.indices) { val evaluatedExpression = _expressions(i).eval(input) if(!_nullableArguments(i) && evaluatedExpression == null) { diff --git a/transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/data/SparkArray.scala b/transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/data/SparkArrayData.scala similarity index 75% rename from transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/data/SparkArray.scala rename to transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/data/SparkArrayData.scala index 9fe91cab..e98ef069 100644 --- a/transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/data/SparkArray.scala +++ b/transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/data/SparkArrayData.scala @@ -7,20 +7,19 @@ package com.linkedin.transport.spark.data import java.util -import com.linkedin.transport.api.data.{PlatformData, StdArray, StdData} +import com.linkedin.transport.api.data.{ArrayData, PlatformData} import com.linkedin.transport.spark.SparkWrapper -import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.types.{ArrayType, DataType} import scala.collection.mutable.ArrayBuffer -case class SparkArray(private var _arrayData: ArrayData, - private val _arrayType: DataType) extends StdArray with PlatformData { +case class SparkArrayData[E](private var _arrayData: org.apache.spark.sql.catalyst.util.ArrayData, + private val _arrayType: DataType) extends ArrayData[E] with PlatformData { private val _elementType = _arrayType.asInstanceOf[ArrayType].elementType private var _mutableBuffer: ArrayBuffer[Any] = if (_arrayData == null) createMutableArray() else null - override def add(e: StdData): Unit = { + override def add(e: E): Unit = { // Once add is called, we cannot use Spark's readonly ArrayData API // we have to add elements to a mutable buffer and start using that // always instead of the readonly stdType @@ -29,7 +28,7 @@ case class SparkArray(private var _arrayData: ArrayData, _mutableBuffer = createMutableArray() } // TODO: Does not support inserting nulls. Should we? - _mutableBuffer.append(e.asInstanceOf[PlatformData].getUnderlyingData) + _mutableBuffer.append(SparkWrapper.getPlatformData(e.asInstanceOf[Object])) } private def createMutableArray(): ArrayBuffer[Any] = { @@ -47,20 +46,20 @@ case class SparkArray(private var _arrayData: ArrayData, if (_mutableBuffer == null) { _arrayData } else { - ArrayData.toArrayData(_mutableBuffer) + org.apache.spark.sql.catalyst.util.ArrayData.toArrayData(_mutableBuffer) } } override def setUnderlyingData(value: scala.Any): Unit = { - _arrayData = value.asInstanceOf[ArrayData] + _arrayData = value.asInstanceOf[org.apache.spark.sql.catalyst.util.ArrayData] _mutableBuffer = null } - override def iterator(): util.Iterator[StdData] = { - new util.Iterator[StdData] { + override def iterator(): util.Iterator[E] = { + new util.Iterator[E] { private var idx = 0 - override def next(): StdData = { + override def next(): E = { val e = get(idx) idx += 1 e @@ -78,11 +77,11 @@ case class SparkArray(private var _arrayData: ArrayData, } } - override def get(idx: Int): StdData = { + override def get(idx: Int): E = { if (_mutableBuffer == null) { - SparkWrapper.createStdData(_arrayData.get(idx, _elementType), _elementType) + SparkWrapper.createStdData(_arrayData.get(idx, _elementType), _elementType).asInstanceOf[E] } else { - SparkWrapper.createStdData(_mutableBuffer(idx), _elementType) + SparkWrapper.createStdData(_mutableBuffer(idx), _elementType).asInstanceOf[E] } } } diff --git a/transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/data/SparkBinary.scala b/transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/data/SparkBinary.scala deleted file mode 100644 index bd402530..00000000 --- a/transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/data/SparkBinary.scala +++ /dev/null @@ -1,19 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.spark.data - -import java.nio.ByteBuffer - -import com.linkedin.transport.api.data.{PlatformData, StdBinary} - -case class SparkBinary(private var _bytes: Array[Byte]) extends StdBinary with PlatformData { - - override def get(): ByteBuffer = ByteBuffer.wrap(_bytes) - - override def getUnderlyingData: AnyRef = _bytes - - override def setUnderlyingData(value: scala.Any): Unit = _bytes = value.asInstanceOf[ByteBuffer].array() -} diff --git a/transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/data/SparkBoolean.scala b/transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/data/SparkBoolean.scala deleted file mode 100644 index 2477eef2..00000000 --- a/transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/data/SparkBoolean.scala +++ /dev/null @@ -1,17 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.spark.data - -import com.linkedin.transport.api.data.{PlatformData, StdBoolean} - -case class SparkBoolean(private var _bool: java.lang.Boolean) extends StdBoolean with PlatformData { - - override def get(): Boolean = _bool.booleanValue() - - override def getUnderlyingData: AnyRef = _bool - - override def setUnderlyingData(value: scala.Any): Unit = _bool = value.asInstanceOf[java.lang.Boolean] -} diff --git a/transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/data/SparkDouble.scala b/transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/data/SparkDouble.scala deleted file mode 100644 index 6a4820e3..00000000 --- a/transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/data/SparkDouble.scala +++ /dev/null @@ -1,18 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.spark.data - -import com.linkedin.transport.api.data.{PlatformData, StdDouble} - -case class SparkDouble(private var _double: java.lang.Double) extends StdDouble with PlatformData { - - override def get(): Double = _double.doubleValue() - - override def getUnderlyingData: AnyRef = _double - - override def setUnderlyingData(value: scala.Any): Unit = _double = value.asInstanceOf[java.lang.Double] -} - diff --git a/transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/data/SparkFloat.scala b/transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/data/SparkFloat.scala deleted file mode 100644 index d9842b51..00000000 --- a/transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/data/SparkFloat.scala +++ /dev/null @@ -1,17 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.spark.data - -import com.linkedin.transport.api.data.{PlatformData, StdFloat} - -case class SparkFloat(private var _float: java.lang.Float) extends StdFloat with PlatformData { - - override def get(): Float = _float.floatValue() - - override def getUnderlyingData: AnyRef = _float - - override def setUnderlyingData(value: scala.Any): Unit = _float = value.asInstanceOf[java.lang.Float] -} diff --git a/transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/data/SparkInteger.scala b/transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/data/SparkInteger.scala deleted file mode 100644 index b7c0db9e..00000000 --- a/transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/data/SparkInteger.scala +++ /dev/null @@ -1,17 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.spark.data - -import com.linkedin.transport.api.data.{PlatformData, StdInteger} - -case class SparkInteger(private var _int: Integer) extends StdInteger with PlatformData { - - override def get(): Int = _int.intValue() - - override def getUnderlyingData: AnyRef = _int - - override def setUnderlyingData(value: scala.Any): Unit = _int = value.asInstanceOf[Integer] -} diff --git a/transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/data/SparkLong.scala b/transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/data/SparkLong.scala deleted file mode 100644 index 5a534290..00000000 --- a/transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/data/SparkLong.scala +++ /dev/null @@ -1,18 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.spark.data - -import com.linkedin.transport.api.data.{PlatformData, StdLong} - -case class SparkLong(private var _long: java.lang.Long) extends StdLong with PlatformData { - - override def get(): Long = _long.longValue() - - override def getUnderlyingData: AnyRef = _long - - override def setUnderlyingData(value: scala.Any): Unit = _long = value.asInstanceOf[java.lang.Long] - -} diff --git a/transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/data/SparkMap.scala b/transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/data/SparkMap.scala deleted file mode 100644 index d200be8c..00000000 --- a/transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/data/SparkMap.scala +++ /dev/null @@ -1,134 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.spark.data - -import java.util - -import com.linkedin.transport.api.data.{PlatformData, StdData, StdMap} -import com.linkedin.transport.spark.SparkWrapper -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, MapData} -import org.apache.spark.sql.types.MapType - -import scala.collection.mutable.Map - - -case class SparkMap(private var _mapData: MapData, - private val _mapType: MapType) extends StdMap with PlatformData { - - private val _keyType = _mapType.keyType - private val _valueType = _mapType.valueType - private var _mutableMap: Map[Any, Any] = if (_mapData == null) createMutableMap() else null - - override def put(key: StdData, value: StdData): Unit = { - // TODO: Does not support inserting nulls. Should we? - if (_mutableMap == null) { - _mutableMap = createMutableMap() - } - _mutableMap.put(key.asInstanceOf[PlatformData].getUnderlyingData, value.asInstanceOf[PlatformData].getUnderlyingData) - } - - override def keySet(): util.Set[StdData] = { - val keysIterator: Iterator[Any] = if (_mutableMap == null) { - new Iterator[Any] { - var offset : Int = 0 - - override def next(): Any = { - offset += 1 - _mapData.keyArray().get(offset - 1, _keyType) - } - - override def hasNext: Boolean = { - offset < SparkMap.this.size() - } - } - } else { - _mutableMap.keysIterator - } - - new util.AbstractSet[StdData] { - - override def iterator(): util.Iterator[StdData] = new util.Iterator[StdData] { - - override def next(): StdData = SparkWrapper.createStdData(keysIterator.next(), _keyType) - - override def hasNext: Boolean = keysIterator.hasNext - } - - override def size(): Int = SparkMap.this.size() - } - } - - override def size(): Int = { - if (_mutableMap == null) { - _mapData.numElements() - } else { - _mutableMap.size - } - } - - override def values(): util.Collection[StdData] = { - val valueIterator: Iterator[Any] = if (_mutableMap == null) { - new Iterator[Any] { - var offset : Int = 0 - - override def next(): Any = { - offset += 1 - _mapData.valueArray().get(offset - 1, _valueType) - } - - override def hasNext: Boolean = { - offset < SparkMap.this.size() - } - } - } else { - _mutableMap.valuesIterator - } - - new util.AbstractCollection[StdData] { - - override def iterator(): util.Iterator[StdData] = new util.Iterator[StdData] { - - override def next(): StdData = SparkWrapper.createStdData(valueIterator.next(), _valueType) - - override def hasNext: Boolean = valueIterator.hasNext - } - - override def size(): Int = SparkMap.this.size() - } - } - - override def containsKey(key: StdData): Boolean = get(key) != null - - override def get(key: StdData): StdData = { - // Spark's complex data types (MapData, ArrayData, InternalRow) do not implement equals/hashcode - // If the key is of the above complex data types, get() will return null - if (_mutableMap == null) { - _mutableMap = createMutableMap() - } - SparkWrapper.createStdData(_mutableMap.get(key.asInstanceOf[PlatformData].getUnderlyingData).orNull, _valueType) - } - - private def createMutableMap(): Map[Any, Any] = { - val mutableMap = Map.empty[Any, Any] - if (_mapData != null) { - _mapData.foreach(_keyType, _valueType, (k, v) => mutableMap.put(k, v)) - } - mutableMap - } - - override def getUnderlyingData: AnyRef = { - if (_mutableMap == null) { - _mapData - } else { - ArrayBasedMapData(_mutableMap) - } - } - - override def setUnderlyingData(value: scala.Any): Unit = { - _mapData = value.asInstanceOf[MapData] - _mutableMap = null - } -} diff --git a/transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/data/SparkMapData.scala b/transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/data/SparkMapData.scala new file mode 100644 index 00000000..cd9679c8 --- /dev/null +++ b/transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/data/SparkMapData.scala @@ -0,0 +1,106 @@ +/** + * Copyright 2018 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.transport.spark.data + +import java.util + +import com.linkedin.transport.api.data.{MapData, PlatformData} +import com.linkedin.transport.spark.SparkWrapper +import org.apache.spark.sql.catalyst.util.ArrayBasedMapData +import org.apache.spark.sql.types.MapType + +import scala.collection.mutable.Map + + +case class SparkMapData[K, V](private var _mapData: org.apache.spark.sql.catalyst.util.MapData, + private val _mapType: MapType) extends MapData[K, V] with PlatformData { + + private val _keyType = _mapType.keyType + private val _valueType = _mapType.valueType + private var _mutableMap: Map[Any, Any] = if (_mapData == null) createMutableMap() else null + + override def put(key: K, value: V): Unit = { + // TODO: Does not support inserting nulls. Should we? + if (_mutableMap == null) { + _mutableMap = createMutableMap() + } + _mutableMap.put( + SparkWrapper.getPlatformData(key.asInstanceOf[Object]), + SparkWrapper.getPlatformData(value.asInstanceOf[Object]) + ) + } + + override def keySet(): util.Set[K] = { + new util.AbstractSet[K] { + + override def iterator(): util.Iterator[K] = new util.Iterator[K] { + private val keysIterator = if (_mutableMap == null) _mapData.keyArray().array.iterator else _mutableMap.keysIterator + + override def next(): K = SparkWrapper.createStdData(keysIterator.next(), _keyType).asInstanceOf[K] + + override def hasNext: Boolean = keysIterator.hasNext + } + + override def size(): Int = SparkMapData.this.size() + } + } + + override def size(): Int = { + if (_mutableMap == null) { + _mapData.numElements() + } else { + _mutableMap.size + } + } + + override def values(): util.Collection[V] = { + new util.AbstractCollection[V] { + + override def iterator(): util.Iterator[V] = new util.Iterator[V] { + private val valueIterator = if (_mutableMap == null) _mapData.valueArray().array.iterator else _mutableMap.valuesIterator + + override def next(): V = SparkWrapper.createStdData(valueIterator.next(), _valueType).asInstanceOf[V] + + override def hasNext: Boolean = valueIterator.hasNext + } + + override def size(): Int = SparkMapData.this.size() + } + } + + override def containsKey(key: K): Boolean = get(key) != null + + override def get(key: K): V = { + // Spark's complex data types (MapData, ArrayData, InternalRow) do not implement equals/hashcode + // If the key is of the above complex data types, get() will return null + if (_mutableMap == null) { + _mutableMap = createMutableMap() + } + SparkWrapper.createStdData(_mutableMap.get(SparkWrapper.getPlatformData(key.asInstanceOf[Object])).orNull, _valueType) + .asInstanceOf[V] + } + + private def createMutableMap(): Map[Any, Any] = { + val mutableMap = Map.empty[Any, Any] + if (_mapData != null) { + _mapData.foreach(_keyType, _valueType, (k, v) => mutableMap.put(k, v)) + } + mutableMap + } + + override def getUnderlyingData: AnyRef = { + if (_mutableMap == null) { + _mapData + } else { + ArrayBasedMapData(_mutableMap) + } + } + + override def setUnderlyingData(value: scala.Any): Unit = { + _mapData = value.asInstanceOf[org.apache.spark.sql.catalyst.util.MapData] + _mutableMap = null + } +} diff --git a/transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/data/SparkStruct.scala b/transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/data/SparkRowData.scala similarity index 74% rename from transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/data/SparkStruct.scala rename to transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/data/SparkRowData.scala index ba432905..9cbc883e 100644 --- a/transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/data/SparkStruct.scala +++ b/transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/data/SparkRowData.scala @@ -7,7 +7,7 @@ package com.linkedin.transport.spark.data import java.util.{List => JavaList} -import com.linkedin.transport.api.data.{PlatformData, StdData, StdStruct} +import com.linkedin.transport.api.data.{PlatformData, RowData} import com.linkedin.transport.spark.SparkWrapper import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types.StructType @@ -16,14 +16,14 @@ import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer -case class SparkStruct(private var _row: InternalRow, - private val _structType: StructType) extends StdStruct with PlatformData { +case class SparkRowData(private var _row: InternalRow, + private val _structType: StructType) extends RowData with PlatformData { private var _mutableBuffer: ArrayBuffer[Any] = if (_row == null) createMutableStruct() else null - override def getField(name: String): StdData = getField(_structType.fieldIndex(name)) + override def getField(name: String): Object = getField(_structType.fieldIndex(name)) - override def getField(index: Int): StdData = { + override def getField(index: Int): Object = { val fieldDataType = _structType(index).dataType if (_mutableBuffer == null) { SparkWrapper.createStdData(_row.get(index, fieldDataType), fieldDataType) @@ -32,15 +32,15 @@ case class SparkStruct(private var _row: InternalRow, } } - override def setField(name: String, value: StdData): Unit = { + override def setField(name: String, value: Object): Unit = { setField(_structType.fieldIndex(name), value) } - override def setField(index: Int, value: StdData): Unit = { + override def setField(index: Int, value: Object): Unit = { if (_mutableBuffer == null) { _mutableBuffer = createMutableStruct() } - _mutableBuffer(index) = value.asInstanceOf[PlatformData].getUnderlyingData + _mutableBuffer(index) = SparkWrapper.getPlatformData(value) } private def createMutableStruct() = { @@ -51,7 +51,7 @@ case class SparkStruct(private var _row: InternalRow, } } - override def fields(): JavaList[StdData] = { + override def fields(): JavaList[Object] = { _structType.indices.map(getField).asJava } diff --git a/transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/data/SparkString.scala b/transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/data/SparkString.scala deleted file mode 100644 index bd089dd5..00000000 --- a/transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/data/SparkString.scala +++ /dev/null @@ -1,18 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.spark.data - -import com.linkedin.transport.api.data.{PlatformData, StdString} -import org.apache.spark.unsafe.types.UTF8String - -case class SparkString(private var _str: UTF8String) extends StdString with PlatformData { - - override def get(): String = _str.toString - - override def getUnderlyingData: AnyRef = _str - - override def setUnderlyingData(value: scala.Any): Unit = _str = value.asInstanceOf[UTF8String] -} diff --git a/transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/types/SparkTypes.scala b/transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/types/SparkTypes.scala index 45fdc5c5..554a282d 100644 --- a/transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/types/SparkTypes.scala +++ b/transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/types/SparkTypes.scala @@ -70,7 +70,7 @@ case class SparkMapType(mapType: MapType) extends StdMapType { override def valueType(): StdType = SparkWrapper.createStdType(mapType.valueType) } -case class SparkStructType(structType: StructType) extends StdStructType { +case class SparkRowType(structType: StructType) extends RowType { override def underlyingType(): DataType = structType diff --git a/transportable-udfs-spark_2.11/src/test/scala/com/linkedin/transport/spark/TestSparkFactory.scala b/transportable-udfs-spark_2.11/src/test/scala/com/linkedin/transport/spark/TestSparkFactory.scala index e9c6304a..20221e47 100644 --- a/transportable-udfs-spark_2.11/src/test/scala/com/linkedin/transport/spark/TestSparkFactory.scala +++ b/transportable-udfs-spark_2.11/src/test/scala/com/linkedin/transport/spark/TestSparkFactory.scala @@ -23,18 +23,6 @@ class TestSparkFactory { val typeFactory: SparkTypeFactory = new SparkTypeFactory val stdFactory = new SparkFactory(new SparkBoundVariables) - @Test - def testCreatePrimitives(): Unit = { - assertEquals(stdFactory.createInteger(1).get(), 1) - assertEquals(stdFactory.createLong(1L).get(), 1L) - assertEquals(stdFactory.createBoolean(true).get(), true) - assertEquals(stdFactory.createString("").get(), "") - assertEquals(stdFactory.createFloat(2.0f).get(), 2.0f) - assertEquals(stdFactory.createDouble(3.0).get(), 3.0) - val byteArray = "foo".getBytes(Charset.forName("UTF-8")) - assertEquals(stdFactory.createBinary(ByteBuffer.wrap(byteArray)).get().array(), byteArray) - } - @Test def testCreateArray(): Unit = { var stdArray = stdFactory.createArray(stdFactory.createStdType("array(integer)")) diff --git a/transportable-udfs-spark_2.11/src/test/scala/com/linkedin/transport/spark/data/TestSparkArray.scala b/transportable-udfs-spark_2.11/src/test/scala/com/linkedin/transport/spark/data/TestSparkArray.scala index 00d70d88..dfc024ac 100644 --- a/transportable-udfs-spark_2.11/src/test/scala/com/linkedin/transport/spark/data/TestSparkArray.scala +++ b/transportable-udfs-spark_2.11/src/test/scala/com/linkedin/transport/spark/data/TestSparkArray.scala @@ -5,7 +5,8 @@ */ package com.linkedin.transport.spark.data -import com.linkedin.transport.api.data.{PlatformData, StdArray} +import com.linkedin.transport.api.data +import com.linkedin.transport.api.data.{PlatformData} import com.linkedin.transport.spark.{SparkFactory, SparkWrapper} import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.types.{ArrayType, DataTypes} @@ -20,35 +21,33 @@ class TestSparkArray { @Test def testCreateSparkArray(): Unit = { - val stdArray = SparkWrapper.createStdData(arrayData, arrayType).asInstanceOf[StdArray] + val stdArray = SparkWrapper.createStdData(arrayData, arrayType).asInstanceOf[data.ArrayData[Integer]] assertEquals(stdArray.size(), arrayData.numElements()) assertSame(stdArray.asInstanceOf[PlatformData].getUnderlyingData, arrayData) } @Test def testSparkArrayGet(): Unit = { - val stdArray = SparkWrapper.createStdData(arrayData, arrayType).asInstanceOf[StdArray] + val stdArray = SparkWrapper.createStdData(arrayData, arrayType).asInstanceOf[data.ArrayData[Integer]] (0 until stdArray.size).foreach(idx => { - assertEquals(stdArray.get(idx).asInstanceOf[SparkInteger].get(), idx) + assertEquals(stdArray.get(idx), idx) }) } @Test def testSparkArrayAdd(): Unit = { - val stdArray = SparkWrapper.createStdData(arrayData, arrayType).asInstanceOf[StdArray] - val insert = stdFactory.createInteger(5) // scalastyle:ignore magic.number - stdArray.add(insert) + val stdArray = SparkWrapper.createStdData(arrayData, arrayType).asInstanceOf[data.ArrayData[Integer]] + stdArray.add(5) // Since original ArrayData is immutable, a mutable ArrayBuffer should be created and set as the underlying object assertNotSame(stdArray.asInstanceOf[PlatformData].getUnderlyingData, arrayData) assertEquals(stdArray.size(), arrayData.numElements() + 1) - assertEquals(stdArray.get(stdArray.size() - 1), insert) + assertEquals(stdArray.get(stdArray.size() - 1), 5) } @Test def testSparkArrayMutabilityReset(): Unit = { - val stdArray = SparkWrapper.createStdData(arrayData, arrayType).asInstanceOf[StdArray] - val insert = stdFactory.createInteger(5) // scalastyle:ignore magic.number - stdArray.add(insert) + val stdArray = SparkWrapper.createStdData(arrayData, arrayType).asInstanceOf[data.ArrayData[Integer]] + stdArray.add(5) stdArray.asInstanceOf[PlatformData].setUnderlyingData(arrayData) // After underlying data is explicitly set, mutuable buffer should be removed assertSame(stdArray.asInstanceOf[PlatformData].getUnderlyingData, arrayData) diff --git a/transportable-udfs-spark_2.11/src/test/scala/com/linkedin/transport/spark/data/TestSparkMap.scala b/transportable-udfs-spark_2.11/src/test/scala/com/linkedin/transport/spark/data/TestSparkMap.scala index 608c027f..12675eb1 100644 --- a/transportable-udfs-spark_2.11/src/test/scala/com/linkedin/transport/spark/data/TestSparkMap.scala +++ b/transportable-udfs-spark_2.11/src/test/scala/com/linkedin/transport/spark/data/TestSparkMap.scala @@ -5,7 +5,7 @@ */ package com.linkedin.transport.spark.data -import com.linkedin.transport.api.data.{PlatformData, StdMap, StdString} +import com.linkedin.transport.api.data.{MapData, PlatformData} import com.linkedin.transport.spark.{SparkFactory, SparkWrapper} import org.apache.spark.sql.catalyst.util.ArrayBasedMapData import org.apache.spark.sql.types.{DataTypes, MapType} @@ -23,58 +23,54 @@ class TestSparkMap { @Test def testCreateSparkMap(): Unit = { - val stdMap = SparkWrapper.createStdData(mapData, mapType).asInstanceOf[StdMap] + val stdMap = SparkWrapper.createStdData(mapData, mapType).asInstanceOf[MapData[String, String]] assertEquals(stdMap.size(), mapData.numElements()) assertSame(stdMap.asInstanceOf[PlatformData].getUnderlyingData, mapData) } @Test def testSparkMapKeySet(): Unit = { - val stdMap = SparkWrapper.createStdData(mapData, mapType).asInstanceOf[StdMap] - assertEqualsNoOrder(stdMap.keySet().toArray, mapData.keyArray.array.map(s => stdFactory.createString(s.toString))) + val stdMap = SparkWrapper.createStdData(mapData, mapType).asInstanceOf[MapData[String, String]] + assertEqualsNoOrder(stdMap.keySet().toArray, mapData.keyArray.array.map(s => s.toString)) } @Test def testSparkMapValues(): Unit = { - val stdMap = SparkWrapper.createStdData(mapData, mapType).asInstanceOf[StdMap] - assertEqualsNoOrder(stdMap.values().toArray, mapData.valueArray.array.map(s => stdFactory.createString(s.toString))) + val stdMap = SparkWrapper.createStdData(mapData, mapType).asInstanceOf[MapData[String, String]] + assertEqualsNoOrder(stdMap.values().toArray, mapData.valueArray.array.map(s => s.toString)) } @Test def testSparkMapGet(): Unit = { - val stdMap = SparkWrapper.createStdData(mapData, mapType).asInstanceOf[StdMap] + val stdMap = SparkWrapper.createStdData(mapData, mapType).asInstanceOf[MapData[String, String]] mapData.keyArray.foreach(mapType.keyType, (idx, key) => { - assertEquals(stdMap.get(stdFactory.createString(key.toString)).asInstanceOf[StdString].get, + assertEquals(stdMap.get(key.toString), mapData.valueArray.array(idx).toString) }) - assertEquals(stdMap.containsKey(stdFactory.createString("nonExistentKey")), false) - // Even for a get in SparkMap we create mutable Map since Spark's Impl is based of arrays. So underlying object should change + assertEquals(stdMap.containsKey("nonExistentKey"), false) + // Even for a get in SparkMapData we create mutable Map since Spark's Impl is based of arrays. So underlying object should change assertNotSame(stdMap.asInstanceOf[PlatformData].getUnderlyingData, mapData) } @Test def testSparkMapContainsKey(): Unit = { - val stdMap = SparkWrapper.createStdData(mapData, mapType).asInstanceOf[StdMap] - assertEquals(stdMap.containsKey(stdFactory.createString("k3")), true) - assertEquals(stdMap.containsKey(stdFactory.createString("k4")), false) + val stdMap = SparkWrapper.createStdData(mapData, mapType).asInstanceOf[MapData[String, String]] + assertEquals(stdMap.containsKey("k3"), true) + assertEquals(stdMap.containsKey("k4"), false) } @Test def testSparkMapPut(): Unit = { - val stdMap = SparkWrapper.createStdData(mapData, mapType).asInstanceOf[StdMap] - val insertKey = stdFactory.createString("k4") - val insertVal = stdFactory.createString("v4") - stdMap.put(insertKey, insertVal) + val stdMap = SparkWrapper.createStdData(mapData, mapType).asInstanceOf[MapData[String, String]] + stdMap.put("k4", "v4") assertEquals(stdMap.size(), mapData.numElements() + 1) - assertEquals(stdMap.get(stdFactory.createString("k4")), insertVal) + assertEquals(stdMap.get("k4"), "v4") } @Test def testSparkMapMutabilityReset(): Unit = { - val stdMap = SparkWrapper.createStdData(mapData, mapType).asInstanceOf[StdMap] - val insertKey = stdFactory.createString("k4") - val insertVal = stdFactory.createString("v4") - stdMap.put(insertKey, insertVal) + val stdMap = SparkWrapper.createStdData(mapData, mapType).asInstanceOf[MapData[String, String]] + stdMap.put("k4", "v4") stdMap.asInstanceOf[PlatformData].setUnderlyingData(mapData) // After underlying data is explicitly set, mutuable map should be removed assertSame(stdMap.asInstanceOf[PlatformData].getUnderlyingData, mapData) diff --git a/transportable-udfs-spark_2.11/src/test/scala/com/linkedin/transport/spark/data/TestSparkPrimitives.scala b/transportable-udfs-spark_2.11/src/test/scala/com/linkedin/transport/spark/data/TestSparkPrimitives.scala deleted file mode 100644 index 21b88c8e..00000000 --- a/transportable-udfs-spark_2.11/src/test/scala/com/linkedin/transport/spark/data/TestSparkPrimitives.scala +++ /dev/null @@ -1,80 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.spark.data - -import java.lang -import java.nio.ByteBuffer -import java.nio.charset.Charset - -import com.linkedin.transport.api.data._ -import com.linkedin.transport.spark.{SparkFactory, SparkWrapper} -import org.apache.spark.sql.types.DataTypes -import org.apache.spark.unsafe.types.UTF8String -import org.testng.Assert.{assertEquals, assertSame} -import org.testng.annotations.Test - - -class TestSparkPrimitives { - - val stdFactory = new SparkFactory(null) - - @Test - def testCreateSparkInteger(): Unit = { - val intData = 123 - val stdInteger = SparkWrapper.createStdData(intData, DataTypes.IntegerType).asInstanceOf[StdInteger] - assertEquals(stdInteger.get(), intData) - assertSame(stdInteger.asInstanceOf[PlatformData].getUnderlyingData, intData) - } - - @Test - def testCreateSparkLong(): Unit = { - val longData = new lang.Long(1234L) // scalastyle:ignore magic.number - val stdLong = SparkWrapper.createStdData(longData, DataTypes.LongType).asInstanceOf[StdLong] - assertEquals(stdLong.get(), longData) - assertSame(stdLong.asInstanceOf[PlatformData].getUnderlyingData, longData) - } - - @Test - def testCreateSparkBoolean(): Unit = { - val booleanData = new lang.Boolean(true) - val stdBoolean = SparkWrapper.createStdData(booleanData, DataTypes.BooleanType).asInstanceOf[StdBoolean] - assertEquals(stdBoolean.get(), true) - assertSame(stdBoolean.asInstanceOf[PlatformData].getUnderlyingData, booleanData) - } - - @Test - def testCreateSparkString(): Unit = { - val stringData = UTF8String.fromString("test") - val stdString = SparkWrapper.createStdData(stringData, DataTypes.StringType).asInstanceOf[StdString] - assertEquals(stdString.get(), "test") - assertSame(stdString.asInstanceOf[PlatformData].getUnderlyingData, stringData) - } - - @Test - def testCreateSparkFloat(): Unit = { - val floatData = new lang.Float(1.0f) - val stdFloat = SparkWrapper.createStdData(floatData, DataTypes.FloatType).asInstanceOf[StdFloat] - assertEquals(stdFloat.get(), 1.0f) - assertSame(stdFloat.asInstanceOf[PlatformData].getUnderlyingData, floatData) - } - - @Test - def testCreateSparkDouble(): Unit = { - val doubleData = new lang.Double(2.0) - val stdDouble = SparkWrapper.createStdData(doubleData, DataTypes.DoubleType).asInstanceOf[StdDouble] - assertEquals(stdDouble.get(), 2.0) - assertSame(stdDouble.asInstanceOf[PlatformData].getUnderlyingData, doubleData) - } - - @Test - def testCreateSparkBinary(): Unit = { - val bytesData = ByteBuffer.wrap("foo".getBytes(Charset.forName("UTF-8"))) - val stdByte = SparkWrapper.createStdData(bytesData.array(), DataTypes.BinaryType).asInstanceOf[StdBinary] - assertEquals(stdByte.get(), bytesData) - assertSame(stdByte.asInstanceOf[PlatformData].getUnderlyingData, bytesData.array()) - } - -} diff --git a/transportable-udfs-spark_2.11/src/test/scala/com/linkedin/transport/spark/data/TestSparkStruct.scala b/transportable-udfs-spark_2.11/src/test/scala/com/linkedin/transport/spark/data/TestSparkRowData.scala similarity index 70% rename from transportable-udfs-spark_2.11/src/test/scala/com/linkedin/transport/spark/data/TestSparkStruct.scala rename to transportable-udfs-spark_2.11/src/test/scala/com/linkedin/transport/spark/data/TestSparkRowData.scala index 9a911af2..df7def17 100644 --- a/transportable-udfs-spark_2.11/src/test/scala/com/linkedin/transport/spark/data/TestSparkStruct.scala +++ b/transportable-udfs-spark_2.11/src/test/scala/com/linkedin/transport/spark/data/TestSparkRowData.scala @@ -5,7 +5,7 @@ */ package com.linkedin.transport.spark.data -import com.linkedin.transport.api.data.{PlatformData, StdStruct} +import com.linkedin.transport.api.data.{PlatformData, RowData} import com.linkedin.transport.spark.{SparkFactory, SparkWrapper} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.ArrayData @@ -14,7 +14,7 @@ import org.apache.spark.unsafe.types.UTF8String import org.testng.Assert.{assertEquals, assertNotSame, assertSame} import org.testng.annotations.Test -class TestSparkStruct { +class TestSparkRowData { val stdFactory = new SparkFactory(null) val dataArray = Array(UTF8String.fromString("str1"), 0, 2L, false, ArrayData.toArrayData(Array.range(0, 5))) // scalastyle:ignore magic.number val fieldNames = Array("strField", "intField", "longField", "boolField", "arrField") @@ -25,41 +25,41 @@ class TestSparkStruct { @Test def testCreateSparkStruct(): Unit = { - val stdStruct = SparkWrapper.createStdData(structData, structType).asInstanceOf[StdStruct] + val stdStruct = SparkWrapper.createStdData(structData, structType).asInstanceOf[RowData] assertSame(stdStruct.asInstanceOf[PlatformData].getUnderlyingData, structData) } @Test def testSparkStructGetField(): Unit = { - val stdStruct = SparkWrapper.createStdData(structData, structType).asInstanceOf[StdStruct] + val stdStruct = SparkWrapper.createStdData(structData, structType).asInstanceOf[RowData] dataArray.indices.foreach(idx => { - assertEquals(stdStruct.getField(idx).asInstanceOf[PlatformData].getUnderlyingData, dataArray(idx)) - assertEquals(stdStruct.getField(fieldNames(idx)).asInstanceOf[PlatformData].getUnderlyingData, dataArray(idx)) + assertEquals(SparkWrapper.getPlatformData(stdStruct.getField(idx)), dataArray(idx)) + assertEquals(SparkWrapper.getPlatformData(stdStruct.getField(fieldNames(idx))), dataArray(idx)) }) } @Test def testSparkStructFields(): Unit = { - val stdStruct = SparkWrapper.createStdData(structData, structType).asInstanceOf[StdStruct] + val stdStruct = SparkWrapper.createStdData(structData, structType).asInstanceOf[RowData] assertEquals(stdStruct.fields().size(), structData.numFields) - assertEquals(stdStruct.fields().toArray.map(f => f.asInstanceOf[PlatformData].getUnderlyingData), dataArray) + assertEquals(stdStruct.fields().toArray.map(f => SparkWrapper.getPlatformData(f)), dataArray) } @Test def testSparkStructSetField(): Unit = { - val stdStruct = SparkWrapper.createStdData(structData, structType).asInstanceOf[StdStruct] - stdStruct.setField(1, stdFactory.createInteger(1)) - assertEquals(stdStruct.getField(1).asInstanceOf[PlatformData].getUnderlyingData, 1) - stdStruct.setField(fieldNames(2), stdFactory.createLong(5)) // scalastyle:ignore magic.number - assertEquals(stdStruct.getField(fieldNames(2)).asInstanceOf[PlatformData].getUnderlyingData, 5L) // scalastyle:ignore magic.number + val stdStruct = SparkWrapper.createStdData(structData, structType).asInstanceOf[RowData] + stdStruct.setField(1, 1) + assertEquals(stdStruct.getField(1), 1) + stdStruct.setField(fieldNames(2), 5L) // scalastyle:ignore magic.number + assertEquals(stdStruct.getField(fieldNames(2)), 5L) // scalastyle:ignore magic.number // Since original InternalRow is immutable, a mutable ArrayBuffer should be created and set as the underlying object assertNotSame(stdStruct.asInstanceOf[PlatformData].getUnderlyingData, structData) } @Test def testSparkStructMutabilityReset(): Unit = { - val stdStruct = SparkWrapper.createStdData(structData, structType).asInstanceOf[StdStruct] - stdStruct.setField(1, stdFactory.createInteger(1)) + val stdStruct = SparkWrapper.createStdData(structData, structType).asInstanceOf[RowData] + stdStruct.setField(1, 1) stdStruct.asInstanceOf[PlatformData].setUnderlyingData(structData) // After underlying data is explicitly set, mutable buffer should be removed assertSame(stdStruct.asInstanceOf[PlatformData].getUnderlyingData, structData) diff --git a/transportable-udfs-test/transportable-udfs-test-api/src/main/java/com/linkedin/transport/test/AbstractStdUDFTest.java b/transportable-udfs-test/transportable-udfs-test-api/src/main/java/com/linkedin/transport/test/AbstractStdUDFTest.java index 7dbc3d34..4e85393e 100644 --- a/transportable-udfs-test/transportable-udfs-test-api/src/main/java/com/linkedin/transport/test/AbstractStdUDFTest.java +++ b/transportable-udfs-test/transportable-udfs-test-api/src/main/java/com/linkedin/transport/test/AbstractStdUDFTest.java @@ -6,7 +6,9 @@ package com.linkedin.transport.test; import com.google.common.base.Preconditions; -import com.linkedin.transport.api.data.StdData; +import com.linkedin.transport.api.data.ArrayData; +import com.linkedin.transport.api.data.MapData; +import com.linkedin.transport.api.data.RowData; import com.linkedin.transport.api.udf.StdUDF; import com.linkedin.transport.api.udf.TopLevelStdUDF; import com.linkedin.transport.test.spi.FunctionCall; @@ -26,15 +28,12 @@ * An abstract class to be extended by all test classes. This class contains helper methods to initialize the * {@link StdTester} and create input and output data for the test cases. * - * The mapping between a {@link StdData} to the corresponding Java type is given below: + * Primitive data is represented by primitive types when passed to the test cases. + * The mapping between container types to the corresponding Java type is given below: *
    - *
  • {@link com.linkedin.transport.api.data.StdInteger} = {@link Integer}
  • - *
  • {@link com.linkedin.transport.api.data.StdLong} = {@link Long}
  • - *
  • {@link com.linkedin.transport.api.data.StdBoolean} = {@link Boolean}
  • - *
  • {@link com.linkedin.transport.api.data.StdString} = {@link String}
  • - *
  • {@link com.linkedin.transport.api.data.StdArray} = Use {@link #array(Object...)} to create arrays
  • - *
  • {@link com.linkedin.transport.api.data.StdMap} = Use {@link #map(Object...)} to create maps
  • - *
  • {@link com.linkedin.transport.api.data.StdStruct} = Use {@link #row(Object...)} to create structs
  • + *
  • {@link ArrayData} = Use {@link #array(Object...)} to create arrays
  • + *
  • {@link MapData} = Use {@link #map(Object...)} to create maps
  • + *
  • {@link RowData} = Use {@link #row(Object...)} to create structs
  • *
* * diff --git a/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/GenericFactory.java b/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/GenericFactory.java index 58d6a921..d3d26338 100644 --- a/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/GenericFactory.java +++ b/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/GenericFactory.java @@ -5,35 +5,19 @@ */ package com.linkedin.transport.test.generic; -import com.google.common.base.Preconditions; import com.linkedin.transport.api.StdFactory; -import com.linkedin.transport.api.data.StdArray; -import com.linkedin.transport.api.data.StdBoolean; -import com.linkedin.transport.api.data.StdBinary; -import com.linkedin.transport.api.data.StdDouble; -import com.linkedin.transport.api.data.StdFloat; -import com.linkedin.transport.api.data.StdInteger; -import com.linkedin.transport.api.data.StdLong; -import com.linkedin.transport.api.data.StdMap; -import com.linkedin.transport.api.data.StdString; -import com.linkedin.transport.api.data.StdStruct; +import com.linkedin.transport.api.data.ArrayData; +import com.linkedin.transport.api.data.MapData; +import com.linkedin.transport.api.data.RowData; import com.linkedin.transport.api.types.StdType; -import com.linkedin.transport.test.generic.data.GenericArray; -import com.linkedin.transport.test.generic.data.GenericBoolean; -import com.linkedin.transport.test.generic.data.GenericBinary; -import com.linkedin.transport.test.generic.data.GenericDouble; -import com.linkedin.transport.test.generic.data.GenericFloat; -import com.linkedin.transport.test.generic.data.GenericInteger; -import com.linkedin.transport.test.generic.data.GenericLong; -import com.linkedin.transport.test.generic.data.GenericMap; -import com.linkedin.transport.test.generic.data.GenericString; +import com.linkedin.transport.test.generic.data.GenericArrayData; +import com.linkedin.transport.test.generic.data.GenericMapData; import com.linkedin.transport.test.generic.data.GenericStruct; import com.linkedin.transport.test.generic.typesystem.GenericTypeFactory; import com.linkedin.transport.test.spi.types.TestType; import com.linkedin.transport.test.spi.types.TestTypeFactory; import com.linkedin.transport.typesystem.AbstractBoundVariables; import com.linkedin.transport.typesystem.TypeSignature; -import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.List; import java.util.stream.Collectors; @@ -50,69 +34,33 @@ public GenericFactory(AbstractBoundVariables boundVariables) { } @Override - public StdInteger createInteger(int value) { - return new GenericInteger(value); + public ArrayData createArray(StdType stdType, int expectedSize) { + return new GenericArrayData(new ArrayList<>(expectedSize), (TestType) stdType.underlyingType()); } @Override - public StdLong createLong(long value) { - return new GenericLong(value); - } - - @Override - public StdBoolean createBoolean(boolean value) { - return new GenericBoolean(value); - } - - @Override - public StdString createString(String value) { - Preconditions.checkNotNull(value, "Cannot create a null StdString"); - return new GenericString(value); - } - - @Override - public StdFloat createFloat(float value) { - return new GenericFloat(value); - } - - @Override - public StdDouble createDouble(double value) { - return new GenericDouble(value); - } - - @Override - public StdBinary createBinary(ByteBuffer value) { - return new GenericBinary(value); - } - - @Override - public StdArray createArray(StdType stdType, int expectedSize) { - return new GenericArray(new ArrayList<>(expectedSize), (TestType) stdType.underlyingType()); - } - - @Override - public StdArray createArray(StdType stdType) { + public ArrayData createArray(StdType stdType) { return createArray(stdType, 0); } @Override - public StdMap createMap(StdType stdType) { - return new GenericMap((TestType) stdType.underlyingType()); + public MapData createMap(StdType stdType) { + return new GenericMapData((TestType) stdType.underlyingType()); } @Override - public StdStruct createStruct(List fieldNames, List fieldTypes) { + public RowData createStruct(List fieldNames, List fieldTypes) { return new GenericStruct(TestTypeFactory.struct(fieldNames, fieldTypes.stream().map(x -> (TestType) x.underlyingType()).collect(Collectors.toList()))); } @Override - public StdStruct createStruct(List fieldTypes) { + public RowData createStruct(List fieldTypes) { return createStruct(null, fieldTypes); } @Override - public StdStruct createStruct(StdType stdType) { + public RowData createStruct(StdType stdType) { return new GenericStruct((TestType) stdType.underlyingType()); } diff --git a/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/GenericStdUDFWrapper.java b/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/GenericStdUDFWrapper.java index 2a345e3c..081e92a4 100644 --- a/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/GenericStdUDFWrapper.java +++ b/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/GenericStdUDFWrapper.java @@ -7,7 +7,6 @@ import com.linkedin.transport.api.StdFactory; import com.linkedin.transport.api.data.PlatformData; -import com.linkedin.transport.api.data.StdData; import com.linkedin.transport.api.udf.StdUDF; import com.linkedin.transport.api.udf.StdUDF0; import com.linkedin.transport.api.udf.StdUDF1; @@ -24,6 +23,7 @@ import com.linkedin.transport.utils.FileSystemUtils; import java.io.IOException; import java.lang.reflect.InvocationTargetException; +import java.nio.ByteBuffer; import java.util.Arrays; import java.util.List; import java.util.stream.Collectors; @@ -42,7 +42,7 @@ public class GenericStdUDFWrapper { protected boolean _requiredFilesProcessed; protected StdFactory _stdFactory; private boolean[] _nullableArguments; - private StdData[] _args; + private Object[] _args; private Class _topLevelUdfClass; private List> _stdUdfImplementations; private String[] _localFiles; @@ -83,12 +83,18 @@ protected boolean containsNullValuedNonNullableArgument(Object[] arguments) { return false; } - protected StdData wrap(Object argument, StdData stdData) { - if (argument != null) { - ((PlatformData) stdData).setUnderlyingData(argument); - return stdData; - } else { + protected Object wrap(Object argument, Object stdData) { + if (argument == null) { return null; + } else { + if (argument instanceof Integer || argument instanceof Long || argument instanceof Boolean + || argument instanceof String || argument instanceof Double || argument instanceof Float + || argument instanceof ByteBuffer) { + return argument; + } else { + ((PlatformData) stdData).setUnderlyingData(argument); + return stdData; + } } } @@ -107,26 +113,26 @@ protected Class getTopLevelUdfClass() { } protected void createStdData() { - _args = new StdData[_inputTypes.length]; + _args = new Object[_inputTypes.length]; for (int i = 0; i < _inputTypes.length; i++) { _args[i] = GenericWrapper.createStdData(null, _inputTypes[i]); } } - private StdData[] wrapArguments(Object[] arguments) { - return IntStream.range(0, _args.length).mapToObj(i -> wrap(arguments[i], _args[i])).toArray(StdData[]::new); + private Object[] wrapArguments(Object[] arguments) { + return IntStream.range(0, _args.length).mapToObj(i -> wrap(arguments[i], _args[i])).toArray(Object[]::new); } public Object evaluate(Object[] arguments) { if (containsNullValuedNonNullableArgument(arguments)) { return null; } - StdData[] args = wrapArguments(arguments); + Object[] args = wrapArguments(arguments); if (!_requiredFilesProcessed) { String[] requiredFiles = getRequiredFiles(args); processRequiredFiles(requiredFiles); } - StdData result; + Object result; switch (args.length) { case 0: result = ((StdUDF0) _stdUdf).eval(); @@ -158,10 +164,10 @@ public Object evaluate(Object[] arguments) { default: throw new UnsupportedOperationException("eval not yet supported for StdUDF" + args.length); } - return result == null ? null : ((PlatformData) result).getUnderlyingData(); + return GenericWrapper.getPlatformData(result); } - public String[] getRequiredFiles(StdData[] args) { + public String[] getRequiredFiles(Object[] args) { String[] requiredFiles; switch (args.length) { case 0: diff --git a/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/GenericWrapper.java b/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/GenericWrapper.java index 8754f0a8..e8707568 100644 --- a/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/GenericWrapper.java +++ b/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/GenericWrapper.java @@ -5,17 +5,10 @@ */ package com.linkedin.transport.test.generic; -import com.linkedin.transport.api.data.StdData; +import com.linkedin.transport.api.data.PlatformData; import com.linkedin.transport.api.types.StdType; -import com.linkedin.transport.test.generic.data.GenericArray; -import com.linkedin.transport.test.generic.data.GenericBoolean; -import com.linkedin.transport.test.generic.data.GenericBinary; -import com.linkedin.transport.test.generic.data.GenericDouble; -import com.linkedin.transport.test.generic.data.GenericFloat; -import com.linkedin.transport.test.generic.data.GenericInteger; -import com.linkedin.transport.test.generic.data.GenericLong; -import com.linkedin.transport.test.generic.data.GenericMap; -import com.linkedin.transport.test.generic.data.GenericString; +import com.linkedin.transport.test.generic.data.GenericArrayData; +import com.linkedin.transport.test.generic.data.GenericMapData; import com.linkedin.transport.test.generic.data.GenericStruct; import com.linkedin.transport.test.spi.Row; import com.linkedin.transport.test.spi.types.ArrayTestType; @@ -40,27 +33,17 @@ public class GenericWrapper { private GenericWrapper() { } - public static StdData createStdData(Object data, TestType dataType) { + public static Object createStdData(Object data, TestType dataType) { if (dataType instanceof UnknownTestType) { return null; - } else if (dataType instanceof IntegerTestType) { - return new GenericInteger((Integer) data); - } else if (dataType instanceof LongTestType) { - return new GenericLong((Long) data); - } else if (dataType instanceof BooleanTestType) { - return new GenericBoolean((Boolean) data); - } else if (dataType instanceof StringTestType) { - return new GenericString((String) data); - } else if (dataType instanceof FloatTestType) { - return new GenericFloat((Float) data); - } else if (dataType instanceof DoubleTestType) { - return new GenericDouble((Double) data); - } else if (dataType instanceof BinaryTestType) { - return new GenericBinary((ByteBuffer) data); + } else if (dataType instanceof IntegerTestType || dataType instanceof LongTestType + || dataType instanceof FloatTestType || dataType instanceof DoubleTestType + || dataType instanceof BooleanTestType || dataType instanceof StringTestType || dataType instanceof BinaryTestType) { + return data; } else if (dataType instanceof ArrayTestType) { - return new GenericArray((List) data, dataType); + return new GenericArrayData((List) data, dataType); } else if (dataType instanceof MapTestType) { - return new GenericMap((Map) data, dataType); + return new GenericMapData((Map) data, dataType); } else if (dataType instanceof StructTestType) { return new GenericStruct((Row) data, dataType); } else { @@ -68,6 +51,20 @@ public static StdData createStdData(Object data, TestType dataType) { } } + public static Object getPlatformData(Object transportData) { + if (transportData == null) { + return null; + } else { + if (transportData instanceof Integer || transportData instanceof Long || transportData instanceof Float + || transportData instanceof Double || transportData instanceof Boolean || transportData instanceof ByteBuffer + || transportData instanceof String) { + return transportData; + } else { + return ((PlatformData) transportData).getUnderlyingData(); + } + } + } + public static StdType createStdType(TestType dataType) { return () -> dataType; } diff --git a/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericArray.java b/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericArrayData.java similarity index 65% rename from transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericArray.java rename to transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericArrayData.java index b2152c93..2aa85cb9 100644 --- a/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericArray.java +++ b/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericArrayData.java @@ -5,9 +5,8 @@ */ package com.linkedin.transport.test.generic.data; +import com.linkedin.transport.api.data.ArrayData; import com.linkedin.transport.api.data.PlatformData; -import com.linkedin.transport.api.data.StdArray; -import com.linkedin.transport.api.data.StdData; import com.linkedin.transport.test.generic.GenericWrapper; import com.linkedin.transport.test.spi.types.ArrayTestType; import com.linkedin.transport.test.spi.types.TestType; @@ -15,12 +14,12 @@ import java.util.List; -public class GenericArray implements StdArray, PlatformData { +public class GenericArrayData implements ArrayData, PlatformData { private List _array; private TestType _elementType; - public GenericArray(List data, TestType type) { + public GenericArrayData(List data, TestType type) { _array = data; _elementType = ((ArrayTestType) type).getElementType(); } @@ -31,18 +30,18 @@ public int size() { } @Override - public StdData get(int idx) { - return GenericWrapper.createStdData(_array.get(idx), _elementType); + public E get(int idx) { + return (E) GenericWrapper.createStdData(_array.get(idx), _elementType); } @Override - public void add(StdData e) { - _array.add(((PlatformData) e).getUnderlyingData()); + public void add(E e) { + _array.add(GenericWrapper.getPlatformData(e)); } @Override - public Iterator iterator() { - return new Iterator() { + public Iterator iterator() { + return new Iterator() { private final Iterator _iterator = _array.iterator(); @Override @@ -51,8 +50,8 @@ public boolean hasNext() { } @Override - public StdData next() { - return GenericWrapper.createStdData(_iterator.next(), _elementType); + public E next() { + return (E) GenericWrapper.createStdData(_iterator.next(), _elementType); } }; } diff --git a/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericBinary.java b/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericBinary.java deleted file mode 100644 index 391a6752..00000000 --- a/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericBinary.java +++ /dev/null @@ -1,35 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.test.generic.data; - -import com.linkedin.transport.api.data.PlatformData; -import com.linkedin.transport.api.data.StdBinary; -import java.nio.ByteBuffer; - - -public class GenericBinary implements StdBinary, PlatformData { - - private ByteBuffer _byteBuffer; - - public GenericBinary(ByteBuffer aByteBuffer) { - _byteBuffer = aByteBuffer; - } - - @Override - public ByteBuffer get() { - return _byteBuffer; - } - - @Override - public Object getUnderlyingData() { - return _byteBuffer; - } - - @Override - public void setUnderlyingData(Object value) { - _byteBuffer = (ByteBuffer) value; - } -} diff --git a/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericBoolean.java b/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericBoolean.java deleted file mode 100644 index e731a1e3..00000000 --- a/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericBoolean.java +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.test.generic.data; - -import com.linkedin.transport.api.data.PlatformData; -import com.linkedin.transport.api.data.StdBoolean; - - -public class GenericBoolean implements StdBoolean, PlatformData { - private Boolean _boolean; - - public GenericBoolean(Boolean aBoolean) { - _boolean = aBoolean; - } - - @Override - public boolean get() { - return _boolean; - } - - @Override - public Object getUnderlyingData() { - return _boolean; - } - - @Override - public void setUnderlyingData(Object value) { - _boolean = (Boolean) value; - } -} diff --git a/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericDouble.java b/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericDouble.java deleted file mode 100644 index 05ac39bf..00000000 --- a/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericDouble.java +++ /dev/null @@ -1,34 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.test.generic.data; - -import com.linkedin.transport.api.data.PlatformData; -import com.linkedin.transport.api.data.StdDouble; - - -public class GenericDouble implements StdDouble, PlatformData { - - private Double _double; - - public GenericDouble(Double aDouble) { - _double = aDouble; - } - - @Override - public double get() { - return _double; - } - - @Override - public Object getUnderlyingData() { - return _double; - } - - @Override - public void setUnderlyingData(Object value) { - _double = (Double) value; - } -} diff --git a/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericFloat.java b/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericFloat.java deleted file mode 100644 index 806787de..00000000 --- a/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericFloat.java +++ /dev/null @@ -1,34 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.test.generic.data; - -import com.linkedin.transport.api.data.PlatformData; -import com.linkedin.transport.api.data.StdFloat; - - -public class GenericFloat implements StdFloat, PlatformData { - - private Float _float; - - public GenericFloat(Float aFloat) { - _float = aFloat; - } - - @Override - public float get() { - return _float; - } - - @Override - public Object getUnderlyingData() { - return _float; - } - - @Override - public void setUnderlyingData(Object value) { - _float = (Float) value; - } -} diff --git a/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericInteger.java b/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericInteger.java deleted file mode 100644 index bcb1905c..00000000 --- a/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericInteger.java +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.test.generic.data; - -import com.linkedin.transport.api.data.PlatformData; -import com.linkedin.transport.api.data.StdInteger; - - -public class GenericInteger implements StdInteger, PlatformData { - private Integer _integer; - - public GenericInteger(Integer integer) { - _integer = integer; - } - - @Override - public int get() { - return _integer; - } - - @Override - public Object getUnderlyingData() { - return _integer; - } - - @Override - public void setUnderlyingData(Object value) { - _integer = (Integer) value; - } -} diff --git a/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericLong.java b/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericLong.java deleted file mode 100644 index 85e9dac6..00000000 --- a/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericLong.java +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.test.generic.data; - -import com.linkedin.transport.api.data.PlatformData; -import com.linkedin.transport.api.data.StdLong; - - -public class GenericLong implements StdLong, PlatformData { - private Long _long; - - public GenericLong(Long aLong) { - _long = aLong; - } - - @Override - public long get() { - return _long; - } - - @Override - public Object getUnderlyingData() { - return _long; - } - - @Override - public void setUnderlyingData(Object value) { - _long = (Long) value; - } -} diff --git a/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericMap.java b/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericMapData.java similarity index 59% rename from transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericMap.java rename to transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericMapData.java index beeeb684..343fbac1 100644 --- a/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericMap.java +++ b/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericMapData.java @@ -5,9 +5,8 @@ */ package com.linkedin.transport.test.generic.data; +import com.linkedin.transport.api.data.MapData; import com.linkedin.transport.api.data.PlatformData; -import com.linkedin.transport.api.data.StdData; -import com.linkedin.transport.api.data.StdMap; import com.linkedin.transport.test.generic.GenericWrapper; import com.linkedin.transport.test.spi.types.MapTestType; import com.linkedin.transport.test.spi.types.TestType; @@ -20,19 +19,19 @@ import java.util.stream.Collectors; -public class GenericMap implements StdMap, PlatformData { +public class GenericMapData implements MapData, PlatformData { private Map _map; private final TestType _keyType; private final TestType _valueType; - public GenericMap(Map map, TestType type) { + public GenericMapData(Map map, TestType type) { _map = map; _keyType = ((MapTestType) type).getKeyType(); _valueType = ((MapTestType) type).getValueType(); } - public GenericMap(TestType type) { + public GenericMapData(TestType type) { this(new LinkedHashMap<>(), type); } @@ -52,21 +51,21 @@ public int size() { } @Override - public StdData get(StdData key) { - return GenericWrapper.createStdData(_map.get(((PlatformData) key).getUnderlyingData()), _valueType); + public V get(K key) { + return (V) GenericWrapper.createStdData(_map.get(GenericWrapper.getPlatformData(key)), _valueType); } @Override - public void put(StdData key, StdData value) { - _map.put(((PlatformData) key).getUnderlyingData(), ((PlatformData) value).getUnderlyingData()); + public void put(K key, V value) { + _map.put(GenericWrapper.getPlatformData(key), GenericWrapper.getPlatformData(value)); } @Override - public Set keySet() { - return new AbstractSet() { + public Set keySet() { + return new AbstractSet() { @Override - public Iterator iterator() { - return new Iterator() { + public Iterator iterator() { + return new Iterator() { Iterator keySet = _map.keySet().iterator(); @Override @@ -75,8 +74,8 @@ public boolean hasNext() { } @Override - public StdData next() { - return GenericWrapper.createStdData(keySet.next(), _keyType); + public K next() { + return (K) GenericWrapper.createStdData(keySet.next(), _keyType); } }; } @@ -89,12 +88,12 @@ public int size() { } @Override - public Collection values() { - return _map.values().stream().map(v -> GenericWrapper.createStdData(v, _valueType)).collect(Collectors.toList()); + public Collection values() { + return _map.values().stream().map(v -> (V) GenericWrapper.createStdData(v, _valueType)).collect(Collectors.toList()); } @Override - public boolean containsKey(StdData key) { - return _map.containsKey(((PlatformData) key).getUnderlyingData()); + public boolean containsKey(K key) { + return _map.containsKey(GenericWrapper.getPlatformData(key)); } } diff --git a/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericString.java b/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericString.java deleted file mode 100644 index 4bb1babb..00000000 --- a/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericString.java +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.test.generic.data; - -import com.linkedin.transport.api.data.PlatformData; -import com.linkedin.transport.api.data.StdString; - - -public class GenericString implements StdString, PlatformData { - private String _string; - - public GenericString(String string) { - _string = string; - } - - @Override - public String get() { - return _string; - } - - @Override - public Object getUnderlyingData() { - return _string; - } - - @Override - public void setUnderlyingData(Object value) { - _string = (String) value; - } -} diff --git a/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericStruct.java b/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericStruct.java index e92b6043..333ddb32 100644 --- a/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericStruct.java +++ b/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericStruct.java @@ -6,8 +6,7 @@ package com.linkedin.transport.test.generic.data; import com.linkedin.transport.api.data.PlatformData; -import com.linkedin.transport.api.data.StdData; -import com.linkedin.transport.api.data.StdStruct; +import com.linkedin.transport.api.data.RowData; import com.linkedin.transport.test.generic.GenericWrapper; import com.linkedin.transport.test.spi.Row; import com.linkedin.transport.test.spi.types.StructTestType; @@ -19,7 +18,7 @@ import java.util.stream.IntStream; -public class GenericStruct implements StdStruct, PlatformData { +public class GenericStruct implements RowData, PlatformData { private Row _struct; private final List _fieldNames; @@ -46,27 +45,27 @@ public void setUnderlyingData(Object value) { } @Override - public StdData getField(int index) { + public Object getField(int index) { return GenericWrapper.createStdData(_struct.getFields().get(index), _fieldTypes.get(index)); } @Override - public StdData getField(String name) { + public Object getField(String name) { return getField(_fieldNames.indexOf(name)); } @Override - public void setField(int index, StdData value) { - _struct.getFields().set(index, ((PlatformData) value).getUnderlyingData()); + public void setField(int index, Object value) { + _struct.getFields().set(index, GenericWrapper.getPlatformData(value)); } @Override - public void setField(String name, StdData value) { + public void setField(String name, Object value) { setField(_fieldNames.indexOf(name), value); } @Override - public List fields() { + public List fields() { return IntStream.range(0, _struct.getFields().size()).mapToObj(this::getField).collect(Collectors.toList()); } } diff --git a/transportable-udfs-test/transportable-udfs-test-hive/src/main/java/com/linkedin/transport/test/hive/udf/MapFromEntries.java b/transportable-udfs-test/transportable-udfs-test-hive/src/main/java/com/linkedin/transport/test/hive/udf/MapFromEntries.java index 4a415fd8..4867fcb4 100644 --- a/transportable-udfs-test/transportable-udfs-test-hive/src/main/java/com/linkedin/transport/test/hive/udf/MapFromEntries.java +++ b/transportable-udfs-test/transportable-udfs-test-hive/src/main/java/com/linkedin/transport/test/hive/udf/MapFromEntries.java @@ -7,10 +7,9 @@ import com.google.common.collect.ImmutableList; import com.linkedin.transport.api.StdFactory; -import com.linkedin.transport.api.data.StdArray; -import com.linkedin.transport.api.data.StdData; -import com.linkedin.transport.api.data.StdMap; -import com.linkedin.transport.api.data.StdStruct; +import com.linkedin.transport.api.data.ArrayData; +import com.linkedin.transport.api.data.MapData; +import com.linkedin.transport.api.data.RowData; import com.linkedin.transport.api.types.StdMapType; import com.linkedin.transport.api.udf.StdUDF1; import com.linkedin.transport.api.udf.TopLevelStdUDF; @@ -21,7 +20,7 @@ * Hive's built-in map() UDF cannot be used to create maps with complex key types. This UDF allows you to do so. * This is used inside {@link com.linkedin.transport.test.hive.HiveTester} to create arbitrary map objects */ -public class MapFromEntries extends StdUDF1 implements TopLevelStdUDF { +public class MapFromEntries extends StdUDF1 implements TopLevelStdUDF { private StdMapType _mapType; @@ -32,10 +31,10 @@ public void init(StdFactory stdFactory) { } @Override - public StdMap eval(StdArray entryArray) { - StdMap result = getStdFactory().createMap(_mapType); - for (StdData element : entryArray) { - StdStruct elementStruct = (StdStruct) element; + public MapData eval(ArrayData entryArray) { + MapData result = getStdFactory().createMap(_mapType); + for (Object element : entryArray) { + RowData elementStruct = (RowData) element; result.put(elementStruct.getField(0), elementStruct.getField(1)); } return result; diff --git a/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/StdUdfWrapper.java b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/StdUdfWrapper.java index 0f2d57af..230edb6e 100644 --- a/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/StdUdfWrapper.java +++ b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/StdUdfWrapper.java @@ -10,8 +10,6 @@ import com.google.common.collect.ImmutableSet; import com.google.common.primitives.Booleans; import com.linkedin.transport.api.StdFactory; -import com.linkedin.transport.api.data.PlatformData; -import com.linkedin.transport.api.data.StdData; import com.linkedin.transport.api.udf.StdUDF; import com.linkedin.transport.api.udf.StdUDF0; import com.linkedin.transport.api.udf.StdUDF1; @@ -42,6 +40,7 @@ import io.trino.spi.type.MapType; import io.trino.spi.type.RowType; import io.trino.spi.type.Type; + import java.lang.invoke.MethodHandle; import java.lang.invoke.MethodHandles; import java.lang.invoke.MethodType; @@ -180,9 +179,9 @@ private List getNullConventio .collect(Collectors.toList()); } - private StdData[] wrapArguments(StdUDF stdUDF, Type[] types, Object[] arguments) { + private Object[] wrapArguments(StdUDF stdUDF, Type[] types, Object[] arguments) { StdFactory stdFactory = stdUDF.getStdFactory(); - StdData[] stdData = new StdData[arguments.length]; + Object[] stdData = new Object[arguments.length]; // TODO: Reuse wrapper objects by creating them once upon initialization and reuse them here // along the same lines of what we do in Hive implementation. // JIRA: https://jira01.corp.linkedin.com:8443/browse/LIHADOOP-34894 @@ -194,12 +193,12 @@ private StdData[] wrapArguments(StdUDF stdUDF, Type[] types, Object[] arguments) protected Object eval(StdUDF stdUDF, Type[] types, boolean isIntegerReturnType, AtomicLong requiredFilesNextRefreshTime, Object... arguments) { - StdData[] args = wrapArguments(stdUDF, types, arguments); + Object[] args = wrapArguments(stdUDF, types, arguments); if (requiredFilesNextRefreshTime.get() <= System.currentTimeMillis()) { String[] requiredFiles = getRequiredFiles(stdUDF, args); processRequiredFiles(stdUDF, requiredFiles, requiredFilesNextRefreshTime); } - StdData result; + Object result; switch (args.length) { case 0: result = ((StdUDF0) stdUDF).eval(); @@ -231,16 +230,11 @@ protected Object eval(StdUDF stdUDF, Type[] types, boolean isIntegerReturnType, default: throw new RuntimeException("eval not supported yet for StdUDF" + args.length); } - if (result == null) { - return null; - } else if (isIntegerReturnType) { - return ((Number) ((PlatformData) result).getUnderlyingData()).longValue(); - } else { - return ((PlatformData) result).getUnderlyingData(); - } + + return TrinoWrapper.getPlatformData(result); } - private String[] getRequiredFiles(StdUDF stdUDF, StdData[] args) { + private String[] getRequiredFiles(StdUDF stdUDF, Object[] args) { String[] requiredFiles; switch (args.length) { case 0: diff --git a/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/TrinoFactory.java b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/TrinoFactory.java index 3b1bde99..46a56ba3 100644 --- a/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/TrinoFactory.java +++ b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/TrinoFactory.java @@ -5,31 +5,16 @@ */ package com.linkedin.transport.trino; -import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableSet; import com.linkedin.transport.api.StdFactory; -import com.linkedin.transport.api.data.StdArray; -import com.linkedin.transport.api.data.StdBoolean; -import com.linkedin.transport.api.data.StdBinary; -import com.linkedin.transport.api.data.StdDouble; -import com.linkedin.transport.api.data.StdFloat; -import com.linkedin.transport.api.data.StdInteger; -import com.linkedin.transport.api.data.StdLong; -import com.linkedin.transport.api.data.StdMap; -import com.linkedin.transport.api.data.StdString; -import com.linkedin.transport.api.data.StdStruct; + import com.linkedin.transport.api.types.StdType; -import com.linkedin.transport.trino.data.TrinoArray; -import com.linkedin.transport.trino.data.TrinoBoolean; -import com.linkedin.transport.trino.data.TrinoBinary; -import com.linkedin.transport.trino.data.TrinoDouble; -import com.linkedin.transport.trino.data.TrinoFloat; -import com.linkedin.transport.trino.data.TrinoInteger; -import com.linkedin.transport.trino.data.TrinoLong; -import com.linkedin.transport.trino.data.TrinoMap; -import com.linkedin.transport.trino.data.TrinoString; -import com.linkedin.transport.trino.data.TrinoStruct; -import io.airlift.slice.Slices; +import com.linkedin.transport.api.data.ArrayData; +import com.linkedin.transport.api.data.MapData; +import com.linkedin.transport.api.data.RowData; +import com.linkedin.transport.trino.data.TrinoArrayData; +import com.linkedin.transport.trino.data.TrinoMapData; +import com.linkedin.transport.trino.data.TrinoRowData; import io.trino.metadata.FunctionBinding; import io.trino.metadata.FunctionDependencies; import io.trino.metadata.Metadata; @@ -41,7 +26,6 @@ import io.trino.spi.type.RowType; import io.trino.spi.type.Type; import java.lang.invoke.MethodHandle; -import java.nio.ByteBuffer; import java.util.List; import java.util.stream.Collectors; @@ -68,71 +52,35 @@ public TrinoFactory(FunctionBinding functionBinding, Metadata metadata) { } @Override - public StdInteger createInteger(int value) { - return new TrinoInteger(value); - } - - @Override - public StdLong createLong(long value) { - return new TrinoLong(value); - } - - @Override - public StdBoolean createBoolean(boolean value) { - return new TrinoBoolean(value); - } - - @Override - public StdString createString(String value) { - Preconditions.checkNotNull(value, "Cannot create a null StdString"); - return new TrinoString(Slices.utf8Slice(value)); - } - - @Override - public StdFloat createFloat(float value) { - return new TrinoFloat(value); - } - - @Override - public StdDouble createDouble(double value) { - return new TrinoDouble(value); - } - - @Override - public StdBinary createBinary(ByteBuffer value) { - return new TrinoBinary(Slices.wrappedBuffer(value.array())); - } - - @Override - public StdArray createArray(StdType stdType, int expectedSize) { - return new TrinoArray((ArrayType) stdType.underlyingType(), expectedSize, this); + public ArrayData createArray(StdType stdType, int expectedSize) { + return new TrinoArrayData((ArrayType) stdType.underlyingType(), expectedSize, this); } @Override - public StdArray createArray(StdType stdType) { + public ArrayData createArray(StdType stdType) { return createArray(stdType, 0); } @Override - public StdMap createMap(StdType stdType) { - return new TrinoMap((MapType) stdType.underlyingType(), this); + public MapData createMap(StdType stdType) { + return new TrinoMapData((MapType) stdType.underlyingType(), this); } @Override - public TrinoStruct createStruct(List fieldNames, List fieldTypes) { - return new TrinoStruct(fieldNames, + public TrinoRowData createStruct(List fieldNames, List fieldTypes) { + return new TrinoRowData(fieldNames, fieldTypes.stream().map(stdType -> (Type) stdType.underlyingType()).collect(Collectors.toList()), this); } @Override - public TrinoStruct createStruct(List fieldTypes) { - return new TrinoStruct( + public TrinoRowData createStruct(List fieldTypes) { + return new TrinoRowData( fieldTypes.stream().map(stdType -> (Type) stdType.underlyingType()).collect(Collectors.toList()), this); } @Override - public StdStruct createStruct(StdType stdType) { - return new TrinoStruct((RowType) stdType.underlyingType(), this); + public RowData createStruct(StdType stdType) { + return new TrinoRowData((RowType) stdType.underlyingType(), this); } @Override diff --git a/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/TrinoWrapper.java b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/TrinoWrapper.java index 651daea7..2bde058c 100644 --- a/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/TrinoWrapper.java +++ b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/TrinoWrapper.java @@ -6,18 +6,12 @@ package com.linkedin.transport.trino; import com.linkedin.transport.api.StdFactory; -import com.linkedin.transport.api.data.StdData; import com.linkedin.transport.api.types.StdType; -import com.linkedin.transport.trino.data.TrinoArray; -import com.linkedin.transport.trino.data.TrinoBoolean; -import com.linkedin.transport.trino.data.TrinoBinary; -import com.linkedin.transport.trino.data.TrinoDouble; -import com.linkedin.transport.trino.data.TrinoFloat; -import com.linkedin.transport.trino.data.TrinoInteger; -import com.linkedin.transport.trino.data.TrinoLong; -import com.linkedin.transport.trino.data.TrinoMap; -import com.linkedin.transport.trino.data.TrinoString; -import com.linkedin.transport.trino.data.TrinoStruct; +import com.linkedin.transport.api.data.PlatformData; +import com.linkedin.transport.trino.data.TrinoData; +import com.linkedin.transport.trino.data.TrinoArrayData; +import com.linkedin.transport.trino.data.TrinoRowData; +import com.linkedin.transport.trino.data.TrinoMapData; import com.linkedin.transport.trino.types.TrinoArrayType; import com.linkedin.transport.trino.types.TrinoBooleanType; import com.linkedin.transport.trino.types.TrinoBinaryType; @@ -27,11 +21,13 @@ import com.linkedin.transport.trino.types.TrinoLongType; import com.linkedin.transport.trino.types.TrinoMapType; import com.linkedin.transport.trino.types.TrinoStringType; -import com.linkedin.transport.trino.types.TrinoStructType; +import com.linkedin.transport.trino.types.TrinoRowType; import com.linkedin.transport.trino.types.TrinoUnknownType; import io.airlift.slice.Slice; +import io.airlift.slice.Slices; import io.trino.spi.TrinoException; import io.trino.spi.block.Block; +import io.trino.spi.block.BlockBuilder; import io.trino.spi.type.ArrayType; import io.trino.spi.type.BigintType; import io.trino.spi.type.BooleanType; @@ -45,32 +41,36 @@ import io.trino.spi.type.VarcharType; import io.trino.type.UnknownType; +import static io.trino.spi.type.BigintType.*; +import static io.trino.spi.type.BooleanType.*; +import static io.trino.spi.type.DoubleType.*; +import static io.trino.spi.type.IntegerType.*; +import static io.trino.spi.type.VarbinaryType.*; +import static io.trino.spi.type.VarcharType.*; import static io.trino.spi.StandardErrorCode.*; import static java.lang.Float.*; import static java.lang.Math.*; import static java.lang.String.*; - +import java.nio.ByteBuffer; public final class TrinoWrapper { private TrinoWrapper() { } - public static StdData createStdData(Object trinoData, Type trinoType, StdFactory stdFactory) { + public static Object createStdData(Object trinoData, Type trinoType, StdFactory stdFactory) { if (trinoData == null) { return null; } if (trinoType instanceof IntegerType) { // Trino represents SQL Integers (i.e., corresponding to IntegerType above) as long or Long - // Therefore, to pass it to the TrinoInteger class, we first cast it to Long, then extract - // the int value. - return new TrinoInteger(((Long) trinoData).intValue()); - } else if (trinoType instanceof BigintType) { - return new TrinoLong((long) trinoData); - } else if (trinoType.getJavaType() == boolean.class) { - return new TrinoBoolean((boolean) trinoData); + // Therefore, we first cast trinoData to Long, then extract the int value. + return ((Long) trinoData).intValue(); + } else if (trinoType instanceof BigintType || trinoType.getJavaType() == boolean.class + || trinoType instanceof DoubleType) { + return trinoData; } else if (trinoType instanceof VarcharType) { - return new TrinoString((Slice) trinoData); + return ((Slice) trinoData).toStringUtf8(); } else if (trinoType instanceof RealType) { // Trino represents SQL Reals (i.e., corresponding to RealType above) as long or Long // Therefore, to pass it to the TrinoFloat class, we first cast it to Long, extract @@ -83,22 +83,69 @@ public static StdData createStdData(Object trinoData, Type trinoType, StdFactory throw new TrinoException(GENERIC_INTERNAL_ERROR, format("Value (%sb) is not a valid single-precision float", Long.toBinaryString(value))); } - return new TrinoFloat(intBitsToFloat(floatValue)); - } else if (trinoType instanceof DoubleType) { - return new TrinoDouble((double) trinoData); + return intBitsToFloat(floatValue); } else if (trinoType instanceof VarbinaryType) { - return new TrinoBinary((Slice) trinoData); + return ((Slice) trinoData).toByteBuffer(); } else if (trinoType instanceof ArrayType) { - return new TrinoArray((Block) trinoData, (ArrayType) trinoType, stdFactory); + return new TrinoArrayData((Block) trinoData, (ArrayType) trinoType, stdFactory); } else if (trinoType instanceof MapType) { - return new TrinoMap((Block) trinoData, trinoType, stdFactory); + return new TrinoMapData((Block) trinoData, trinoType, stdFactory); } else if (trinoType instanceof RowType) { - return new TrinoStruct((Block) trinoData, trinoType, stdFactory); + return new TrinoRowData((Block) trinoData, trinoType, stdFactory); } assert false : "Unrecognized Trino Type: " + trinoType.getClass(); return null; } + public static Object getPlatformData(Object transportData) { + if (transportData == null) { + return null; + } + if (transportData instanceof Integer) { + return ((Number) transportData).longValue(); + } else if (transportData instanceof Long) { + return ((Long) transportData).longValue(); + } else if (transportData instanceof Float) { + return (long) floatToIntBits((Float) transportData); + } else if (transportData instanceof Double) { + return ((Double) transportData).doubleValue(); + } else if (transportData instanceof Boolean) { + return ((Boolean) transportData).booleanValue(); + } else if (transportData instanceof String) { + return Slices.utf8Slice((String) transportData); + } else if (transportData instanceof ByteBuffer) { + return Slices.wrappedBuffer(((ByteBuffer) transportData).array()); + } else { + return ((PlatformData) transportData).getUnderlyingData(); + } + } + + public static void writeToBlock(Object transportData, BlockBuilder blockBuilder) { + if (transportData == null) { + blockBuilder.appendNull(); + } else { + if (transportData instanceof Integer) { + // This looks a bit strange, but the call to writeLong is correct here. INTEGER does not have a writeInt method for + // some reason. It uses BlockBuilder.writeInt internally. + INTEGER.writeLong(blockBuilder, (Integer) transportData); + } else if (transportData instanceof Long) { + BIGINT.writeLong(blockBuilder, (Long) transportData); + } else if (transportData instanceof Float) { + INTEGER.writeLong(blockBuilder, floatToIntBits((Float) transportData)); + } else if (transportData instanceof Double) { + DOUBLE.writeDouble(blockBuilder, (Double) transportData); + } else if (transportData instanceof Boolean) { + BOOLEAN.writeBoolean(blockBuilder, (Boolean) transportData); + } else if (transportData instanceof String) { + VARCHAR.writeSlice(blockBuilder, Slices.utf8Slice((String) transportData)); + } else if (transportData instanceof ByteBuffer) { + VARBINARY.writeSlice(blockBuilder, Slices.wrappedBuffer((ByteBuffer) transportData)); + } else { + ((TrinoData) transportData).writeToBlock(blockBuilder); + } + } + } + public static StdType createStdType(Object trinoType) { if (trinoType instanceof IntegerType) { return new TrinoIntegerType((IntegerType) trinoType); @@ -119,7 +166,7 @@ public static StdType createStdType(Object trinoType) { } else if (trinoType instanceof MapType) { return new TrinoMapType((MapType) trinoType); } else if (trinoType instanceof RowType) { - return new TrinoStructType(((RowType) trinoType)); + return new TrinoRowType(((RowType) trinoType)); } else if (trinoType instanceof UnknownType) { return new TrinoUnknownType(((UnknownType) trinoType)); } @@ -137,4 +184,4 @@ public static int checkedIndexToBlockPosition(Block block, long index) { } return -1; // -1 indicates that the element is out of range and the calling function should return null } -} +} \ No newline at end of file diff --git a/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoArray.java b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoArrayData.java similarity index 76% rename from transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoArray.java rename to transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoArrayData.java index 4d0dfa5d..3fe21ffe 100644 --- a/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoArray.java +++ b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoArrayData.java @@ -6,20 +6,19 @@ package com.linkedin.transport.trino.data; import com.linkedin.transport.api.StdFactory; -import com.linkedin.transport.api.data.StdArray; -import com.linkedin.transport.api.data.StdData; import com.linkedin.transport.trino.TrinoWrapper; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.PageBuilderStatus; import io.trino.spi.type.ArrayType; import io.trino.spi.type.Type; +import com.linkedin.transport.api.data.ArrayData; import java.util.Iterator; import static io.trino.spi.type.TypeUtils.*; -public class TrinoArray extends TrinoData implements StdArray { +public class TrinoArrayData extends TrinoData implements ArrayData { private final StdFactory _stdFactory; private final ArrayType _arrayType; @@ -28,14 +27,14 @@ public class TrinoArray extends TrinoData implements StdArray { private Block _block; private BlockBuilder _mutable; - public TrinoArray(Block block, ArrayType arrayType, StdFactory stdFactory) { + public TrinoArrayData(Block block, ArrayType arrayType, StdFactory stdFactory) { _block = block; _arrayType = arrayType; _elementType = arrayType.getElementType(); _stdFactory = stdFactory; } - public TrinoArray(ArrayType arrayType, int expectedEntries, StdFactory stdFactory) { + public TrinoArrayData(ArrayType arrayType, int expectedEntries, StdFactory stdFactory) { _block = null; _elementType = arrayType.getElementType(); _mutable = _elementType.createBlockBuilder(new PageBuilderStatus().createBlockBuilderStatus(), expectedEntries); @@ -49,19 +48,19 @@ public int size() { } @Override - public StdData get(int idx) { + public E get(int idx) { Block sourceBlock = _mutable == null ? _block : _mutable; int position = TrinoWrapper.checkedIndexToBlockPosition(sourceBlock, idx); Object element = readNativeValue(_elementType, sourceBlock, position); - return TrinoWrapper.createStdData(element, _elementType, _stdFactory); + return (E) TrinoWrapper.createStdData(element, _elementType, _stdFactory); } @Override - public void add(StdData e) { + public void add(E e) { if (_mutable == null) { _mutable = _elementType.createBlockBuilder(new PageBuilderStatus().createBlockBuilderStatus(), 1); } - ((TrinoData) e).writeToBlock(_mutable); + TrinoWrapper.writeToBlock(e, _mutable); } @Override @@ -75,10 +74,10 @@ public void setUnderlyingData(Object value) { } @Override - public Iterator iterator() { - return new Iterator() { + public Iterator iterator() { + return new Iterator() { Block sourceBlock = _mutable == null ? _block : _mutable; - int size = TrinoArray.this.size(); + int size = TrinoArrayData.this.size(); int position = 0; @Override @@ -87,10 +86,10 @@ public boolean hasNext() { } @Override - public StdData next() { + public E next() { Object element = readNativeValue(_elementType, sourceBlock, position); position++; - return TrinoWrapper.createStdData(element, _elementType, _stdFactory); + return (E) TrinoWrapper.createStdData(element, _elementType, _stdFactory); } }; } diff --git a/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoBinary.java b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoBinary.java deleted file mode 100644 index 9fa7914b..00000000 --- a/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoBinary.java +++ /dev/null @@ -1,42 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.trino.data; - -import com.linkedin.transport.api.data.StdBinary; -import io.airlift.slice.Slice; -import io.trino.spi.block.BlockBuilder; -import java.nio.ByteBuffer; - -import static io.trino.spi.type.VarbinaryType.*; - -public class TrinoBinary extends TrinoData implements StdBinary { - - private Slice _slice; - - public TrinoBinary(Slice slice) { - _slice = slice; - } - - @Override - public ByteBuffer get() { - return _slice.toByteBuffer(); - } - - @Override - public Object getUnderlyingData() { - return _slice; - } - - @Override - public void setUnderlyingData(Object value) { - _slice = (Slice) value; - } - - @Override - public void writeToBlock(BlockBuilder blockBuilder) { - VARBINARY.writeSlice(blockBuilder, _slice); - } -} diff --git a/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoBoolean.java b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoBoolean.java deleted file mode 100644 index 9b6c9e23..00000000 --- a/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoBoolean.java +++ /dev/null @@ -1,41 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.trino.data; - -import com.linkedin.transport.api.data.StdBoolean; -import io.trino.spi.block.BlockBuilder; - -import static io.trino.spi.type.BooleanType.*; - - -public class TrinoBoolean extends TrinoData implements StdBoolean { - - boolean _value; - - public TrinoBoolean(boolean value) { - _value = value; - } - - @Override - public boolean get() { - return _value; - } - - @Override - public Object getUnderlyingData() { - return _value; - } - - @Override - public void setUnderlyingData(Object value) { - _value = (boolean) value; - } - - @Override - public void writeToBlock(BlockBuilder blockBuilder) { - BOOLEAN.writeBoolean(blockBuilder, _value); - } -} diff --git a/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoDouble.java b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoDouble.java deleted file mode 100644 index 6e3567ec..00000000 --- a/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoDouble.java +++ /dev/null @@ -1,41 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.trino.data; - -import com.linkedin.transport.api.data.StdDouble; -import io.trino.spi.block.BlockBuilder; - -import static io.trino.spi.type.DoubleType.*; - - -public class TrinoDouble extends TrinoData implements StdDouble { - - private double _double; - - public TrinoDouble(double aDouble) { - _double = aDouble; - } - - @Override - public double get() { - return _double; - } - - @Override - public Object getUnderlyingData() { - return _double; - } - - @Override - public void setUnderlyingData(Object value) { - _double = (double) value; - } - - @Override - public void writeToBlock(BlockBuilder blockBuilder) { - DOUBLE.writeDouble(blockBuilder, _double); - } -} \ No newline at end of file diff --git a/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoFloat.java b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoFloat.java deleted file mode 100644 index 16893bcc..00000000 --- a/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoFloat.java +++ /dev/null @@ -1,41 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.trino.data; - -import com.linkedin.transport.api.data.StdFloat; -import io.trino.spi.block.BlockBuilder; - -import static java.lang.Float.*; - - -public class TrinoFloat extends TrinoData implements StdFloat { - - private float _float; - - public TrinoFloat(float aFloat) { - _float = aFloat; - } - - @Override - public float get() { - return _float; - } - - @Override - public Object getUnderlyingData() { - return (long) floatToIntBits(_float); - } - - @Override - public void setUnderlyingData(Object value) { - _float = (float) value; - } - - @Override - public void writeToBlock(BlockBuilder blockBuilder) { - blockBuilder.writeInt(floatToIntBits(_float)); - } -} diff --git a/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoInteger.java b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoInteger.java deleted file mode 100644 index bc52ad62..00000000 --- a/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoInteger.java +++ /dev/null @@ -1,43 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.trino.data; - -import com.linkedin.transport.api.data.StdInteger; -import io.trino.spi.block.BlockBuilder; - -import static io.trino.spi.type.IntegerType.*; - - -public class TrinoInteger extends TrinoData implements StdInteger { - - int _integer; - - public TrinoInteger(int integer) { - _integer = integer; - } - - @Override - public int get() { - return _integer; - } - - @Override - public Object getUnderlyingData() { - return _integer; - } - - @Override - public void setUnderlyingData(Object value) { - _integer = ((Long) value).intValue(); - } - - @Override - public void writeToBlock(BlockBuilder blockBuilder) { - // It looks a bit strange, but the call to writeLong is correct here. INTEGER does not have a writeInt method for - // some reason. It uses BlockBuilder.writeInt internally. - INTEGER.writeLong(blockBuilder, _integer); - } -} diff --git a/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoLong.java b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoLong.java deleted file mode 100644 index 5f842938..00000000 --- a/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoLong.java +++ /dev/null @@ -1,41 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.trino.data; - -import com.linkedin.transport.api.data.StdLong; -import io.trino.spi.block.BlockBuilder; - -import static io.trino.spi.type.BigintType.*; - - -public class TrinoLong extends TrinoData implements StdLong { - - long _value; - - public TrinoLong(long value) { - _value = value; - } - - @Override - public long get() { - return _value; - } - - @Override - public Object getUnderlyingData() { - return _value; - } - - @Override - public void setUnderlyingData(Object value) { - _value = (long) value; - } - - @Override - public void writeToBlock(BlockBuilder blockBuilder) { - BIGINT.writeLong(blockBuilder, _value); - } -} diff --git a/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoMap.java b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoMapData.java similarity index 74% rename from transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoMap.java rename to transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoMapData.java index 73c74637..0bd38ad0 100644 --- a/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoMap.java +++ b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoMapData.java @@ -8,9 +8,6 @@ import com.google.common.base.Throwables; import com.google.common.collect.ImmutableList; import com.linkedin.transport.api.StdFactory; -import com.linkedin.transport.api.data.PlatformData; -import com.linkedin.transport.api.data.StdData; -import com.linkedin.transport.api.data.StdMap; import com.linkedin.transport.trino.TrinoFactory; import com.linkedin.transport.trino.TrinoWrapper; import io.trino.spi.TrinoException; @@ -20,6 +17,7 @@ import io.trino.spi.function.OperatorType; import io.trino.spi.type.MapType; import io.trino.spi.type.Type; +import com.linkedin.transport.api.data.MapData; import java.lang.invoke.MethodHandle; import java.util.AbstractCollection; import java.util.AbstractSet; @@ -34,7 +32,7 @@ import static io.trino.spi.type.TypeUtils.*; -public class TrinoMap extends TrinoData implements StdMap { +public class TrinoMapData extends TrinoData implements MapData { final Type _keyType; final Type _valueType; @@ -43,7 +41,7 @@ public class TrinoMap extends TrinoData implements StdMap { final StdFactory _stdFactory; Block _block; - public TrinoMap(Type mapType, StdFactory stdFactory) { + public TrinoMapData(Type mapType, StdFactory stdFactory) { BlockBuilder mutable = mapType.createBlockBuilder(new PageBuilderStatus().createBlockBuilderStatus(), 1); mutable.beginBlockEntry(); mutable.closeEntry(); @@ -58,7 +56,7 @@ public TrinoMap(Type mapType, StdFactory stdFactory) { OperatorType.EQUAL, ImmutableList.of(_keyType, _keyType), simpleConvention(NULLABLE_RETURN, NEVER_NULL, NEVER_NULL)); } - public TrinoMap(Block block, Type mapType, StdFactory stdFactory) { + public TrinoMapData(Block block, Type mapType, StdFactory stdFactory) { this(mapType, stdFactory); _block = block; } @@ -69,13 +67,12 @@ public int size() { } @Override - public StdData get(StdData key) { - Object trinoKey = ((PlatformData) key).getUnderlyingData(); - int i = seekKey(trinoKey); + public V get(K key) { + Object prestoKey = TrinoWrapper.getPlatformData(key); + int i = seekKey(prestoKey); if (i != -1) { Object value = readNativeValue(_valueType, _block, i); - StdData stdValue = TrinoWrapper.createStdData(value, _valueType, _stdFactory); - return stdValue; + return (V) TrinoWrapper.createStdData(value, _valueType, _stdFactory); } else { return null; } @@ -84,10 +81,10 @@ public StdData get(StdData key) { // TODO: Do not copy the _mutable BlockBuilder on every update. As long as updates are append-only or for fixed-size // types, we can skip copying. @Override - public void put(StdData key, StdData value) { + public void put(K key, V value) { BlockBuilder mutable = _mapType.createBlockBuilder(new PageBuilderStatus().createBlockBuilderStatus(), 1); BlockBuilder entryBuilder = mutable.beginBlockEntry(); - Object trinoKey = ((PlatformData) key).getUnderlyingData(); + Object trinoKey = TrinoWrapper.getPlatformData(key); int valuePosition = seekKey(trinoKey); for (int i = 0; i < _block.getPositionCount(); i += 2) { // Write the current key to the map @@ -95,26 +92,26 @@ public void put(StdData key, StdData value) { // Find out if we need to change the corresponding value if (i == valuePosition - 1) { // Use the user-supplied value - ((TrinoData) value).writeToBlock(entryBuilder); + TrinoWrapper.writeToBlock(value, entryBuilder); } else { // Use the existing value in original _block _valueType.appendTo(_block, i + 1, entryBuilder); } } if (valuePosition == -1) { - ((TrinoData) key).writeToBlock(entryBuilder); - ((TrinoData) value).writeToBlock(entryBuilder); + TrinoWrapper.writeToBlock(key, entryBuilder); + TrinoWrapper.writeToBlock(value, entryBuilder); } mutable.closeEntry(); _block = ((MapType) _mapType).getObject(mutable.build(), 0); } - public Set keySet() { - return new AbstractSet() { + public Set keySet() { + return new AbstractSet() { @Override - public Iterator iterator() { - return new Iterator() { + public Iterator iterator() { + return new Iterator() { int i = -2; @Override @@ -123,27 +120,27 @@ public boolean hasNext() { } @Override - public StdData next() { + public K next() { i += 2; - return TrinoWrapper.createStdData(readNativeValue(_keyType, _block, i), _keyType, _stdFactory); + return (K) TrinoWrapper.createStdData(readNativeValue(_keyType, _block, i), _keyType, _stdFactory); } }; } @Override public int size() { - return TrinoMap.this.size(); + return TrinoMapData.this.size(); } }; } @Override - public Collection values() { - return new AbstractCollection() { + public Collection values() { + return new AbstractCollection() { @Override - public Iterator iterator() { - return new Iterator() { + public Iterator iterator() { + return new Iterator() { int i = -2; @Override @@ -152,22 +149,25 @@ public boolean hasNext() { } @Override - public StdData next() { + public V next() { i += 2; - return TrinoWrapper.createStdData(readNativeValue(_valueType, _block, i + 1), _valueType, _stdFactory); + return + (V) TrinoWrapper.createStdData( + readNativeValue(_valueType, _block, i + 1), _valueType, _stdFactory + ); } }; } @Override public int size() { - return TrinoMap.this.size(); + return TrinoMapData.this.size(); } }; } @Override - public boolean containsKey(StdData key) { + public boolean containsKey(K key) { return get(key) != null; } diff --git a/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoStruct.java b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoRowData.java similarity index 83% rename from transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoStruct.java rename to transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoRowData.java index c94ae335..74d724fa 100644 --- a/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoStruct.java +++ b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoRowData.java @@ -6,8 +6,6 @@ package com.linkedin.transport.trino.data; import com.linkedin.transport.api.StdFactory; -import com.linkedin.transport.api.data.StdData; -import com.linkedin.transport.api.data.StdStruct; import com.linkedin.transport.trino.TrinoWrapper; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; @@ -15,6 +13,7 @@ import io.trino.spi.block.PageBuilderStatus; import io.trino.spi.type.RowType; import io.trino.spi.type.Type; +import com.linkedin.transport.api.data.RowData; import java.util.ArrayList; import java.util.List; import java.util.Optional; @@ -24,28 +23,28 @@ import static io.trino.spi.type.TypeUtils.*; -public class TrinoStruct extends TrinoData implements StdStruct { +public class TrinoRowData extends TrinoData implements RowData { final RowType _rowType; final StdFactory _stdFactory; Block _block; - public TrinoStruct(Type rowType, StdFactory stdFactory) { + public TrinoRowData(Type rowType, StdFactory stdFactory) { _rowType = (RowType) rowType; _stdFactory = stdFactory; } - public TrinoStruct(Block block, Type rowType, StdFactory stdFactory) { + public TrinoRowData(Block block, Type rowType, StdFactory stdFactory) { this(rowType, stdFactory); _block = block; } - public TrinoStruct(List fieldTypes, StdFactory stdFactory) { + public TrinoRowData(List fieldTypes, StdFactory stdFactory) { _stdFactory = stdFactory; _rowType = RowType.anonymous(fieldTypes); } - public TrinoStruct(List fieldNames, List fieldTypes, StdFactory stdFactory) { + public TrinoRowData(List fieldNames, List fieldTypes, StdFactory stdFactory) { _stdFactory = stdFactory; List fields = IntStream.range(0, fieldNames.size()) .mapToObj(i -> new RowType.Field(Optional.ofNullable(fieldNames.get(i)), fieldTypes.get(i))) @@ -54,7 +53,7 @@ public TrinoStruct(List fieldNames, List fieldTypes, StdFactory st } @Override - public StdData getField(int index) { + public Object getField(int index) { int position = TrinoWrapper.checkedIndexToBlockPosition(_block, index); if (position == -1) { return null; @@ -65,7 +64,7 @@ public StdData getField(int index) { } @Override - public StdData getField(String name) { + public Object getField(String name) { int index = -1; Type elementType = null; int i = 0; @@ -85,7 +84,7 @@ public StdData getField(String name) { } @Override - public void setField(int index, StdData value) { + public void setField(int index, Object value) { // TODO: This is not the right way to get this object. The status should be passed in from the invocation of the // function and propagated to here. See PRESTO-1359 for more details. BlockBuilderStatus blockBuilderStatus = new PageBuilderStatus().createBlockBuilderStatus(); @@ -94,7 +93,7 @@ public void setField(int index, StdData value) { int i = 0; for (RowType.Field field : _rowType.getFields()) { if (i == index) { - ((TrinoData) value).writeToBlock(rowBlockBuilder); + TrinoWrapper.writeToBlock(value, rowBlockBuilder); } else { if (_block == null) { rowBlockBuilder.appendNull(); @@ -109,13 +108,13 @@ public void setField(int index, StdData value) { } @Override - public void setField(String name, StdData value) { + public void setField(String name, Object value) { BlockBuilder mutable = _rowType.createBlockBuilder(new PageBuilderStatus().createBlockBuilderStatus(), 1); BlockBuilder rowBlockBuilder = mutable.beginBlockEntry(); int i = 0; for (RowType.Field field : _rowType.getFields()) { if (field.getName().isPresent() && name.equals(field.getName().get())) { - ((TrinoData) value).writeToBlock(rowBlockBuilder); + TrinoWrapper.writeToBlock(value, rowBlockBuilder); } else { if (_block == null) { rowBlockBuilder.appendNull(); @@ -130,8 +129,8 @@ public void setField(String name, StdData value) { } @Override - public List fields() { - ArrayList fields = new ArrayList<>(); + public List fields() { + ArrayList fields = new ArrayList<>(); for (int i = 0; i < _block.getPositionCount(); i++) { Type elementType = _rowType.getFields().get(i).getType(); Object element = readNativeValue(elementType, _block, i); diff --git a/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoString.java b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoString.java deleted file mode 100644 index 5fc9e7f7..00000000 --- a/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoString.java +++ /dev/null @@ -1,42 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.trino.data; - -import com.linkedin.transport.api.data.StdString; -import io.airlift.slice.Slice; -import io.trino.spi.block.BlockBuilder; - -import static io.trino.spi.type.VarcharType.*; - - -public class TrinoString extends TrinoData implements StdString { - - Slice _slice; - - public TrinoString(Slice slice) { - _slice = slice; - } - - @Override - public String get() { - return _slice.toStringUtf8(); - } - - @Override - public Object getUnderlyingData() { - return _slice; - } - - @Override - public void setUnderlyingData(Object value) { - _slice = (Slice) value; - } - - @Override - public void writeToBlock(BlockBuilder blockBuilder) { - VARCHAR.writeSlice(blockBuilder, _slice); - } -} diff --git a/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/types/TrinoStructType.java b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/types/TrinoRowType.java similarity index 75% rename from transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/types/TrinoStructType.java rename to transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/types/TrinoRowType.java index ae44e08a..e4894727 100644 --- a/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/types/TrinoStructType.java +++ b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/types/TrinoRowType.java @@ -5,19 +5,18 @@ */ package com.linkedin.transport.trino.types; -import com.linkedin.transport.api.types.StdStructType; +import com.linkedin.transport.api.types.RowType; import com.linkedin.transport.api.types.StdType; import com.linkedin.transport.trino.TrinoWrapper; -import io.trino.spi.type.RowType; import java.util.List; import java.util.stream.Collectors; -public class TrinoStructType implements StdStructType { +public class TrinoRowType implements RowType { - final RowType rowType; + final io.trino.spi.type.RowType rowType; - public TrinoStructType(RowType rowType) { + public TrinoRowType(io.trino.spi.type.RowType rowType) { this.rowType = rowType; }