Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
package com.linkedin.transport.test.trino;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.linkedin.transport.test.spi.Row;
import com.linkedin.transport.test.spi.TestCase;
Expand All @@ -23,7 +22,6 @@
import io.trino.spi.connector.ConnectorFactory;
import io.trino.spi.connector.ConnectorMetadata;
import io.trino.spi.function.BoundSignature;
import io.trino.metadata.FunctionBinding;
import io.trino.spi.function.FunctionId;
import com.linkedin.transport.api.StdFactory;
import com.linkedin.transport.api.udf.StdUDF;
Expand Down Expand Up @@ -103,12 +101,9 @@ public Connector create(String catalogName, Map<String, String> config, Connecto
@Override
public StdFactory getStdFactory() {
if (_stdFactory == null) {
FunctionBinding functionBinding = new FunctionBinding(
new FunctionId("test"),
_stdFactory = new TrinoFactory(
new BoundSignature("test", UNKNOWN, ImmutableList.of()),
ImmutableMap.of(),
ImmutableMap.of());
_stdFactory = new TrinoFactory(functionBinding, new TrinoTestFunctionDependencies(InternalTypeManager.TESTING_TYPE_MANAGER, _runner));
new TrinoTestFunctionDependencies(InternalTypeManager.TESTING_TYPE_MANAGER, _runner));
}
return _stdFactory;
}
Expand Down
4 changes: 3 additions & 1 deletion transportable-udfs-trino-plugin/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ dependencies {
implementation (group:'io.airlift', name: 'log', version: '221')
implementation (group:'com.google.guava', name: 'guava', version: '24.1-jre')
implementation (group:'io.trino', name: 'trino-plugin-toolkit', version: project.ext.'trino-version')
runtimeOnly (group:'io.trino', name: 'trino-main', version: project.ext.'trino-version')
runtimeOnly (group:'io.trino', name: 'trino-main', version: project.ext.'trino-version') {
exclude 'group': 'io.trino', 'module': 'trino-spi'
}
compileOnly(group:'io.trino', name: 'trino-spi', version: project.ext.'trino-version')
testImplementation (group:'io.trino', name: 'trino-main', version: project.ext.'trino-version')
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,17 @@
import com.google.common.collect.ImmutableSet;
import com.linkedin.transport.typesystem.TypeSignature;
import com.linkedin.transport.typesystem.TypeSignatureElement;
import io.trino.spi.TrinoException;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.Set;
import java.util.stream.Collectors;

import static com.linkedin.transport.typesystem.ConcreteTypeSignatureElement.*;
import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR;


/**
Expand Down Expand Up @@ -49,6 +53,14 @@ static String quoteReservedKeywords(String signature) {
return toTrinoTypeSignatureString(TypeSignature.parse(signature));
}

public static MethodHandle methodHandle(Class<?> clazz, String name, Class<?>... parameterTypes) {
try {
return MethodHandles.lookup().unreflect(clazz.getMethod(name, parameterTypes));
} catch (IllegalAccessException | NoSuchMethodException e) {
throw new TrinoException(GENERIC_INTERNAL_ERROR, e);
}
}

private static String toTrinoTypeSignatureString(TypeSignature typeSignature) {
final TypeSignatureElement typeSignatureBase = typeSignature.getBase();
if (BOOLEAN.equals(typeSignatureBase)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,15 @@
import com.linkedin.transport.api.udf.StdUDF8;
import com.linkedin.transport.api.udf.TopLevelStdUDF;
import com.linkedin.transport.typesystem.GenericTypeSignatureElement;
import io.trino.metadata.FunctionBinding;
import io.trino.metadata.SignatureBinder;
import io.trino.spi.function.BoundSignature;
import io.trino.spi.function.FunctionDependencies;
import io.trino.spi.function.FunctionDependencyDeclaration;
import io.trino.spi.function.FunctionKind;
import io.trino.spi.function.FunctionMetadata;
import io.trino.spi.function.ScalarFunctionAdapter;
import io.trino.spi.function.ScalarFunctionImplementation;
import io.trino.spi.function.Signature;
import io.trino.spi.function.TypeVariableConstraint;
import io.trino.operator.scalar.ChoicesSpecializedSqlScalarFunction;
import io.trino.spi.classloader.ThreadContextClassLoader;
import io.trino.spi.function.InvocationConvention;
import io.trino.spi.type.ArrayType;
Expand All @@ -54,13 +52,15 @@
import java.util.stream.IntStream;
import org.apache.commons.lang3.ClassUtils;

import static com.linkedin.transport.trino.StdUDFUtils.methodHandle;
import static com.linkedin.transport.trino.StdUDFUtils.quoteReservedKeywords;
import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.*;
import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BOXED_NULLABLE;
import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL;
import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.NULLABLE_RETURN;
import static io.trino.spi.function.OperatorType.*;
import static io.trino.spi.function.TypeVariableConstraint.*;
import static io.trino.spi.function.ScalarFunctionAdapter.NullAdaptationPolicy.RETURN_NULL_ON_NULL;
import static io.trino.spi.function.OperatorType.EQUAL;
import static io.trino.spi.function.TypeVariableConstraint.typeVariable;
import static io.trino.sql.analyzer.TypeSignatureTranslator.parseTypeSignature;
import static io.trino.util.Reflection.*;

// Suppressing argument naming convention for the evalInternal methods
@SuppressWarnings({"checkstyle:regexpsinglelinejava"})
Expand All @@ -70,6 +70,7 @@ public abstract class StdUdfWrapper {
private static final int JITTER_FACTOR = 50; // to calculate jitter from delay

private final FunctionMetadata functionMetadata;
private final ScalarFunctionAdapter functionAdapter = new ScalarFunctionAdapter(RETURN_NULL_ON_NULL);

public StdUdfWrapper(StdUDF stdUDF) {
this.functionMetadata = FunctionMetadata.builder(FunctionKind.SCALAR)
Expand Down Expand Up @@ -133,8 +134,7 @@ public FunctionDependencyDeclaration getFunctionDependencies(BoundSignature boun

public ScalarFunctionImplementation getScalarFunctionImplementation(BoundSignature boundSignature,
FunctionDependencies functionDependencies, InvocationConvention invocationConvention) {
FunctionBinding functionBinding = SignatureBinder.bindFunction(functionMetadata.getFunctionId(), functionMetadata.getSignature(), boundSignature);
StdFactory stdFactory = new TrinoFactory(functionBinding, functionDependencies);
StdFactory stdFactory = new TrinoFactory(boundSignature, functionDependencies);
StdUDF stdUDF = getStdUDF();
stdUDF.init(stdFactory);
// Subtract a small jitter value so that refresh is triggered on first call
Expand All @@ -145,12 +145,25 @@ public ScalarFunctionImplementation getScalarFunctionImplementation(BoundSignatu
- (new Random()).nextInt(initialJitterInt));
boolean[] nullableArguments = stdUDF.getAndCheckNullableArguments();

ScalarFunctionImplementation res = new ChoicesSpecializedSqlScalarFunction(
return internalGetScalarFunctionImplementation(
boundSignature,
NULLABLE_RETURN,
getMethodHandle(stdUDF, boundSignature, nullableArguments, requiredFilesNextRefreshTime),
getNullConventionForArguments(nullableArguments),
getMethodHandle(stdUDF, boundSignature, nullableArguments, requiredFilesNextRefreshTime)).getScalarFunctionImplementation(invocationConvention);
return res;
invocationConvention
);
}

private ScalarFunctionImplementation internalGetScalarFunctionImplementation(BoundSignature boundSignature, MethodHandle methodHandle,
List<InvocationConvention.InvocationArgumentConvention> nullableArguments, InvocationConvention invocationConvention) {
InvocationConvention actualConvention = new InvocationConvention(nullableArguments, NULLABLE_RETURN, false, false);
MethodHandle internalMethodHandle = functionAdapter.adapt(
methodHandle,
boundSignature.getArgumentTypes(),
actualConvention,
invocationConvention
);
return ScalarFunctionImplementation.builder().methodHandle(internalMethodHandle)
.lambdaInterfaces(ImmutableList.of()).build();
}

private MethodHandle getMethodHandle(StdUDF stdUDF, BoundSignature boundSignature, boolean[] nullableArguments,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
package com.linkedin.transport.trino;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableSet;
import com.linkedin.transport.api.StdFactory;
import com.linkedin.transport.api.data.StdArray;
import com.linkedin.transport.api.data.StdBoolean;
Expand All @@ -30,7 +29,7 @@
import com.linkedin.transport.trino.data.TrinoString;
import com.linkedin.transport.trino.data.TrinoStruct;
import io.airlift.slice.Slices;
import io.trino.metadata.FunctionBinding;
import io.trino.spi.function.BoundSignature;
import io.trino.spi.function.FunctionDependencies;
import io.trino.metadata.OperatorNotFoundException;
import io.trino.spi.function.InvocationConvention;
Expand All @@ -39,24 +38,19 @@
import io.trino.spi.type.MapType;
import io.trino.spi.type.RowType;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeSignature;
import java.lang.invoke.MethodHandle;
import java.nio.ByteBuffer;
import java.util.List;
import java.util.stream.Collectors;

import static com.linkedin.transport.trino.StdUDFUtils.quoteReservedKeywords;
import static io.trino.metadata.SignatureBinder.*;
import static io.trino.sql.analyzer.TypeSignatureTranslator.*;


public class TrinoFactory implements StdFactory {

final FunctionBinding functionBinding;
final BoundSignature boundSignature;
final FunctionDependencies functionDependencies;

public TrinoFactory(FunctionBinding functionBinding, FunctionDependencies functionDependencies) {
this.functionBinding = functionBinding;
public TrinoFactory(BoundSignature boundSignature, FunctionDependencies functionDependencies) {
this.boundSignature = boundSignature;
this.functionDependencies = functionDependencies;
}

Expand Down Expand Up @@ -130,8 +124,7 @@ public StdStruct createStruct(StdType stdType) {

@Override
public StdType createStdType(String typeSignatureStr) {
TypeSignature typeSignature = applyBoundVariables(parseTypeSignature(quoteReservedKeywords(typeSignatureStr), ImmutableSet.of()), functionBinding);
return TrinoWrapper.createStdType(functionDependencies.getType(typeSignature));
return TrinoWrapper.createStdType(boundSignature.getReturnType());
}

public MethodHandle getOperatorHandle(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
import io.trino.spi.type.Type;
import java.util.Iterator;

import static io.trino.spi.type.TypeUtils.*;

import static io.trino.spi.type.TypeUtils.readNativeValue;

public class TrinoArray extends TrinoData implements StdArray {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import io.trino.spi.block.BlockBuilder;
import java.nio.ByteBuffer;

import static io.trino.spi.type.VarbinaryType.*;
import static io.trino.spi.type.VarbinaryType.VARBINARY;

public class TrinoBinary extends TrinoData implements StdBinary {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import com.linkedin.transport.api.data.StdBoolean;
import io.trino.spi.block.BlockBuilder;

import static io.trino.spi.type.BooleanType.*;
import static io.trino.spi.type.BooleanType.BOOLEAN;


public class TrinoBoolean extends TrinoData implements StdBoolean {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import com.linkedin.transport.api.data.StdDouble;
import io.trino.spi.block.BlockBuilder;

import static io.trino.spi.type.DoubleType.*;
import static io.trino.spi.type.DoubleType.DOUBLE;


public class TrinoDouble extends TrinoData implements StdDouble {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import com.linkedin.transport.api.data.StdFloat;
import io.trino.spi.block.BlockBuilder;

import static java.lang.Float.*;
import static java.lang.Float.floatToIntBits;


public class TrinoFloat extends TrinoData implements StdFloat {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import com.linkedin.transport.api.data.StdInteger;
import io.trino.spi.block.BlockBuilder;

import static io.trino.spi.type.IntegerType.*;
import static io.trino.spi.type.IntegerType.INTEGER;


public class TrinoInteger extends TrinoData implements StdInteger {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import com.linkedin.transport.api.data.StdLong;
import io.trino.spi.block.BlockBuilder;

import static io.trino.spi.type.BigintType.*;
import static io.trino.spi.type.BigintType.BIGINT;


public class TrinoLong extends TrinoData implements StdLong {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@
import java.util.Iterator;
import java.util.Set;

import static io.trino.spi.StandardErrorCode.*;
import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR;
import static io.trino.spi.function.InvocationConvention.simpleConvention;
import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL;
import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.NULLABLE_RETURN;
import static io.trino.spi.type.TypeUtils.*;
import static io.trino.spi.type.TypeUtils.readNativeValue;


public class TrinoMap extends TrinoData implements StdMap {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import io.airlift.slice.Slice;
import io.trino.spi.block.BlockBuilder;

import static io.trino.spi.type.VarcharType.*;
import static io.trino.spi.type.VarcharType.VARCHAR;


public class TrinoString extends TrinoData implements StdString {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import static io.trino.spi.type.TypeUtils.*;
import static io.trino.spi.type.TypeUtils.readNativeValue;


public class TrinoStruct extends TrinoData implements StdStruct {
Expand Down