Skip to content

Commit

Permalink
Register "passthrough" UDFs with correct ordinal return type (#541)
Browse files Browse the repository at this point in the history
* register spark groot udf with correct return type

* test showcase

* extend + refactor unit test

---------

Co-authored-by: Limian Zhang <[email protected]>
  • Loading branch information
KevinGe00 and rzhang10 authored Oct 17, 2024
1 parent b507761 commit c51d3b5
Show file tree
Hide file tree
Showing 6 changed files with 177 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -675,10 +675,12 @@ public boolean isOptional(int i) {

createAddUserDefinedFunction("com.linkedin.groot.runtime.udf.spark.HasMemberConsentUDF", ReturnTypes.BOOLEAN,
family(SqlTypeFamily.STRING, SqlTypeFamily.ANY, SqlTypeFamily.TIMESTAMP));
createAddUserDefinedFunction("com.linkedin.groot.runtime.udf.spark.RedactFieldIfUDF", ARG1,
createAddUserDefinedFunction("com.linkedin.groot.runtime.udf.spark.RedactFieldIfUDF",
new OrdinalReturnTypeInferenceV2(1),
family(SqlTypeFamily.BOOLEAN, SqlTypeFamily.ANY, SqlTypeFamily.STRING, SqlTypeFamily.ANY));
createAddUserDefinedFunction("com.linkedin.groot.runtime.udf.spark.RedactSecondarySchemaFieldIfUDF", ARG1, family(
SqlTypeFamily.BOOLEAN, SqlTypeFamily.ANY, SqlTypeFamily.ARRAY, SqlTypeFamily.STRING, SqlTypeFamily.STRING));
createAddUserDefinedFunction("com.linkedin.groot.runtime.udf.spark.RedactSecondarySchemaFieldIfUDF",
new OrdinalReturnTypeInferenceV2(1), family(SqlTypeFamily.BOOLEAN, SqlTypeFamily.ANY, SqlTypeFamily.ARRAY,
SqlTypeFamily.STRING, SqlTypeFamily.STRING));
createAddUserDefinedFunction("com.linkedin.groot.runtime.udf.spark.GetMappedValueUDF", FunctionReturnTypes.STRING,
family(SqlTypeFamily.STRING, SqlTypeFamily.STRING));
createAddUserDefinedFunction("com.linkedin.groot.runtime.udf.spark.ExtractCollectionUDF",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/**
* Copyright 2021-2024 LinkedIn Corporation. All rights reserved.
* Licensed under the BSD-2 Clause license.
* See LICENSE in the project root for license information.
*/
package com.linkedin.coral.hive.hive2rel;

import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
import org.apache.hadoop.hive.serde2.objectinspector.*;


@Description(name = "return_second_arg_struct_udf",
value = "_FUNC_(string, struct) - Returns the second argument (struct) as-is")
public class CoralTestUDFReturnSecondArg extends GenericUDF {

private transient ObjectInspector stringOI;
private transient StructObjectInspector structOI;

@Override
public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException {
// Check the number of arguments
if (arguments.length != 2) {
throw new UDFArgumentLengthException(
"return_struct_udf() requires exactly two arguments: a string and a struct.");
}

// Validate the first argument (string)
if (arguments[0].getCategory() != ObjectInspector.Category.PRIMITIVE || ((PrimitiveObjectInspector) arguments[0])
.getPrimitiveCategory() != PrimitiveObjectInspector.PrimitiveCategory.STRING) {
throw new UDFArgumentException("The first argument must be a string.");
}

// Validate the second argument (struct)
if (arguments[1].getCategory() != ObjectInspector.Category.STRUCT) {
throw new UDFArgumentException("The second argument must be a struct.");
}

// Initialize ObjectInspectors
stringOI = arguments[0];
structOI = (StructObjectInspector) arguments[1];

// Return the ObjectInspector for the struct (second argument)
return structOI;
}

@Override
public Object evaluate(DeferredObject[] arguments) throws HiveException {
// Simply return the second argument as-is
Object structObj = arguments[1].get();
return structObj;
}

@Override
public String getDisplayString(String[] children) {
return "return_struct_udf(" + children[0] + ", " + children[1] + ")";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import com.linkedin.coral.common.HiveMetastoreClient;
import com.linkedin.coral.common.HiveMscAdapter;
import com.linkedin.coral.common.functions.FunctionReturnTypes;
import com.linkedin.coral.hive.hive2rel.functions.OrdinalReturnTypeInferenceV2;
import com.linkedin.coral.hive.hive2rel.functions.StaticHiveFunctionRegistry;

import static org.apache.calcite.sql.type.OperandTypes.*;
Expand Down Expand Up @@ -81,6 +82,9 @@ public static void registerUdfs() {
"com.linkedin.coral.hive.hive2rel.CoralTestUDFReturnStruct", FunctionReturnTypes
.rowOf(ImmutableList.of("isEven", "number"), ImmutableList.of(SqlTypeName.BOOLEAN, SqlTypeName.INTEGER)),
family(SqlTypeFamily.INTEGER));
StaticHiveFunctionRegistry.createAddUserDefinedFunction(
"com.linkedin.coral.hive.hive2rel.CoralTestUDFReturnSecondArg", new OrdinalReturnTypeInferenceV2(1),
family(SqlTypeFamily.STRING, SqlTypeFamily.ANY));
}

private static void initializeTables() {
Expand All @@ -104,6 +108,7 @@ private static void initializeTables() {
String baseComplexNullableWithDefaults = loadSchema("base-complex-nullable-with-defaults.avsc");
String basePrimitive = loadSchema("base-primitive.avsc");
String baseComplexNestedStructSameName = loadSchema("base-complex-nested-struct-same-name.avsc");
String baseComplexMixedNullabilities = loadSchema("base-complex-mixed-nullabilities.avsc");

executeCreateTableQuery("default", "basecomplex", baseComplexSchema);
executeCreateTableQuery("default", "basecomplexunioncompatible", baseComplexUnionCompatible);
Expand All @@ -125,6 +130,7 @@ private static void initializeTables() {
executeCreateTableWithPartitionQuery("default", "basenestedcomplex", baseNestedComplexSchema);
executeCreateTableWithPartitionQuery("default", "basecomplexnullablewithdefaults", baseComplexNullableWithDefaults);
executeCreateTableWithPartitionQuery("default", "basecomplexnonnullable", baseComplexNonNullable);
executeCreateTableWithPartitionQuery("default", "basecomplexmixednullabilities", baseComplexMixedNullabilities);

String baseComplexSchemaWithDoc = loadSchema("docTestResources/base-complex-with-doc.avsc");
String baseEnumSchemaWithDoc = loadSchema("docTestResources/base-enum-with-doc.avsc");
Expand Down Expand Up @@ -170,6 +176,9 @@ private static void initializeUdfs() {

executeCreateFunctionQuery("default", Collections.singletonList("foo_udf_return_struct"), "FuncIsEven",
"com.linkedin.coral.hive.hive2rel.CoralTestUDFReturnStruct");

executeCreateFunctionQuery("default", Collections.singletonList("innerfield_with_udf"), "ReturnInnerStuct",
"com.linkedin.coral.hive.hive2rel.CoralTestUDFReturnSecondArg");
}

private static void executeCreateTableQuery(String dbName, String tableName, String schema) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,25 @@ public void testUdfLessThanHundred() {
Assert.assertEquals(actualSchema.toString(true), TestUtils.loadSchema("testUdfLessThanHundred-expected.avsc"));
}

@Test
public void testPreserveNullabilitiesAfterApplyingOrdinalReturnTypeUDF() {
String viewSql = "CREATE VIEW innerfield_with_udf "
+ "tblproperties('functions' = 'ReturnInnerStuct:com.linkedin.coral.hive.hive2rel.CoralTestUDFReturnSecondArg', "
+ " 'dependencies' = 'ivy://com.linkedin:udf:1.0') " + "AS "
+ "SELECT default_innerfield_with_udf_ReturnInnerStuct('foo', innerRecord) AS innerRecord "
+ "FROM basecomplexmixednullabilities";

TestUtils.executeCreateViewQuery("default", "innerfield_with_udf", viewSql);

ViewToAvroSchemaConverter viewToAvroSchemaConverter = ViewToAvroSchemaConverter.create(hiveMetastoreClient);
Schema actualSchema = viewToAvroSchemaConverter.toAvroSchema("default", "innerfield_with_udf");

// Expect all fields to retain their nullability after applying the UDF, CoralTestUDFReturnSecondArg, that simply
// returns the second argument as is
Assert.assertEquals(actualSchema.toString(true),
TestUtils.loadSchema("testPreserveNullabilitiesAfterApplyingOrdinalReturnTypeUDF-expected.avsc"));
}

@Test
public void testUdfGreaterThanHundred() {
String viewSql = "CREATE VIEW foo_dali_udf2 "
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
{
"type": "record",
"name": "OuterRecord",
"fields": [
{
"name": "innerRecord",
"type": {
"type": "record",
"name": "InnerRecord",
"fields": [
{
"name": "String_Field_Non_Nullable",
"type": "string"
},
{
"name": "String_Field_Nullable",
"type": [ "string", "null" ]
},
{
"name" : "Int_Field_Non_Nullable",
"type" : "int"
},
{
"name" : "Int_Field_Nullable",
"type" : [ "int", "null" ]
},
{
"name" : "Array_Col_Non_Nullable",
"type" : {
"type" : "array",
"items" : "string"
}
},
{
"name" : "Array_Col_Nullable",
"type" : [ "null", {
"type" : "array",
"items" : [ "null", "string" ]
} ]
}
]
}
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
{
"type" : "record",
"name" : "innerfield_with_udf",
"namespace" : "default.innerfield_with_udf",
"fields" : [ {
"name" : "innerRecord",
"type" : {
"type" : "record",
"name" : "InnerRecord",
"namespace" : "default.innerfield_with_udf.innerfield_with_udf",
"fields" : [ {
"name" : "String_Field_Non_Nullable",
"type" : "string"
}, {
"name" : "String_Field_Nullable",
"type" : [ "string", "null" ]
}, {
"name" : "Int_Field_Non_Nullable",
"type" : "int"
}, {
"name" : "Int_Field_Nullable",
"type" : [ "int", "null" ]
}, {
"name" : "Array_Col_Non_Nullable",
"type" : {
"type" : "array",
"items" : "string"
}
}, {
"name" : "Array_Col_Nullable",
"type" : [ "null", {
"type" : "array",
"items" : [ "null", "string" ]
} ]
} ]
}
} ]
}

0 comments on commit c51d3b5

Please sign in to comment.