diff --git a/extensions/src/main/java/dev/cel/extensions/BUILD.bazel b/extensions/src/main/java/dev/cel/extensions/BUILD.bazel index b31dfc288..4400b0df0 100644 --- a/extensions/src/main/java/dev/cel/extensions/BUILD.bazel +++ b/extensions/src/main/java/dev/cel/extensions/BUILD.bazel @@ -147,7 +147,7 @@ java_library( "//common/types", "//compiler:compiler_builder", "//runtime", - "//runtime:runtime_helper", + "//runtime:proto_message_runtime_equality", "@maven//:com_google_errorprone_error_prone_annotations", "@maven//:com_google_guava_guava", ], diff --git a/extensions/src/main/java/dev/cel/extensions/CelSetsExtensions.java b/extensions/src/main/java/dev/cel/extensions/CelSetsExtensions.java index 6abac498d..f52c394ed 100644 --- a/extensions/src/main/java/dev/cel/extensions/CelSetsExtensions.java +++ b/extensions/src/main/java/dev/cel/extensions/CelSetsExtensions.java @@ -29,7 +29,7 @@ import dev.cel.runtime.CelRuntime; import dev.cel.runtime.CelRuntimeBuilder; import dev.cel.runtime.CelRuntimeLibrary; -import dev.cel.runtime.RuntimeEquality; +import dev.cel.runtime.ProtoMessageRuntimeEquality; import java.util.Collection; import java.util.Iterator; import java.util.Set; @@ -66,9 +66,6 @@ public final class CelSetsExtensions implements CelCompilerLibrary, CelRuntimeLi + " are unique, so size does not factor into the computation. If either list is empty," + " the result will be false."; - private static final RuntimeEquality RUNTIME_EQUALITY = - new RuntimeEquality(DynamicProto.create(DefaultMessageFactory.INSTANCE)); - /** Denotes the set extension function. */ public enum Function { CONTAINS( @@ -111,7 +108,7 @@ String getFunction() { } private final ImmutableSet functions; - private final CelOptions celOptions; + private final ProtoMessageRuntimeEquality runtimeEquality; CelSetsExtensions(CelOptions celOptions) { this(celOptions, ImmutableSet.copyOf(Function.values())); @@ -119,7 +116,9 @@ String getFunction() { CelSetsExtensions(CelOptions celOptions, Set functions) { this.functions = ImmutableSet.copyOf(functions); - this.celOptions = celOptions; + this.runtimeEquality = + ProtoMessageRuntimeEquality.create( + DynamicProto.create(DefaultMessageFactory.INSTANCE), celOptions); } @Override @@ -208,7 +207,7 @@ private boolean contains(Object o, Collection list) { } private boolean objectsEquals(Object o1, Object o2) { - return RUNTIME_EQUALITY.objectEquals(o1, o2, celOptions); + return runtimeEquality.objectEquals(o1, o2); } private boolean setIntersects(Collection listA, Collection listB) { diff --git a/publish/BUILD.bazel b/publish/BUILD.bazel index 14974d284..4ec1e68ab 100644 --- a/publish/BUILD.bazel +++ b/publish/BUILD.bazel @@ -9,7 +9,7 @@ RUNTIME_TARGETS = [ "//runtime/src/main/java/dev/cel/runtime", "//runtime/src/main/java/dev/cel/runtime:base", "//runtime/src/main/java/dev/cel/runtime:interpreter", - "//runtime/src/main/java/dev/cel/runtime:runtime_helper", + "//runtime/src/main/java/dev/cel/runtime:runtime_helpers", "//runtime/src/main/java/dev/cel/runtime:unknown_attributes", ] diff --git a/runtime/BUILD.bazel b/runtime/BUILD.bazel index a55058241..d9965e8c4 100644 --- a/runtime/BUILD.bazel +++ b/runtime/BUILD.bazel @@ -25,9 +25,27 @@ java_library( ) java_library( - name = "runtime_helper", + name = "runtime_helpers", visibility = ["//visibility:public"], - exports = ["//runtime/src/main/java/dev/cel/runtime:runtime_helper"], + exports = ["//runtime/src/main/java/dev/cel/runtime:runtime_helpers"], +) + +java_library( + name = "proto_message_runtime_helpers", + visibility = ["//visibility:public"], + exports = ["//runtime/src/main/java/dev/cel/runtime:proto_message_runtime_helpers"], +) + +java_library( + name = "runtime_equality", + visibility = ["//visibility:public"], + exports = ["//runtime/src/main/java/dev/cel/runtime:runtime_equality"], +) + +java_library( + name = "proto_message_runtime_equality", + visibility = ["//visibility:public"], + exports = ["//runtime/src/main/java/dev/cel/runtime:proto_message_runtime_equality"], ) java_library( diff --git a/runtime/src/main/java/dev/cel/runtime/BUILD.bazel b/runtime/src/main/java/dev/cel/runtime/BUILD.bazel index 871e32cb1..554f388d8 100644 --- a/runtime/src/main/java/dev/cel/runtime/BUILD.bazel +++ b/runtime/src/main/java/dev/cel/runtime/BUILD.bazel @@ -80,7 +80,7 @@ java_library( ":evaluation_exception_builder", ":evaluation_listener", ":metadata", - ":runtime_helper", + ":runtime_helpers", ":unknown_attributes", "//:auto_value", "//common", @@ -105,31 +105,80 @@ java_library( ) java_library( - name = "runtime_helper", + name = "runtime_equality", srcs = [ "RuntimeEquality.java", - "RuntimeHelpers.java", ], tags = [ ], - # NOTE: do not grow this dependencies arbitrarily deps = [ + ":runtime_helpers", "//common:error_codes", "//common:options", "//common:runtime_exception", - "//common/annotations", "//common/internal:comparison_functions", - "//common/internal:converter", + "@maven//:com_google_errorprone_error_prone_annotations", + "@maven//:com_google_guava_guava", + "@maven//:com_google_protobuf_protobuf_java", + ], +) + +java_library( + name = "proto_message_runtime_equality", + srcs = [ + "ProtoMessageRuntimeEquality.java", + ], + tags = [ + ], + deps = [ + ":proto_message_runtime_helpers", + ":runtime_equality", + "//common:options", + "//common/annotations", "//common/internal:dynamic_proto", "//common/internal:proto_equality", "@maven//:com_google_errorprone_error_prone_annotations", "@maven//:com_google_guava_guava", "@maven//:com_google_protobuf_protobuf_java", + ], +) + +java_library( + name = "runtime_helpers", + srcs = [ + "RuntimeHelpers.java", + ], + tags = [ + ], + deps = [ + "//common:error_codes", + "//common:options", + "//common:runtime_exception", + "//common/internal:converter", + "@maven//:com_google_errorprone_error_prone_annotations", + "@maven//:com_google_guava_guava", + "@maven//:com_google_protobuf_protobuf_java", "@maven//:com_google_re2j_re2j", "@maven//:org_threeten_threeten_extra", ], ) +java_library( + name = "proto_message_runtime_helpers", + srcs = [ + "ProtoMessageRuntimeHelpers.java", + ], + tags = [ + ], + deps = [ + ":runtime_helpers", + "//common:options", + "//common/annotations", + "//common/internal:dynamic_proto", + "@maven//:com_google_protobuf_protobuf_java", + ], +) + # keep sorted RUNTIME_SOURCES = [ "CelFunctionOverload.java", @@ -196,7 +245,10 @@ java_library( deps = [ ":evaluation_exception", ":evaluation_listener", - ":runtime_helper", + ":interpreter", + ":proto_message_runtime_equality", + ":runtime_equality", + ":runtime_helpers", ":runtime_type_provider_legacy", ":unknown_attributes", "//:auto_value", @@ -214,7 +266,6 @@ java_library( "//common/types:cel_types", "//common/values:cel_value_provider", "//common/values:proto_message_value_provider", - "//runtime:interpreter", "@maven//:com_google_code_findbugs_annotations", "@maven//:com_google_errorprone_error_prone_annotations", "@maven//:com_google_guava_guava", @@ -271,6 +322,7 @@ java_library( name = "runtime_type_provider_legacy", srcs = ["RuntimeTypeProviderLegacyImpl.java"], deps = [ + ":interpreter", ":unknown_attributes", "//common:error_codes", "//common:options", @@ -282,7 +334,6 @@ java_library( "//common/values:cel_value", "//common/values:cel_value_provider", "//common/values:proto_message_value", - "//runtime:interpreter", "@maven//:com_google_errorprone_error_prone_annotations", "@maven//:com_google_guava_guava", ], diff --git a/runtime/src/main/java/dev/cel/runtime/CelRuntimeLegacyImpl.java b/runtime/src/main/java/dev/cel/runtime/CelRuntimeLegacyImpl.java index 7ff19fd17..7665a04a1 100644 --- a/runtime/src/main/java/dev/cel/runtime/CelRuntimeLegacyImpl.java +++ b/runtime/src/main/java/dev/cel/runtime/CelRuntimeLegacyImpl.java @@ -241,10 +241,12 @@ public CelRuntimeLegacyImpl build() { runtimeTypeFactory, DefaultMessageFactory.create(celDescriptorPool)); DynamicProto dynamicProto = DynamicProto.create(runtimeTypeFactory); + RuntimeEquality runtimeEquality = ProtoMessageRuntimeEquality.create(dynamicProto, options); ImmutableMap.Builder functionBindingsBuilder = ImmutableMap.builder(); - for (CelFunctionBinding standardFunctionBinding : newStandardFunctionBindings(dynamicProto)) { + for (CelFunctionBinding standardFunctionBinding : + newStandardFunctionBindings(runtimeEquality)) { functionBindingsBuilder.put( standardFunctionBinding.getOverloadId(), standardFunctionBinding); } @@ -283,7 +285,7 @@ public CelRuntimeLegacyImpl build() { } private ImmutableSet newStandardFunctionBindings( - DynamicProto dynamicProto) { + RuntimeEquality runtimeEquality) { CelStandardFunctions celStandardFunctions; if (standardEnvironmentEnabled) { celStandardFunctions = @@ -336,7 +338,7 @@ private ImmutableSet newStandardFunctionBindings( return ImmutableSet.of(); } - return celStandardFunctions.newFunctionBindings(dynamicProto, options); + return celStandardFunctions.newFunctionBindings(runtimeEquality, options); } private static CelDescriptorPool newDescriptorPool( diff --git a/runtime/src/main/java/dev/cel/runtime/CelStandardFunctions.java b/runtime/src/main/java/dev/cel/runtime/CelStandardFunctions.java index 9640db57d..cb13cc3ea 100644 --- a/runtime/src/main/java/dev/cel/runtime/CelStandardFunctions.java +++ b/runtime/src/main/java/dev/cel/runtime/CelStandardFunctions.java @@ -31,8 +31,8 @@ import dev.cel.common.CelErrorCode; import dev.cel.common.CelOptions; import dev.cel.common.CelRuntimeException; +import dev.cel.common.annotations.Internal; import dev.cel.common.internal.ComparisonFunctions; -import dev.cel.common.internal.DynamicProto; import dev.cel.common.internal.SafeStringFormatter; import dev.cel.runtime.CelRuntime.CelFunctionBinding; import dev.cel.runtime.CelStandardFunctions.StandardFunction.Overload.Arithmetic; @@ -251,16 +251,14 @@ public enum InternalOperator implements StandardOverload { Object.class, List.class, (Object value, List list) -> - bindingHelper.runtimeEquality.inList( - list, value, bindingHelper.celOptions))), + bindingHelper.runtimeEquality.inList(list, value))), IN_MAP( (bindingHelper) -> CelFunctionBinding.from( "in_map", Object.class, Map.class, - (Object key, Map map) -> - bindingHelper.runtimeEquality.inMap(map, key, bindingHelper.celOptions))); + (Object key, Map map) -> bindingHelper.runtimeEquality.inMap(map, key))); private final FunctionBindingCreator bindingCreator; @@ -282,18 +280,14 @@ public enum Relation implements StandardOverload { "equals", Object.class, Object.class, - (Object x, Object y) -> - bindingHelper.runtimeEquality.objectEquals( - x, y, bindingHelper.celOptions))), + bindingHelper.runtimeEquality::objectEquals)), NOT_EQUALS( (bindingHelper) -> CelFunctionBinding.from( "not_equals", Object.class, Object.class, - (Object x, Object y) -> - !bindingHelper.runtimeEquality.objectEquals( - x, y, bindingHelper.celOptions))); + (Object x, Object y) -> !bindingHelper.runtimeEquality.objectEquals(x, y))); private final FunctionBindingCreator bindingCreator; @@ -591,12 +585,7 @@ public enum Index implements StandardOverload { INDEX_MAP( (bindingHelper) -> CelFunctionBinding.from( - "index_map", - Map.class, - Object.class, - (Map map, Object key) -> - bindingHelper.runtimeEquality.indexMap( - map, key, bindingHelper.celOptions))); + "index_map", Map.class, Object.class, bindingHelper.runtimeEquality::indexMap)); private final FunctionBindingCreator bindingCreator; @@ -1738,18 +1727,14 @@ public enum OptionalValue implements StandardOverload { // special cased inside the interpreter. Map.class, String.class, - (Map map, String key) -> - bindingHelper.runtimeEquality.findInMap( - map, key, bindingHelper.celOptions))), + bindingHelper.runtimeEquality::findInMap)), MAP_OPTINDEX_OPTIONAL_VALUE( (bindingHelper) -> CelFunctionBinding.from( "map_optindex_optional_value", Map.class, Object.class, - (Map map, Object key) -> - bindingHelper.runtimeEquality.findInMap( - map, key, bindingHelper.celOptions))), + bindingHelper.runtimeEquality::findInMap)), OPTIONAL_MAP_OPTINDEX_OPTIONAL_VALUE( (bindingHelper) -> CelFunctionBinding.from( @@ -1757,11 +1742,7 @@ public enum OptionalValue implements StandardOverload { Optional.class, Object.class, (Optional optionalMap, Object key) -> - indexOptionalMap( - optionalMap, - key, - bindingHelper.celOptions, - bindingHelper.runtimeEquality))), + indexOptionalMap(optionalMap, key, bindingHelper.runtimeEquality))), OPTIONAL_MAP_INDEX_VALUE( (bindingHelper) -> CelFunctionBinding.from( @@ -1769,11 +1750,7 @@ public enum OptionalValue implements StandardOverload { Optional.class, Object.class, (Optional optionalMap, Object key) -> - indexOptionalMap( - optionalMap, - key, - bindingHelper.celOptions, - bindingHelper.runtimeEquality))), + indexOptionalMap(optionalMap, key, bindingHelper.runtimeEquality))), OPTIONAL_LIST_INDEX_INT( (bindingHelper) -> CelFunctionBinding.from( @@ -1834,9 +1811,10 @@ ImmutableSet getOverloads() { return standardOverloads; } + @Internal public ImmutableSet newFunctionBindings( - DynamicProto dynamicProto, CelOptions celOptions) { - FunctionBindingHelper helper = new FunctionBindingHelper(celOptions, dynamicProto); + RuntimeEquality runtimeEquality, CelOptions celOptions) { + FunctionBindingHelper helper = new FunctionBindingHelper(celOptions, runtimeEquality); ImmutableSet.Builder builder = ImmutableSet.builder(); for (StandardOverload overload : standardOverloads) { builder.add(overload.newFunctionBinding(helper)); @@ -1977,9 +1955,9 @@ private static final class FunctionBindingHelper { private final CelOptions celOptions; private final RuntimeEquality runtimeEquality; - private FunctionBindingHelper(CelOptions celOptions, DynamicProto dynamicProto) { + private FunctionBindingHelper(CelOptions celOptions, RuntimeEquality runtimeEquality) { this.celOptions = celOptions; - this.runtimeEquality = new RuntimeEquality(dynamicProto); + this.runtimeEquality = runtimeEquality; } } @@ -2054,14 +2032,14 @@ private static ZoneId timeZone(String tz) { } private static Object indexOptionalMap( - Optional optionalMap, Object key, CelOptions options, RuntimeEquality runtimeEquality) { + Optional optionalMap, Object key, RuntimeEquality runtimeEquality) { if (!optionalMap.isPresent()) { return Optional.empty(); } Map map = (Map) optionalMap.get(); - return runtimeEquality.findInMap(map, key, options); + return runtimeEquality.findInMap(map, key); } private static Object indexOptionalList(Optional optionalList, long index) { diff --git a/runtime/src/main/java/dev/cel/runtime/DefaultInterpreter.java b/runtime/src/main/java/dev/cel/runtime/DefaultInterpreter.java index 2a4e58e2a..a279b9e2f 100644 --- a/runtime/src/main/java/dev/cel/runtime/DefaultInterpreter.java +++ b/runtime/src/main/java/dev/cel/runtime/DefaultInterpreter.java @@ -206,9 +206,7 @@ private IntermediateResult evalInternal(ExecutionFrame frame, CelExpr expr) frame.getEvaluationListener().callback(expr, result.value()); return result; } catch (CelRuntimeException e) { - throw CelEvaluationExceptionBuilder.newBuilder(e) - .setMetadata(metadata, expr.id()) - .build(); + throw CelEvaluationExceptionBuilder.newBuilder(e).setMetadata(metadata, expr.id()).build(); } catch (RuntimeException e) { throw CelEvaluationExceptionBuilder.newBuilder(e.getMessage()) .setCause(e) @@ -418,9 +416,7 @@ private IntermediateResult evalCall(ExecutionFrame frame, CelExpr expr, CelCall } return IntermediateResult.create(attr, dispatchResult); } catch (CelRuntimeException ce) { - throw CelEvaluationExceptionBuilder.newBuilder(ce) - .setMetadata(metadata, expr.id()) - .build(); + throw CelEvaluationExceptionBuilder.newBuilder(ce).setMetadata(metadata, expr.id()).build(); } catch (RuntimeException e) { throw CelEvaluationExceptionBuilder.newBuilder( "Function '%s' failed with arg(s) '%s'", @@ -455,9 +451,7 @@ private ResolvedOverload findOverloadOrThrow( .setMetadata(metadata, expr.id()) .build()); } catch (CelRuntimeException e) { - throw CelEvaluationExceptionBuilder.newBuilder(e) - .setMetadata(metadata, expr.id()) - .build(); + throw CelEvaluationExceptionBuilder.newBuilder(e).setMetadata(metadata, expr.id()).build(); } } diff --git a/runtime/src/main/java/dev/cel/runtime/ProtoMessageRuntimeEquality.java b/runtime/src/main/java/dev/cel/runtime/ProtoMessageRuntimeEquality.java new file mode 100644 index 000000000..ea4b6874d --- /dev/null +++ b/runtime/src/main/java/dev/cel/runtime/ProtoMessageRuntimeEquality.java @@ -0,0 +1,70 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.runtime; + +import com.google.errorprone.annotations.Immutable; +import com.google.protobuf.Message; +import dev.cel.common.CelOptions; +import dev.cel.common.annotations.Internal; +import dev.cel.common.internal.DynamicProto; +import dev.cel.common.internal.ProtoEquality; +import java.util.Objects; + +/** + * ProtoMessageRuntimeEquality contains methods for performing CEL related equality checks, + * including full protobuf messages by leveraging descriptors. + * + *

CEL Library Internals. Do Not Use. + */ +@Internal +@Immutable +public final class ProtoMessageRuntimeEquality extends RuntimeEquality { + + private final ProtoEquality protoEquality; + + @Internal + public static ProtoMessageRuntimeEquality create( + DynamicProto dynamicProto, CelOptions celOptions) { + return new ProtoMessageRuntimeEquality(dynamicProto, celOptions); + } + + @Override + public boolean objectEquals(Object x, Object y) { + if (celOptions.disableCelStandardEquality()) { + return Objects.equals(x, y); + } + if (x == y) { + return true; + } + + if (celOptions.enableProtoDifferencerEquality()) { + x = runtimeHelper.adaptValue(x); + y = runtimeHelper.adaptValue(y); + if (x instanceof Message) { + if (!(y instanceof Message)) { + return false; + } + return protoEquality.equals((Message) x, (Message) y); + } + } + + return super.objectEquals(x, y); + } + + private ProtoMessageRuntimeEquality(DynamicProto dynamicProto, CelOptions celOptions) { + super(ProtoMessageRuntimeHelpers.create(dynamicProto, celOptions), celOptions); + this.protoEquality = new ProtoEquality(dynamicProto); + } +} diff --git a/runtime/src/main/java/dev/cel/runtime/ProtoMessageRuntimeHelpers.java b/runtime/src/main/java/dev/cel/runtime/ProtoMessageRuntimeHelpers.java new file mode 100644 index 000000000..b18310b5b --- /dev/null +++ b/runtime/src/main/java/dev/cel/runtime/ProtoMessageRuntimeHelpers.java @@ -0,0 +1,66 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.runtime; + +import com.google.protobuf.Message; +import com.google.protobuf.MessageLiteOrBuilder; +import com.google.protobuf.MessageOrBuilder; +import dev.cel.common.CelOptions; +import dev.cel.common.annotations.Internal; +import dev.cel.common.internal.DynamicProto; +import dev.cel.common.internal.ProtoAdapter; + +/** + * Helper methods for common CEL related routines that require a full protobuf dependency. + * + *

CEL Library Internals. Do Not Use. + */ +@Internal +public final class ProtoMessageRuntimeHelpers extends RuntimeHelpers { + + private final ProtoAdapter protoAdapter; + + @Internal + public static ProtoMessageRuntimeHelpers create( + DynamicProto dynamicProto, CelOptions celOptions) { + return new ProtoMessageRuntimeHelpers( + new ProtoAdapter(dynamicProto, celOptions.enableUnsignedLongs())); + } + + /** + * Adapts a {@code protobuf.Message} to a plain old Java object. + * + *

Well-known protobuf types (wrappers, JSON types) are unwrapped to Java native object + * representations. + * + *

If the incoming {@code obj} is of type {@code google.protobuf.Any} the object is unpacked + * and the proto within is passed to the {@code adaptProtoToValue} method again to ensure the + * message contained within the Any is properly unwrapped if it is a well-known protobuf type. + */ + @Override + Object adaptProtoToValue(MessageLiteOrBuilder obj) { + if (obj instanceof Message) { + return protoAdapter.adaptProtoToValue((MessageOrBuilder) obj); + } + if (obj instanceof Message.Builder) { + return protoAdapter.adaptProtoToValue(((Message.Builder) obj).build()); + } + return obj; + } + + private ProtoMessageRuntimeHelpers(ProtoAdapter protoAdapter) { + this.protoAdapter = protoAdapter; + } +} diff --git a/runtime/src/main/java/dev/cel/runtime/RuntimeEquality.java b/runtime/src/main/java/dev/cel/runtime/RuntimeEquality.java index 05bda7caa..3dacc79a3 100644 --- a/runtime/src/main/java/dev/cel/runtime/RuntimeEquality.java +++ b/runtime/src/main/java/dev/cel/runtime/RuntimeEquality.java @@ -16,46 +16,40 @@ import com.google.common.primitives.UnsignedLong; import com.google.errorprone.annotations.Immutable; -import com.google.protobuf.Message; -import com.google.protobuf.MessageOrBuilder; +import com.google.protobuf.MessageLiteOrBuilder; import dev.cel.common.CelErrorCode; import dev.cel.common.CelOptions; import dev.cel.common.CelRuntimeException; -import dev.cel.common.annotations.Internal; import dev.cel.common.internal.ComparisonFunctions; -import dev.cel.common.internal.DynamicProto; -import dev.cel.common.internal.ProtoEquality; import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.Map.Entry; import java.util.Objects; import java.util.Optional; import java.util.Set; -/** CEL Library Internals. Do Not Use. */ -@Internal +/** RuntimeEquality contains methods for performing CEL related equality checks. */ @Immutable -public final class RuntimeEquality { +class RuntimeEquality { + protected final RuntimeHelpers runtimeHelper; + protected final CelOptions celOptions; - private final DynamicProto dynamicProto; - private final ProtoEquality protoEquality; - - public RuntimeEquality(DynamicProto dynamicProto) { - this.dynamicProto = dynamicProto; - this.protoEquality = new ProtoEquality(dynamicProto); + static RuntimeEquality create(RuntimeHelpers runtimeHelper, CelOptions celOptions) { + return new RuntimeEquality(runtimeHelper, celOptions); } // Functions // ========= /** Determine whether the {@code list} contains the given {@code value}. */ - public boolean inList(List list, A value, CelOptions celOptions) { + public boolean inList(List list, A value) { if (list.contains(value)) { return true; } if (value instanceof Number) { for (A elem : list) { - if (objectEquals(elem, value, celOptions)) { + if (objectEquals(elem, value)) { return true; } } @@ -65,8 +59,8 @@ public boolean inList(List list, A value, CelOptions celOptions) { /** Bound-checked indexing of maps. */ @SuppressWarnings("unchecked") - public B indexMap(Map map, A index, CelOptions celOptions) { - Optional value = findInMap(map, index, celOptions); + public B indexMap(Map map, A index) { + Optional value = findInMap(map, index); // Use this method rather than the standard 'orElseThrow' method because of the unchecked cast. if (value.isPresent()) { return (B) value.get(); @@ -76,17 +70,17 @@ public B indexMap(Map map, A index, CelOptions celOptions) { } /** Determine whether the {@code map} contains the given {@code key}. */ - public boolean inMap(Map map, A key, CelOptions celOptions) { - return findInMap(map, key, celOptions).isPresent(); + public boolean inMap(Map map, A key) { + return findInMap(map, key).isPresent(); } - public Optional findInMap(Map map, Object index, CelOptions celOptions) { + public Optional findInMap(Map map, Object index) { if (celOptions.disableCelStandardEquality()) { return Optional.ofNullable(map.get(index)); } - if (index instanceof MessageOrBuilder) { - index = RuntimeHelpers.adaptProtoToValue(dynamicProto, (MessageOrBuilder) index, celOptions); + if (index instanceof MessageLiteOrBuilder) { + index = runtimeHelper.adaptProtoToValue((MessageLiteOrBuilder) index); } Object v = map.get(index); if (v != null) { @@ -138,23 +132,18 @@ public Optional findInMap(Map map, Object index, CelOptions celOpt * *

Heterogeneous equality differs from homogeneous equality in that two objects may be * comparable even if they are not of the same type, where type differences are usually trivially - * false. Heterogeneous runtime equality is under consideration in b/71516544. - * - *

Note, uint values are problematic in that they cannot be properly type-tested for equality - * in comparisons with 64-int signed integer values, see b/159183198. This problem only affects - * Java and is typically inconsequential due to the requirement for type-checking expressions - * before they are evaluated. + * false. */ @SuppressWarnings({"rawtypes", "unchecked"}) - public boolean objectEquals(Object x, Object y, CelOptions celOptions) { + public boolean objectEquals(Object x, Object y) { if (celOptions.disableCelStandardEquality()) { return Objects.equals(x, y); } if (x == y) { return true; } - x = RuntimeHelpers.adaptValue(dynamicProto, x, celOptions); - y = RuntimeHelpers.adaptValue(dynamicProto, y, celOptions); + x = runtimeHelper.adaptValue(x); + y = runtimeHelper.adaptValue(y); if (x instanceof Number) { if (!(y instanceof Number)) { return false; @@ -162,11 +151,13 @@ public boolean objectEquals(Object x, Object y, CelOptions celOptions) { return ComparisonFunctions.numericEquals((Number) x, (Number) y); } if (celOptions.enableProtoDifferencerEquality()) { - if (x instanceof Message) { - if (!(y instanceof Message)) { + if (x instanceof MessageLiteOrBuilder) { + if (!(y instanceof MessageLiteOrBuilder)) { return false; } - return protoEquality.equals((Message) x, (Message) y); + // TODO: Implement when CelLiteDescriptor is available + throw new UnsupportedOperationException( + "Proto Differencer equality is not supported for MessageLite."); } } if (x instanceof Iterable) { @@ -182,7 +173,7 @@ public boolean objectEquals(Object x, Object y, CelOptions celOptions) { return false; } try { - if (!objectEquals(xElem, yElems.next(), celOptions)) { + if (!objectEquals(xElem, yElems.next())) { return false; } } catch (IllegalArgumentException iae) { @@ -207,15 +198,15 @@ public boolean objectEquals(Object x, Object y, CelOptions celOptions) { return false; } IllegalArgumentException e = null; - Set entrySet = xMap.entrySet(); + Set entrySet = xMap.entrySet(); for (Map.Entry xEntry : entrySet) { - Optional yVal = findInMap(yMap, xEntry.getKey(), celOptions); + Optional yVal = findInMap(yMap, xEntry.getKey()); // Use isPresent() rather than isEmpty() to stay backwards compatible with Java 8. if (!yVal.isPresent()) { return false; } try { - if (!objectEquals(xEntry.getValue(), yVal.get(), celOptions)) { + if (!objectEquals(xEntry.getValue(), yVal.get())) { return false; } } catch (IllegalArgumentException iae) { @@ -248,4 +239,9 @@ private static Optional unsignedToLongLossless(UnsignedLong v) { } return Optional.empty(); } + + RuntimeEquality(RuntimeHelpers runtimeHelper, CelOptions celOptions) { + this.runtimeHelper = runtimeHelper; + this.celOptions = celOptions; + } } diff --git a/runtime/src/main/java/dev/cel/runtime/RuntimeHelpers.java b/runtime/src/main/java/dev/cel/runtime/RuntimeHelpers.java index b885bb55d..1c20017d1 100644 --- a/runtime/src/main/java/dev/cel/runtime/RuntimeHelpers.java +++ b/runtime/src/main/java/dev/cel/runtime/RuntimeHelpers.java @@ -17,21 +17,17 @@ import static com.google.common.base.Preconditions.checkArgument; import com.google.common.primitives.Ints; -import com.google.common.primitives.UnsignedInts; import com.google.common.primitives.UnsignedLong; import com.google.common.primitives.UnsignedLongs; +import com.google.errorprone.annotations.Immutable; import com.google.protobuf.Duration; -import com.google.protobuf.Message; -import com.google.protobuf.MessageOrBuilder; +import com.google.protobuf.MessageLiteOrBuilder; import com.google.protobuf.NullValue; import com.google.re2j.Pattern; import dev.cel.common.CelErrorCode; import dev.cel.common.CelOptions; import dev.cel.common.CelRuntimeException; -import dev.cel.common.annotations.Internal; import dev.cel.common.internal.Converter; -import dev.cel.common.internal.DynamicProto; -import dev.cel.common.internal.ProtoAdapter; import java.time.format.DateTimeParseException; import java.util.ArrayList; import java.util.List; @@ -43,18 +39,23 @@ * *

CEL Library Internals. Do Not Use. */ -@Internal -public final class RuntimeHelpers { +@Immutable +class RuntimeHelpers { // Maximum and minimum range supported by protobuf Duration values. private static final java.time.Duration DURATION_MAX = java.time.Duration.ofDays(3652500); private static final java.time.Duration DURATION_MIN = DURATION_MAX.negated(); + static RuntimeHelpers create() { + return new RuntimeHelpers(); + } + // Functions // ========= /** Convert a string to a Duration. */ - public static Duration createDurationFromString(String d) { + @SuppressWarnings("AndroidJdkLibsChecker") // DateTimeParseException added in 26 + static Duration createDurationFromString(String d) { try { java.time.Duration dv = AmountFormats.parseUnitBasedDuration(d); // Ensure that the duration value can be adequately represented within a protobuf.Duration. @@ -67,13 +68,7 @@ public static Duration createDurationFromString(String d) { } } - /** Match a string against a regular expression. */ - public static boolean matches(String string, String regexp) { - return matches( - string, regexp, CelOptions.newBuilder().disableCelStandardEquality(false).build()); - } - - public static boolean matches(String string, String regexp, CelOptions celOptions) { + static boolean matches(String string, String regexp, CelOptions celOptions) { Pattern pattern = Pattern.compile(regexp); int maxProgramSize = celOptions.maxRegexProgramSize(); if (maxProgramSize >= 0 && pattern.programSize() > maxProgramSize) { @@ -92,13 +87,8 @@ public static boolean matches(String string, String regexp, CelOptions celOption return pattern.matcher(string).find(); } - /** Create a compiled pattern for the given regular expression. */ - public static Pattern compilePattern(String regexp) { - return Pattern.compile(regexp); - } - /** Concatenates two lists into a new list. */ - public static List concat(List first, List second) { + static List concat(List first, List second) { // TODO: return a view instead of an actual copy. List result = new ArrayList<>(first.size() + second.size()); result.addAll(first); @@ -110,7 +100,7 @@ public static List concat(List first, List second) { // =========== /** Bound-checked indexing of lists. */ - public static A indexList(List list, Number index) { + static A indexList(List list, Number index) { if (index instanceof Double) { return doubleToLongLossless(index.doubleValue()) .map(v -> indexList(list, v)) @@ -134,35 +124,35 @@ public static A indexList(List list, Number index) { // // CEL requires exceptions to be thrown when int arithmetic exceeds the represented range. - public static long int64Add(long x, long y, CelOptions celOptions) { + static long int64Add(long x, long y, CelOptions celOptions) { if (celOptions.errorOnIntWrap()) { return Math.addExact(x, y); } return x + y; } - public static long int64Divide(long x, long y, CelOptions celOptions) { + static long int64Divide(long x, long y, CelOptions celOptions) { if (celOptions.errorOnIntWrap() && x == Long.MIN_VALUE && y == -1) { throw new ArithmeticException("most negative number wraps"); } return x / y; } - public static long int64Multiply(long x, long y, CelOptions celOptions) { + static long int64Multiply(long x, long y, CelOptions celOptions) { if (celOptions.errorOnIntWrap()) { return Math.multiplyExact(x, y); } return x * y; } - public static long int64Negate(long x, CelOptions celOptions) { + static long int64Negate(long x, CelOptions celOptions) { if (celOptions.errorOnIntWrap()) { return Math.negateExact(x); } return -x; } - public static long int64Subtract(long x, long y, CelOptions celOptions) { + static long int64Subtract(long x, long y, CelOptions celOptions) { if (celOptions.errorOnIntWrap()) { return Math.subtractExact(x, y); } @@ -180,7 +170,7 @@ public static long int64Subtract(long x, long y, CelOptions celOptions) { // works for signed long values that are greater than or equal to 0. The former reinterprets the // long as unsigned, using the bits as is. - public static long uint64Add(long x, long y, CelOptions celOptions) { + static long uint64Add(long x, long y, CelOptions celOptions) { if (celOptions.errorOnIntWrap()) { if (x < 0 && y < 0) { // Both numbers are in the upper half of the range, so it must overflow. @@ -197,30 +187,30 @@ public static long uint64Add(long x, long y, CelOptions celOptions) { return x + y; } - public static UnsignedLong uint64Add(UnsignedLong x, UnsignedLong y) { + static UnsignedLong uint64Add(UnsignedLong x, UnsignedLong y) { if (x.compareTo(UnsignedLong.MAX_VALUE.minus(y)) > 0) { throw new ArithmeticException("range overflow on unsigned addition"); } return x.plus(y); } - public static int uint64CompareTo(long x, long y, CelOptions celOptions) { + static int uint64CompareTo(long x, long y, CelOptions celOptions) { return celOptions.enableUnsignedComparisonAndArithmeticIsUnsigned() ? UnsignedLongs.compare(x, y) : UnsignedLong.valueOf(x).compareTo(UnsignedLong.valueOf(y)); } - public static int uint64CompareTo(long x, long y) { + static int uint64CompareTo(long x, long y) { // Features is set to empty, as this class is public and the build visibility is public. // Existing callers expect legacy behavior. return uint64CompareTo(x, y, CelOptions.LEGACY); } - public static int uint64CompareTo(UnsignedLong x, UnsignedLong y) { + static int uint64CompareTo(UnsignedLong x, UnsignedLong y) { return x.compareTo(y); } - public static long uint64Divide(long x, long y, CelOptions celOptions) { + static long uint64Divide(long x, long y, CelOptions celOptions) { try { return celOptions.enableUnsignedComparisonAndArithmeticIsUnsigned() ? UnsignedLongs.divide(x, y) @@ -230,13 +220,13 @@ public static long uint64Divide(long x, long y, CelOptions celOptions) { } } - public static long uint64Divide(long x, long y) { + static long uint64Divide(long x, long y) { // Features is set to empty, as this class is public and the build visibility is public. // Existing callers expect legacy behavior. return uint64Divide(x, y, CelOptions.LEGACY); } - public static UnsignedLong uint64Divide(UnsignedLong x, UnsignedLong y) { + static UnsignedLong uint64Divide(UnsignedLong x, UnsignedLong y) { if (y.equals(UnsignedLong.ZERO)) { throw new CelRuntimeException( new ArithmeticException("/ by zero"), CelErrorCode.DIVIDE_BY_ZERO); @@ -244,7 +234,7 @@ public static UnsignedLong uint64Divide(UnsignedLong x, UnsignedLong y) { return x.dividedBy(y); } - public static long uint64Mod(long x, long y, CelOptions celOptions) { + static long uint64Mod(long x, long y, CelOptions celOptions) { try { return celOptions.enableUnsignedComparisonAndArithmeticIsUnsigned() ? UnsignedLongs.remainder(x, y) @@ -254,7 +244,7 @@ public static long uint64Mod(long x, long y, CelOptions celOptions) { } } - public static UnsignedLong uint64Mod(UnsignedLong x, UnsignedLong y) { + static UnsignedLong uint64Mod(UnsignedLong x, UnsignedLong y) { if (y.equals(UnsignedLong.ZERO)) { throw new CelRuntimeException( new ArithmeticException("/ by zero"), CelErrorCode.DIVIDE_BY_ZERO); @@ -262,13 +252,13 @@ public static UnsignedLong uint64Mod(UnsignedLong x, UnsignedLong y) { return x.mod(y); } - public static long uint64Mod(long x, long y) { + static long uint64Mod(long x, long y) { // Features is set to empty, as this class is public and the build visibility is public. // Existing callers expect legacy behavior. return uint64Mod(x, y, CelOptions.LEGACY); } - public static long uint64Multiply(long x, long y, CelOptions celOptions) { + static long uint64Multiply(long x, long y, CelOptions celOptions) { long z = celOptions.enableUnsignedComparisonAndArithmeticIsUnsigned() ? x * y @@ -279,20 +269,20 @@ public static long uint64Multiply(long x, long y, CelOptions celOptions) { return z; } - public static long uint64Multiply(long x, long y) { + static long uint64Multiply(long x, long y) { // Features is set to empty, as this class is public and the build visibility is public. // Existing callers expect legacy behavior. return uint64Multiply(x, y, CelOptions.LEGACY); } - public static UnsignedLong uint64Multiply(UnsignedLong x, UnsignedLong y) { + static UnsignedLong uint64Multiply(UnsignedLong x, UnsignedLong y) { if (!y.equals(UnsignedLong.ZERO) && x.compareTo(UnsignedLong.MAX_VALUE.dividedBy(y)) > 0) { throw new ArithmeticException("multiply out of unsigned integer range"); } return x.times(y); } - public static long uint64Subtract(long x, long y, CelOptions celOptions) { + static long uint64Subtract(long x, long y, CelOptions celOptions) { if (celOptions.errorOnIntWrap()) { // Throw an overflow error if x < y, as unsigned longs. This happens if y has its high // bit set and x does not, or if they have the same high bit and x < y as signed longs. @@ -304,7 +294,7 @@ public static long uint64Subtract(long x, long y, CelOptions celOptions) { return x - y; } - public static UnsignedLong uint64Subtract(UnsignedLong x, UnsignedLong y) { + static UnsignedLong uint64Subtract(UnsignedLong x, UnsignedLong y) { // Throw an overflow error if x < y, as unsigned longs. This happens if y has its high // bit set and x does not, or if they have the same high bit and x < y as signed longs. if (x.compareTo(y) < 0) { @@ -313,9 +303,6 @@ public static UnsignedLong uint64Subtract(UnsignedLong x, UnsignedLong y) { return x.minus(y); } - // Object equality - // =================== - // Proto Type Adaption // =================== @@ -324,36 +311,34 @@ public static UnsignedLong uint64Subtract(UnsignedLong x, UnsignedLong y) { // want to avoid to do this conversion eagerly, so we create views on the underlying data. // The below code is the extensive boilerplate to do so. - public static Converter identity() { + static Converter identity() { return (A value) -> value; } - public static final Converter INT32_TO_INT64 = Integer::longValue; - - public static final Converter UINT32_TO_UINT64 = UnsignedInts::toLong; + static final Converter INT32_TO_INT64 = Integer::longValue; - public static final Converter FLOAT_TO_DOUBLE = Float::doubleValue; + static final Converter FLOAT_TO_DOUBLE = Float::doubleValue; - public static final Converter INT64_TO_INT32 = Ints::checkedCast; + static final Converter INT64_TO_INT32 = Ints::checkedCast; - public static final Converter DOUBLE_TO_FLOAT = Double::floatValue; + static final Converter DOUBLE_TO_FLOAT = Double::floatValue; /** Adapts a plain old Java object into a CEL value. */ - public static Object adaptValue(DynamicProto dynamicProto, Object value, CelOptions celOptions) { + public Object adaptValue(Object value) { if (value == null) { return NullValue.NULL_VALUE; } if (value instanceof Number) { return maybeAdaptPrimitive(value); } - if (value instanceof MessageOrBuilder) { - return adaptProtoToValue(dynamicProto, (MessageOrBuilder) value, celOptions); + if (value instanceof MessageLiteOrBuilder) { + return adaptProtoToValue((MessageLiteOrBuilder) value); } return value; } /** Adapts a {@code Number} value to its appropriate CEL type. */ - public static Object maybeAdaptPrimitive(Object value) { + static Object maybeAdaptPrimitive(Object value) { if (value instanceof Optional) { Optional optionalVal = (Optional) value; if (!optionalVal.isPresent()) { @@ -380,19 +365,11 @@ public static Object maybeAdaptPrimitive(Object value) { * and the proto within is passed to the {@code adaptProtoToValue} method again to ensure the * message contained within the Any is properly unwrapped if it is a well-known protobuf type. */ - public static Object adaptProtoToValue( - DynamicProto dynamicProto, MessageOrBuilder obj, CelOptions celOptions) { - ProtoAdapter protoAdapter = new ProtoAdapter(dynamicProto, celOptions.enableUnsignedLongs()); - if (obj instanceof Message) { - return protoAdapter.adaptProtoToValue(obj); - } - if (obj instanceof Message.Builder) { - return protoAdapter.adaptProtoToValue(((Message.Builder) obj).build()); - } - return obj; + Object adaptProtoToValue(MessageLiteOrBuilder obj) { + throw new UnsupportedOperationException("Not implemented yet"); } - public static Optional doubleToUnsignedChecked(double v) { + static Optional doubleToUnsignedChecked(double v) { // getExponent of NaN or Infinite will return a Double.MAX_EXPONENT + 1 (or 128) if (v < 0.0 || Math.getExponent(v) >= 64) { return Optional.empty(); @@ -406,7 +383,7 @@ public static Optional doubleToUnsignedChecked(double v) { return Optional.of(UnsignedLong.fromLongBits((long) v)); } - public static Optional doubleToLongChecked(double v) { + static Optional doubleToLongChecked(double v) { // getExponent of NaN or Infinite values will return a Double.MAX_EXPONENT + 1 (or 128) int exp = Math.getExponent(v); if (exp >= 63 && v != Math.scalb(-1.0, 63)) { @@ -415,10 +392,10 @@ public static Optional doubleToLongChecked(double v) { return Optional.of((long) v); } - public static Optional doubleToLongLossless(Number v) { + static Optional doubleToLongLossless(Number v) { Optional conv = doubleToLongChecked(v.doubleValue()); return conv.map(l -> l.doubleValue() == v.doubleValue() ? l : null); } - private RuntimeHelpers() {} + RuntimeHelpers() {} } diff --git a/runtime/src/test/java/dev/cel/runtime/BUILD.bazel b/runtime/src/test/java/dev/cel/runtime/BUILD.bazel index 7c1668b29..a3108eaf6 100644 --- a/runtime/src/test/java/dev/cel/runtime/BUILD.bazel +++ b/runtime/src/test/java/dev/cel/runtime/BUILD.bazel @@ -44,7 +44,10 @@ java_library( "//runtime:evaluation_exception_builder", "//runtime:evaluation_listener", "//runtime:interpreter", - "//runtime:runtime_helper", + "//runtime:proto_message_runtime_equality", + "//runtime:proto_message_runtime_helpers", + "//runtime:runtime_equality", + "//runtime:runtime_helpers", "//runtime:unknown_attributes", "//runtime:unknown_options", "@cel_spec//proto/cel/expr/conformance/proto2:test_all_types_java_proto", diff --git a/runtime/src/test/java/dev/cel/runtime/ProtoMessageRuntimeEqualityTest.java b/runtime/src/test/java/dev/cel/runtime/ProtoMessageRuntimeEqualityTest.java new file mode 100644 index 000000000..4120cc360 --- /dev/null +++ b/runtime/src/test/java/dev/cel/runtime/ProtoMessageRuntimeEqualityTest.java @@ -0,0 +1,691 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.runtime; + +import static com.google.common.truth.Truth.assertThat; + +import com.google.auto.value.AutoValue; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.primitives.UnsignedLong; +import com.google.protobuf.Any; +import com.google.protobuf.BoolValue; +import com.google.protobuf.ByteString; +import com.google.protobuf.BytesValue; +import com.google.protobuf.DoubleValue; +import com.google.protobuf.FloatValue; +import com.google.protobuf.Int32Value; +import com.google.protobuf.Int64Value; +import com.google.protobuf.ListValue; +import com.google.protobuf.NullValue; +import com.google.protobuf.StringValue; +import com.google.protobuf.Struct; +import com.google.protobuf.UInt32Value; +import com.google.protobuf.UInt64Value; +import com.google.protobuf.Value; +import com.google.protobuf.util.Durations; +import com.google.protobuf.util.Timestamps; +import com.google.rpc.context.AttributeContext; +import com.google.rpc.context.AttributeContext.Auth; +import com.google.rpc.context.AttributeContext.Peer; +import com.google.rpc.context.AttributeContext.Request; +import dev.cel.common.CelDescriptorUtil; +import dev.cel.common.CelOptions; +import dev.cel.common.CelRuntimeException; +import dev.cel.common.internal.AdaptingTypes; +import dev.cel.common.internal.BidiConverter; +import dev.cel.common.internal.DefaultDescriptorPool; +import dev.cel.common.internal.DefaultMessageFactory; +import dev.cel.common.internal.DynamicProto; +import java.util.Arrays; +import java.util.List; +import org.jspecify.annotations.Nullable; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; + +@RunWith(Parameterized.class) +public final class ProtoMessageRuntimeEqualityTest { + private static final CelOptions EMPTY_OPTIONS = + CelOptions.newBuilder().disableCelStandardEquality(false).build(); + private static final CelOptions PROTO_EQUALITY = + CelOptions.newBuilder() + .disableCelStandardEquality(false) + .enableProtoDifferencerEquality(true) + .build(); + private static final DynamicProto DYNAMIC_PROTO = + DynamicProto.create( + DefaultMessageFactory.create( + DefaultDescriptorPool.create( + CelDescriptorUtil.getAllDescriptorsFromFileDescriptor( + AttributeContext.getDescriptor().getFile())))); + + private static final ProtoMessageRuntimeEquality RUNTIME_EQUALITY_LEGACY_OPTIONS = + ProtoMessageRuntimeEquality.create(DYNAMIC_PROTO, CelOptions.LEGACY); + + private static final ProtoMessageRuntimeEquality RUNTIME_EQUALITY_DEFAULT_OPTIONS = + ProtoMessageRuntimeEquality.create(DYNAMIC_PROTO, CelOptions.DEFAULT); + + private static final ProtoMessageRuntimeEquality RUNTIME_EQUALITY_EMPTY_OPTIONS = + ProtoMessageRuntimeEquality.create(DYNAMIC_PROTO, EMPTY_OPTIONS); + + private static final ProtoMessageRuntimeEquality RUNTIME_EQUALITY_PROTO_EQUALITY = + ProtoMessageRuntimeEquality.create(DYNAMIC_PROTO, PROTO_EQUALITY); + + @Test + public void inMap() throws Exception { + ImmutableMap map = ImmutableMap.of("key", "value", "key2", "value2"); + assertThat(RUNTIME_EQUALITY_EMPTY_OPTIONS.inMap(map, "key2")).isTrue(); + assertThat(RUNTIME_EQUALITY_EMPTY_OPTIONS.inMap(map, "key3")).isFalse(); + + ImmutableMap mixedKeyMap = + ImmutableMap.of( + "key", "value", 2L, "value2", UnsignedLong.valueOf(42), "answer to everything"); + // Integer tests. + assertThat(RUNTIME_EQUALITY_EMPTY_OPTIONS.inMap(mixedKeyMap, 2)).isTrue(); + assertThat(RUNTIME_EQUALITY_EMPTY_OPTIONS.inMap(mixedKeyMap, 3)).isFalse(); + + // Long tests. + assertThat(RUNTIME_EQUALITY_EMPTY_OPTIONS.inMap(mixedKeyMap, -1L)).isFalse(); + assertThat(RUNTIME_EQUALITY_EMPTY_OPTIONS.inMap(mixedKeyMap, 3L)).isFalse(); + assertThat(RUNTIME_EQUALITY_EMPTY_OPTIONS.inMap(mixedKeyMap, 2L)).isTrue(); + assertThat(RUNTIME_EQUALITY_EMPTY_OPTIONS.inMap(mixedKeyMap, 42L)).isTrue(); + + // Floating point tests + assertThat(RUNTIME_EQUALITY_EMPTY_OPTIONS.inMap(mixedKeyMap, -1.0d)).isFalse(); + assertThat(RUNTIME_EQUALITY_EMPTY_OPTIONS.inMap(mixedKeyMap, 2.1d)).isFalse(); + assertThat( + RUNTIME_EQUALITY_EMPTY_OPTIONS.inMap(mixedKeyMap, UnsignedLong.MAX_VALUE.doubleValue())) + .isFalse(); + assertThat(RUNTIME_EQUALITY_EMPTY_OPTIONS.inMap(mixedKeyMap, 2.0d)).isTrue(); + assertThat(RUNTIME_EQUALITY_EMPTY_OPTIONS.inMap(mixedKeyMap, Double.NaN)).isFalse(); + + // Unsigned long tests. + assertThat(RUNTIME_EQUALITY_EMPTY_OPTIONS.inMap(mixedKeyMap, UnsignedLong.valueOf(1L))) + .isFalse(); + assertThat(RUNTIME_EQUALITY_EMPTY_OPTIONS.inMap(mixedKeyMap, UnsignedLong.valueOf(2L))) + .isTrue(); + assertThat(RUNTIME_EQUALITY_EMPTY_OPTIONS.inMap(mixedKeyMap, UnsignedLong.MAX_VALUE)).isFalse(); + assertThat(RUNTIME_EQUALITY_EMPTY_OPTIONS.inMap(mixedKeyMap, UInt64Value.of(2L))).isTrue(); + + // Validate the legacy behavior as well. + assertThat(RUNTIME_EQUALITY_LEGACY_OPTIONS.inMap(mixedKeyMap, 2)).isFalse(); + assertThat(RUNTIME_EQUALITY_LEGACY_OPTIONS.inMap(mixedKeyMap, 2L)).isTrue(); + assertThat(RUNTIME_EQUALITY_LEGACY_OPTIONS.inMap(mixedKeyMap, Int64Value.of(2L))).isFalse(); + assertThat(RUNTIME_EQUALITY_LEGACY_OPTIONS.inMap(mixedKeyMap, UInt64Value.of(2L))).isFalse(); + } + + @Test + public void inList() throws Exception { + ImmutableList list = ImmutableList.of("value", "value2"); + assertThat(RUNTIME_EQUALITY_EMPTY_OPTIONS.inList(list, "value")).isTrue(); + assertThat(RUNTIME_EQUALITY_EMPTY_OPTIONS.inList(list, "value3")).isFalse(); + + ImmutableList mixedValueList = ImmutableList.of(1, "value", 2, "value2"); + assertThat(RUNTIME_EQUALITY_EMPTY_OPTIONS.inList(mixedValueList, 2)).isTrue(); + assertThat(RUNTIME_EQUALITY_EMPTY_OPTIONS.inList(mixedValueList, 3)).isFalse(); + + assertThat(RUNTIME_EQUALITY_EMPTY_OPTIONS.inList(mixedValueList, 2L)).isTrue(); + assertThat(RUNTIME_EQUALITY_EMPTY_OPTIONS.inList(mixedValueList, 3L)).isFalse(); + + assertThat(RUNTIME_EQUALITY_EMPTY_OPTIONS.inList(mixedValueList, 2.0)).isTrue(); + assertThat(RUNTIME_EQUALITY_EMPTY_OPTIONS.inList(mixedValueList, Double.NaN)).isFalse(); + + assertThat(RUNTIME_EQUALITY_EMPTY_OPTIONS.inList(mixedValueList, UnsignedLong.valueOf(2L))) + .isTrue(); + assertThat(RUNTIME_EQUALITY_EMPTY_OPTIONS.inList(mixedValueList, UnsignedLong.valueOf(3L))) + .isFalse(); + + // Validate the legacy behavior as well. + assertThat(RUNTIME_EQUALITY_LEGACY_OPTIONS.inList(mixedValueList, 2)).isTrue(); + assertThat(RUNTIME_EQUALITY_LEGACY_OPTIONS.inList(mixedValueList, 2L)).isFalse(); + } + + @Test + public void indexMap() throws Exception { + ImmutableMap mixedKeyMap = + ImmutableMap.of(1L, "value", UnsignedLong.valueOf(2L), "value2"); + assertThat(RUNTIME_EQUALITY_DEFAULT_OPTIONS.indexMap(mixedKeyMap, 1.0)).isEqualTo("value"); + assertThat(RUNTIME_EQUALITY_DEFAULT_OPTIONS.indexMap(mixedKeyMap, 2.0)).isEqualTo("value2"); + Assert.assertThrows( + CelRuntimeException.class, + () -> RUNTIME_EQUALITY_LEGACY_OPTIONS.indexMap(mixedKeyMap, 1.0)); + Assert.assertThrows( + CelRuntimeException.class, + () -> RUNTIME_EQUALITY_DEFAULT_OPTIONS.indexMap(mixedKeyMap, 1.1)); + } + + @AutoValue + abstract static class State { + /** + * Expected comparison outcome when equality is performed with the given options. + * + *

The {@code null} value indicates that the outcome is an error. + */ + public abstract @Nullable Boolean outcome(); + + /** Runtime equality instance to use when performing the equality check. */ + public abstract ProtoMessageRuntimeEquality runtimeEquality(); + + public static State create( + @Nullable Boolean outcome, ProtoMessageRuntimeEquality runtimeEquality) { + return new AutoValue_ProtoMessageRuntimeEqualityTest_State(outcome, runtimeEquality); + } + } + + /** Represents expected result states for an equality test case. */ + @AutoValue + abstract static class Result { + + /** The result {@code State} value associated with different feature flag combinations. */ + public abstract ImmutableSet states(); + + /** + * Creates a Result for a comparison that is undefined (throws an Exception) under both equality + * modes. + */ + public static Result undefined() { + return always(null); + } + + /** Creates a Result for a comparison that is false under both equality modes. */ + public static Result alwaysFalse() { + return always(false); + } + + /** Creates a Result for a comparison that is true under both equality modes. */ + public static Result alwaysTrue() { + return always(true); + } + + public static Result unsigned(Boolean outcome) { + return Result.builder() + .states( + ImmutableList.of( + State.create(outcome, RUNTIME_EQUALITY_EMPTY_OPTIONS), + State.create(outcome, RUNTIME_EQUALITY_PROTO_EQUALITY))) + .build(); + } + + private static Result always(@Nullable Boolean outcome) { + return Result.builder() + .states( + ImmutableList.of( + State.create(outcome, RUNTIME_EQUALITY_EMPTY_OPTIONS), + State.create(outcome, RUNTIME_EQUALITY_PROTO_EQUALITY))) + .build(); + } + + private static Result proto(Boolean equalsOutcome, Boolean diffOutcome) { + return Result.builder() + .states( + ImmutableList.of( + State.create(equalsOutcome, RUNTIME_EQUALITY_EMPTY_OPTIONS), + State.create(diffOutcome, RUNTIME_EQUALITY_PROTO_EQUALITY))) + .build(); + } + + public static Builder builder() { + return new AutoValue_ProtoMessageRuntimeEqualityTest_Result.Builder(); + } + + @AutoValue.Builder + public abstract static class Builder { + abstract Builder states(ImmutableList states); + + abstract Result build(); + } + } + + @Parameter(0) + public Object lhs; + + @Parameter(1) + public Object rhs; + + @Parameter(2) + public Result result; + + @Parameters + public static List data() { + return Arrays.asList( + new Object[][] { + // Boolean tests. + {true, true, Result.alwaysTrue()}, + {BoolValue.of(true), true, Result.alwaysTrue()}, + {Any.pack(BoolValue.of(true)), true, Result.alwaysTrue()}, + {Value.newBuilder().setBoolValue(true).build(), true, Result.alwaysTrue()}, + {true, false, Result.alwaysFalse()}, + {0, false, Result.alwaysFalse()}, + + // Bytes tests. + {ByteString.copyFromUtf8("h¢"), ByteString.copyFromUtf8("h¢"), Result.alwaysTrue()}, + {ByteString.copyFromUtf8("hello"), ByteString.EMPTY, Result.alwaysFalse()}, + {BytesValue.of(ByteString.EMPTY), ByteString.EMPTY, Result.alwaysTrue()}, + { + BytesValue.of(ByteString.copyFromUtf8("h¢")), + ByteString.copyFromUtf8("h¢"), + Result.alwaysTrue() + }, + {Any.pack(BytesValue.of(ByteString.EMPTY)), ByteString.EMPTY, Result.alwaysTrue()}, + {"h¢", ByteString.copyFromUtf8("h¢"), Result.alwaysFalse()}, + + // Double tests. + {1.0, 1.0, Result.alwaysTrue()}, + {Double.valueOf(1.0), 1.0, Result.alwaysTrue()}, + {DoubleValue.of(42.5), 42.5, Result.alwaysTrue()}, + // Floats are unwrapped to double types. + {FloatValue.of(1.0f), 1.0, Result.alwaysTrue()}, + {Value.newBuilder().setNumberValue(-1.5D).build(), -1.5, Result.alwaysTrue()}, + {1.0, -1.0, Result.alwaysFalse()}, + {1.0, 1.0D, Result.alwaysTrue()}, + {1.0, 1.1D, Result.alwaysFalse()}, + {1.0D, 1.1f, Result.alwaysFalse()}, + {1.0, 1, Result.alwaysTrue()}, + + // Float tests. + {1.0f, 1.0f, Result.alwaysTrue()}, + {Float.valueOf(1.0f), 1.0f, Result.alwaysTrue()}, + {1.0f, -1.0f, Result.alwaysFalse()}, + {1.0f, 1.0, Result.alwaysTrue()}, + + // Integer tests. + {16, 16, Result.alwaysTrue()}, + {17, 16, Result.alwaysFalse()}, + {17, 16.0, Result.alwaysFalse()}, + + // Long tests. + {-15L, -15L, Result.alwaysTrue()}, + // Int32 values are unwrapped to int types. + {Int32Value.of(-15), -15L, Result.alwaysTrue()}, + {Int64Value.of(-15L), -15L, Result.alwaysTrue()}, + {Any.pack(Int32Value.of(-15)), -15L, Result.alwaysTrue()}, + {Any.pack(Int64Value.of(-15L)), -15L, Result.alwaysTrue()}, + {-15L, -16L, Result.alwaysFalse()}, + {-15L, -15, Result.alwaysTrue()}, + {-15L, 15.0, Result.alwaysFalse()}, + + // Null tests. + {null, null, Result.alwaysTrue()}, + {false, null, Result.alwaysFalse()}, + {0.0, null, Result.alwaysFalse()}, + {0, null, Result.alwaysFalse()}, + {null, "null", Result.alwaysFalse()}, + {"null", null, Result.alwaysFalse()}, + {null, NullValue.NULL_VALUE, Result.alwaysTrue()}, + {null, ImmutableList.of(), Result.alwaysFalse()}, + {ImmutableMap.of(), null, Result.alwaysFalse()}, + {ByteString.copyFromUtf8(""), null, Result.alwaysFalse()}, + {null, Timestamps.EPOCH, Result.alwaysFalse()}, + {Durations.ZERO, null, Result.alwaysFalse()}, + {NullValue.NULL_VALUE, NullValue.NULL_VALUE, Result.alwaysTrue()}, + { + Value.newBuilder().setNullValue(NullValue.NULL_VALUE).build(), + NullValue.NULL_VALUE, + Result.alwaysTrue() + }, + { + Any.pack(Value.newBuilder().setNullValue(NullValue.NULL_VALUE).build()), + NullValue.NULL_VALUE, + Result.alwaysTrue() + }, + + // String tests. + {"", "", Result.alwaysTrue()}, + {"str", "str", Result.alwaysTrue()}, + {StringValue.of("str"), "str", Result.alwaysTrue()}, + {Value.newBuilder().setStringValue("str").build(), "str", Result.alwaysTrue()}, + {Any.pack(StringValue.of("str")), "str", Result.alwaysTrue()}, + {Any.pack(Value.newBuilder().setStringValue("str").build()), "str", Result.alwaysTrue()}, + {"", "non-empty", Result.alwaysFalse()}, + + // Uint tests. + {UInt32Value.of(1234), 1234L, Result.alwaysTrue()}, + {UInt64Value.of(1234L), 1234L, Result.alwaysTrue()}, + {UInt64Value.of(1234L), Int64Value.of(1234L), Result.alwaysTrue()}, + {UInt32Value.of(1234), UnsignedLong.valueOf(1234L), Result.alwaysTrue()}, + {UInt64Value.of(1234L), UnsignedLong.valueOf(1234L), Result.alwaysTrue()}, + {Any.pack(UInt64Value.of(1234L)), UnsignedLong.valueOf(1234L), Result.alwaysTrue()}, + {UInt32Value.of(123), UnsignedLong.valueOf(1234L), Result.alwaysFalse()}, + {UInt64Value.of(123L), UnsignedLong.valueOf(1234L), Result.alwaysFalse()}, + {Any.pack(UInt64Value.of(123L)), UnsignedLong.valueOf(1234L), Result.alwaysFalse()}, + + // Cross-type equality tests. + {UInt32Value.of(1234), 1234.0, Result.alwaysTrue()}, + {UInt32Value.of(1234), 1234.0, Result.alwaysTrue()}, + {UInt64Value.of(1234L), 1234L, Result.alwaysTrue()}, + {UInt32Value.of(1234), 1234.1, Result.alwaysFalse()}, + {UInt64Value.of(1234L), 1233L, Result.alwaysFalse()}, + {UnsignedLong.valueOf(1234L), 1234L, Result.alwaysTrue()}, + {UnsignedLong.valueOf(1234L), 1234.1, Result.alwaysFalse()}, + {1234L, 1233.2, Result.alwaysFalse()}, + {-1234L, UnsignedLong.valueOf(1233L), Result.alwaysFalse()}, + + // List tests. + // Note, this list equality behaves equivalently to the following expression: + // 1.0 == 1.0 && "dos" == 2.0 && 3.0 == 4.0 + // The middle predicate is an error; however, the last comparison yields false and so + + // the error is short-circuited away. + {Arrays.asList(1.0, "dos", 3.0), Arrays.asList(1.0, 2.0, 4.0), Result.alwaysFalse()}, + {Arrays.asList("1", 2), ImmutableList.of("1", 2), Result.alwaysTrue()}, + {Arrays.asList("1", 2), ImmutableSet.of("1", 2), Result.alwaysTrue()}, + {Arrays.asList(1.0, 2.0, 3.0), Arrays.asList(1.0, 2.0), Result.alwaysFalse()}, + {Arrays.asList(1.0, 3.0), Arrays.asList(1.0, 2.0), Result.alwaysFalse()}, + { + AdaptingTypes.adaptingList( + ImmutableList.of(1, 2, 3), + BidiConverter.of( + ProtoMessageRuntimeHelpers.INT32_TO_INT64, + ProtoMessageRuntimeHelpers.INT64_TO_INT32)), + Arrays.asList(1L, 2L, 3L), + Result.alwaysTrue() + }, + { + ListValue.newBuilder() + .addValues(Value.newBuilder().setStringValue("hello")) + .addValues(Value.newBuilder().setStringValue("world")) + .build(), + ImmutableList.of("hello", "world"), + Result.alwaysTrue() + }, + { + ListValue.newBuilder() + .addValues(Value.newBuilder().setStringValue("hello")) + .addValues(Value.newBuilder().setListValue(ListValue.getDefaultInstance())) + .build(), + ImmutableList.of("hello", "world"), + Result.alwaysFalse() + }, + { + ListValue.newBuilder() + .addValues(Value.newBuilder().setListValue(ListValue.getDefaultInstance())) + .addValues( + Value.newBuilder() + .setListValue( + ListValue.newBuilder() + .addValues(Value.newBuilder().setBoolValue(true)))) + .build(), + ImmutableList.of(ImmutableList.of(), ImmutableList.of(true)), + Result.alwaysTrue() + }, + { + Value.newBuilder() + .setListValue( + ListValue.newBuilder() + .addValues(Value.newBuilder().setNumberValue(-1.5)) + .addValues(Value.newBuilder().setNumberValue(42.25))) + .build(), + AdaptingTypes.adaptingList( + ImmutableList.of(-1.5f, 42.25f), + BidiConverter.of( + ProtoMessageRuntimeHelpers.FLOAT_TO_DOUBLE, + ProtoMessageRuntimeHelpers.DOUBLE_TO_FLOAT)), + Result.alwaysTrue() + }, + + // Map tests. + {ImmutableMap.of("one", 1), ImmutableMap.of("one", "uno"), Result.alwaysFalse()}, + {ImmutableMap.of("two", 2), ImmutableMap.of("two", 3), Result.alwaysFalse()}, + {ImmutableMap.of("one", 2), ImmutableMap.of("two", 3), Result.alwaysFalse()}, + // Note, this map is the composition of the following two tests above where: + // ("one", 1) == ("one", "uno") -> error + // ("two", 2) == ("two", 3) -> false + // Within CEL error && false -> false, and the key order in the test has specifically + // been chosen to exercise this behavior. + { + ImmutableMap.of("one", 1, "two", 2), + ImmutableMap.of("one", "uno", "two", 3), + Result.alwaysFalse() + }, + {ImmutableMap.of("key", "value"), ImmutableMap.of("key", "value"), Result.alwaysTrue()}, + {ImmutableMap.of(), ImmutableMap.of("key", "value"), Result.alwaysFalse()}, + {ImmutableMap.of("key", "value"), ImmutableMap.of("key", "diff"), Result.alwaysFalse()}, + {ImmutableMap.of("key", 42), ImmutableMap.of("key", 42L), Result.alwaysTrue()}, + {ImmutableMap.of("key", 42.0), ImmutableMap.of("key", 42L), Result.alwaysTrue()}, + { + AdaptingTypes.adaptingMap( + ImmutableMap.of("key1", 42, "key2", 31, "key3", 20), + BidiConverter.identity(), + BidiConverter.of( + ProtoMessageRuntimeHelpers.INT32_TO_INT64, + ProtoMessageRuntimeHelpers.INT64_TO_INT32)), + ImmutableMap.of("key1", 42L, "key2", 31L, "key3", 20L), + Result.alwaysTrue() + }, + { + AdaptingTypes.adaptingMap( + ImmutableMap.of(1, 42.5f, 2, 31f, 3, 20.25f), + BidiConverter.of( + ProtoMessageRuntimeHelpers.INT32_TO_INT64, + ProtoMessageRuntimeHelpers.INT64_TO_INT32), + BidiConverter.of( + ProtoMessageRuntimeHelpers.FLOAT_TO_DOUBLE, + ProtoMessageRuntimeHelpers.DOUBLE_TO_FLOAT)), + ImmutableMap.of(1L, 42.5D, 2L, 31D, 3L, 20.25D), + Result.alwaysTrue() + }, + { + AdaptingTypes.adaptingMap( + ImmutableMap.of("1", 42.5f, "2", 31f, "3", 20.25f), + BidiConverter.identity(), + BidiConverter.of( + ProtoMessageRuntimeHelpers.FLOAT_TO_DOUBLE, + ProtoMessageRuntimeHelpers.DOUBLE_TO_FLOAT)), + Struct.getDefaultInstance(), + Result.alwaysFalse() + }, + { + AdaptingTypes.adaptingMap( + ImmutableMap.of("1", 42.5f, "2", 31f, "3", 20.25f), + BidiConverter.identity(), + BidiConverter.of( + ProtoMessageRuntimeHelpers.FLOAT_TO_DOUBLE, + ProtoMessageRuntimeHelpers.DOUBLE_TO_FLOAT)), + Struct.newBuilder() + .putFields("1", Value.newBuilder().setNumberValue(42.5D).build()) + .putFields("2", Value.newBuilder().setNumberValue(31D).build()) + .putFields("3", Value.newBuilder().setNumberValue(20.25D).build()) + .build(), + Result.alwaysTrue() + }, + { + AdaptingTypes.adaptingMap( + ImmutableMap.of("1", 42.5f, "2", 31f, "3", 20.25f), + BidiConverter.identity(), + BidiConverter.of( + ProtoMessageRuntimeHelpers.FLOAT_TO_DOUBLE, + ProtoMessageRuntimeHelpers.DOUBLE_TO_FLOAT)), + Struct.newBuilder() + .putFields("1", Value.newBuilder().setNumberValue(42.5D).build()) + .putFields("2", Value.newBuilder().setNumberValue(31D).build()) + .putFields("3", Value.newBuilder().setStringValue("oops").build()) + .build(), + Result.alwaysFalse() + }, + + // Protobuf tests. + { + AttributeContext.newBuilder().setRequest(Request.getDefaultInstance()).build(), + AttributeContext.newBuilder().setRequest(Request.newBuilder().setHost("")).build(), + Result.alwaysTrue() + }, + { + AttributeContext.newBuilder() + .setRequest(Request.getDefaultInstance()) + .setOrigin(Peer.getDefaultInstance()) + .build(), + AttributeContext.newBuilder().setRequest(Request.getDefaultInstance()).build(), + Result.alwaysFalse() + }, + // Proto differencer unpacks any values. + { + AttributeContext.newBuilder() + .addExtensions( + Any.newBuilder() + .setTypeUrl("type.googleapis.com/google.rpc.context.AttributeContext") + .setValue(ByteString.copyFromUtf8("\032\000:\000")) + .build()) + .build(), + AttributeContext.newBuilder() + .addExtensions( + Any.newBuilder() + .setTypeUrl("type.googleapis.com/google.rpc.context.AttributeContext") + .setValue(ByteString.copyFromUtf8(":\000\032\000")) + .build()) + .build(), + Result.builder() + .states( + ImmutableList.of( + State.create(false, RUNTIME_EQUALITY_EMPTY_OPTIONS), + State.create(true, RUNTIME_EQUALITY_PROTO_EQUALITY))) + .build() + }, + // If type url is missing, fallback to bytes comparison for payload. + { + AttributeContext.newBuilder() + .addExtensions( + Any.newBuilder().setValue(ByteString.copyFromUtf8("\032\000:\000")).build()) + .build(), + AttributeContext.newBuilder() + .addExtensions( + Any.newBuilder().setValue(ByteString.copyFromUtf8(":\000\032\000")).build()) + .build(), + Result.alwaysFalse() + }, + { + AttributeContext.newBuilder() + .setRequest(Request.getDefaultInstance()) + .setOrigin(Peer.getDefaultInstance()) + .build(), + "test string", + Result.alwaysFalse() + }, + { + AttributeContext.newBuilder() + .setRequest(Request.getDefaultInstance()) + .setOrigin(Peer.getDefaultInstance()) + .build(), + null, + Result.alwaysFalse() + }, + { + AttributeContext.newBuilder() + .addExtensions( + Any.pack( + AttributeContext.newBuilder() + .setRequest(Request.getDefaultInstance()) + .setOrigin(Peer.getDefaultInstance()) + .build())) + .build(), + AttributeContext.newBuilder() + .addExtensions( + Any.pack( + AttributeContext.newBuilder() + .setRequest(Request.getDefaultInstance()) + .build())) + .build(), + Result.alwaysFalse() + }, + { + AttributeContext.getDefaultInstance(), + AttributeContext.newBuilder() + .setRequest(Request.newBuilder().setHost("localhost")) + .build(), + Result.alwaysFalse() + }, + // Differently typed messages aren't comparable. + {AttributeContext.getDefaultInstance(), Auth.getDefaultInstance(), Result.alwaysFalse()}, + // Message.equals() treats NaN values as equal. Message differencer treats NaN values + // as inequal (the same behavior as the C++ implementation). + { + AttributeContext.newBuilder() + .setRequest( + Request.newBuilder() + .setAuth( + Auth.newBuilder() + .setClaims( + Struct.newBuilder() + .putFields( + "custom", + Value.newBuilder() + .setNumberValue(Double.NaN) + .build())))) + .build(), + AttributeContext.newBuilder() + .setRequest( + Request.newBuilder() + .setAuth( + Auth.newBuilder() + .setClaims( + Struct.newBuilder() + .putFields( + "custom", + Value.newBuilder() + .setNumberValue(Double.NaN) + .build())))) + .build(), + Result.proto(/* equalsOutcome= */ true, /* diffOutcome= */ false), + }, + + // Note: this is the motivating use case for converting to heterogeneous equality in + // the future. + { + AttributeContext.newBuilder() + .setRequest( + Request.newBuilder() + .setAuth( + Auth.newBuilder() + .setClaims( + Struct.newBuilder() + .putFields( + "custom", + Value.newBuilder().setNumberValue(123.0).build())))) + .build(), + AttributeContext.newBuilder() + .setRequest( + Request.newBuilder() + .setAuth( + Auth.newBuilder() + .setClaims( + Struct.newBuilder() + .putFields( + "custom", + Value.newBuilder().setBoolValue(true).build())))) + .build(), + Result.alwaysFalse(), + }, + }); + } + + @Test + public void objectEquals() throws Exception { + for (State state : result.states()) { + if (state.outcome() == null) { + Assert.assertThrows( + CelRuntimeException.class, () -> state.runtimeEquality().objectEquals(lhs, rhs)); + Assert.assertThrows( + CelRuntimeException.class, () -> state.runtimeEquality().objectEquals(rhs, lhs)); + return; + } + assertThat(state.runtimeEquality().objectEquals(lhs, rhs)).isEqualTo(state.outcome()); + assertThat(state.runtimeEquality().objectEquals(rhs, lhs)).isEqualTo(state.outcome()); + } + } +} diff --git a/runtime/src/test/java/dev/cel/runtime/RuntimeHelpersTest.java b/runtime/src/test/java/dev/cel/runtime/ProtoMessageRuntimeHelpersTest.java similarity index 53% rename from runtime/src/test/java/dev/cel/runtime/RuntimeHelpersTest.java rename to runtime/src/test/java/dev/cel/runtime/ProtoMessageRuntimeHelpersTest.java index b22c3f14e..985fbcc42 100644 --- a/runtime/src/test/java/dev/cel/runtime/RuntimeHelpersTest.java +++ b/runtime/src/test/java/dev/cel/runtime/ProtoMessageRuntimeHelpersTest.java @@ -48,13 +48,15 @@ import org.junit.runners.JUnit4; @RunWith(JUnit4.class) -public final class RuntimeHelpersTest { +public final class ProtoMessageRuntimeHelpersTest { private static final DynamicProto DYNAMIC_PROTO = DynamicProto.create(DefaultMessageFactory.INSTANCE); + private static final RuntimeHelpers RUNTIME_HELPER = + ProtoMessageRuntimeHelpers.create(DYNAMIC_PROTO, CelOptions.DEFAULT); @Test public void createDurationFromString() throws Exception { - assertThat(RuntimeHelpers.createDurationFromString("15.11s")) + assertThat(ProtoMessageRuntimeHelpers.createDurationFromString("15.11s")) .isEqualTo(Duration.newBuilder().setSeconds(15).setNanos(110000000).build()); } @@ -62,127 +64,136 @@ public void createDurationFromString() throws Exception { public void createDurationFromString_outOfRange() throws Exception { assertThrows( IllegalArgumentException.class, - () -> RuntimeHelpers.createDurationFromString("-320000000000s")); + () -> ProtoMessageRuntimeHelpers.createDurationFromString("-320000000000s")); } @Test public void int64Add() throws Exception { - assertThat(RuntimeHelpers.int64Add(1, 1, CelOptions.LEGACY)).isEqualTo(2); - assertThat(RuntimeHelpers.int64Add(2, 2, CelOptions.DEFAULT)).isEqualTo(4); - assertThat(RuntimeHelpers.int64Add(1, Long.MAX_VALUE, CelOptions.LEGACY)) + assertThat(ProtoMessageRuntimeHelpers.int64Add(1, 1, CelOptions.LEGACY)).isEqualTo(2); + assertThat(ProtoMessageRuntimeHelpers.int64Add(2, 2, CelOptions.DEFAULT)).isEqualTo(4); + assertThat(ProtoMessageRuntimeHelpers.int64Add(1, Long.MAX_VALUE, CelOptions.LEGACY)) .isEqualTo(Long.MIN_VALUE); - assertThat(RuntimeHelpers.int64Add(-1, Long.MIN_VALUE, CelOptions.LEGACY)) + assertThat(ProtoMessageRuntimeHelpers.int64Add(-1, Long.MIN_VALUE, CelOptions.LEGACY)) .isEqualTo(Long.MAX_VALUE); assertThrows( ArithmeticException.class, - () -> RuntimeHelpers.int64Add(1, Long.MAX_VALUE, CelOptions.DEFAULT)); + () -> ProtoMessageRuntimeHelpers.int64Add(1, Long.MAX_VALUE, CelOptions.DEFAULT)); assertThrows( ArithmeticException.class, - () -> RuntimeHelpers.int64Add(-1, Long.MIN_VALUE, CelOptions.DEFAULT)); + () -> ProtoMessageRuntimeHelpers.int64Add(-1, Long.MIN_VALUE, CelOptions.DEFAULT)); } @Test public void int64Divide() throws Exception { - assertThat(RuntimeHelpers.int64Divide(-44, 11, CelOptions.LEGACY)).isEqualTo(-4); - assertThat(RuntimeHelpers.int64Divide(-44, 11, CelOptions.DEFAULT)).isEqualTo(-4); - assertThat(RuntimeHelpers.int64Divide(Long.MIN_VALUE, -1, CelOptions.LEGACY)) + assertThat(ProtoMessageRuntimeHelpers.int64Divide(-44, 11, CelOptions.LEGACY)).isEqualTo(-4); + assertThat(ProtoMessageRuntimeHelpers.int64Divide(-44, 11, CelOptions.DEFAULT)).isEqualTo(-4); + assertThat(ProtoMessageRuntimeHelpers.int64Divide(Long.MIN_VALUE, -1, CelOptions.LEGACY)) .isEqualTo(Long.MIN_VALUE); assertThrows( ArithmeticException.class, - () -> RuntimeHelpers.int64Divide(Long.MIN_VALUE, -1, CelOptions.DEFAULT)); + () -> ProtoMessageRuntimeHelpers.int64Divide(Long.MIN_VALUE, -1, CelOptions.DEFAULT)); } @Test public void int64Multiply() throws Exception { - assertThat(RuntimeHelpers.int64Multiply(2, 3, CelOptions.LEGACY)).isEqualTo(6); - assertThat(RuntimeHelpers.int64Multiply(2, 3, CelOptions.DEFAULT)).isEqualTo(6); - assertThat(RuntimeHelpers.int64Multiply(Long.MIN_VALUE, -1, CelOptions.LEGACY)) + assertThat(ProtoMessageRuntimeHelpers.int64Multiply(2, 3, CelOptions.LEGACY)).isEqualTo(6); + assertThat(ProtoMessageRuntimeHelpers.int64Multiply(2, 3, CelOptions.DEFAULT)).isEqualTo(6); + assertThat(ProtoMessageRuntimeHelpers.int64Multiply(Long.MIN_VALUE, -1, CelOptions.LEGACY)) .isEqualTo(Long.MIN_VALUE); assertThrows( ArithmeticException.class, - () -> RuntimeHelpers.int64Multiply(Long.MIN_VALUE, -1, CelOptions.DEFAULT)); + () -> ProtoMessageRuntimeHelpers.int64Multiply(Long.MIN_VALUE, -1, CelOptions.DEFAULT)); } @Test public void int64Negate() throws Exception { - assertThat(RuntimeHelpers.int64Negate(7, CelOptions.LEGACY)).isEqualTo(-7); - assertThat(RuntimeHelpers.int64Negate(7, CelOptions.DEFAULT)).isEqualTo(-7); - assertThat(RuntimeHelpers.int64Negate(Long.MIN_VALUE, CelOptions.LEGACY)) + assertThat(ProtoMessageRuntimeHelpers.int64Negate(7, CelOptions.LEGACY)).isEqualTo(-7); + assertThat(ProtoMessageRuntimeHelpers.int64Negate(7, CelOptions.DEFAULT)).isEqualTo(-7); + assertThat(ProtoMessageRuntimeHelpers.int64Negate(Long.MIN_VALUE, CelOptions.LEGACY)) .isEqualTo(Long.MIN_VALUE); assertThrows( ArithmeticException.class, - () -> RuntimeHelpers.int64Negate(Long.MIN_VALUE, CelOptions.DEFAULT)); + () -> ProtoMessageRuntimeHelpers.int64Negate(Long.MIN_VALUE, CelOptions.DEFAULT)); } @Test public void int64Subtract() throws Exception { - assertThat(RuntimeHelpers.int64Subtract(50, 100, CelOptions.LEGACY)).isEqualTo(-50); - assertThat(RuntimeHelpers.int64Subtract(50, 100, CelOptions.DEFAULT)).isEqualTo(-50); - assertThat(RuntimeHelpers.int64Subtract(Long.MIN_VALUE, 1, CelOptions.LEGACY)) + assertThat(ProtoMessageRuntimeHelpers.int64Subtract(50, 100, CelOptions.LEGACY)).isEqualTo(-50); + assertThat(ProtoMessageRuntimeHelpers.int64Subtract(50, 100, CelOptions.DEFAULT)) + .isEqualTo(-50); + assertThat(ProtoMessageRuntimeHelpers.int64Subtract(Long.MIN_VALUE, 1, CelOptions.LEGACY)) .isEqualTo(Long.MAX_VALUE); - assertThat(RuntimeHelpers.int64Subtract(Long.MAX_VALUE, -1, CelOptions.LEGACY)) + assertThat(ProtoMessageRuntimeHelpers.int64Subtract(Long.MAX_VALUE, -1, CelOptions.LEGACY)) .isEqualTo(Long.MIN_VALUE); assertThrows( ArithmeticException.class, - () -> RuntimeHelpers.int64Subtract(Long.MIN_VALUE, 1, CelOptions.DEFAULT)); + () -> ProtoMessageRuntimeHelpers.int64Subtract(Long.MIN_VALUE, 1, CelOptions.DEFAULT)); assertThrows( ArithmeticException.class, - () -> RuntimeHelpers.int64Subtract(Long.MAX_VALUE, -1, CelOptions.DEFAULT)); + () -> ProtoMessageRuntimeHelpers.int64Subtract(Long.MAX_VALUE, -1, CelOptions.DEFAULT)); } @Test public void uint64CompareTo_unsignedLongs() { - assertThat(RuntimeHelpers.uint64CompareTo(UnsignedLong.ONE, UnsignedLong.ZERO)).isEqualTo(1); - assertThat(RuntimeHelpers.uint64CompareTo(UnsignedLong.ZERO, UnsignedLong.ONE)).isEqualTo(-1); - assertThat(RuntimeHelpers.uint64CompareTo(UnsignedLong.ONE, UnsignedLong.ONE)).isEqualTo(0); + assertThat(ProtoMessageRuntimeHelpers.uint64CompareTo(UnsignedLong.ONE, UnsignedLong.ZERO)) + .isEqualTo(1); + assertThat(ProtoMessageRuntimeHelpers.uint64CompareTo(UnsignedLong.ZERO, UnsignedLong.ONE)) + .isEqualTo(-1); + assertThat(ProtoMessageRuntimeHelpers.uint64CompareTo(UnsignedLong.ONE, UnsignedLong.ONE)) + .isEqualTo(0); assertThat( - RuntimeHelpers.uint64CompareTo( + ProtoMessageRuntimeHelpers.uint64CompareTo( UnsignedLong.valueOf(Long.MAX_VALUE), UnsignedLong.MAX_VALUE)) .isEqualTo(-1); } @Test public void uint64CompareTo_throwsWhenNegativeOrGreaterThanSignedLongMax() throws Exception { - assertThrows(IllegalArgumentException.class, () -> RuntimeHelpers.uint64CompareTo(-1, 0)); - assertThrows(IllegalArgumentException.class, () -> RuntimeHelpers.uint64CompareTo(0, -1)); + assertThrows( + IllegalArgumentException.class, () -> ProtoMessageRuntimeHelpers.uint64CompareTo(-1, 0)); + assertThrows( + IllegalArgumentException.class, () -> ProtoMessageRuntimeHelpers.uint64CompareTo(0, -1)); } @Test public void uint64CompareTo_unsignedComparisonAndArithmeticIsUnsigned() throws Exception { // In twos complement, -1 is represented by all bits being set. This is equivalent to the // maximum value when unsigned. - assertThat(RuntimeHelpers.uint64CompareTo(-1, 0, CelOptions.DEFAULT)).isGreaterThan(0); - assertThat(RuntimeHelpers.uint64CompareTo(0, -1, CelOptions.DEFAULT)).isLessThan(0); + assertThat(ProtoMessageRuntimeHelpers.uint64CompareTo(-1, 0, CelOptions.DEFAULT)) + .isGreaterThan(0); + assertThat(ProtoMessageRuntimeHelpers.uint64CompareTo(0, -1, CelOptions.DEFAULT)).isLessThan(0); } @Test public void uint64Add_signedLongs() throws Exception { - assertThat(RuntimeHelpers.uint64Add(4, 4, CelOptions.LEGACY)).isEqualTo(8); - assertThat(RuntimeHelpers.uint64Add(4, 4, CelOptions.DEFAULT)).isEqualTo(8); - assertThat(RuntimeHelpers.uint64Add(-1, 1, CelOptions.LEGACY)).isEqualTo(0); + assertThat(ProtoMessageRuntimeHelpers.uint64Add(4, 4, CelOptions.LEGACY)).isEqualTo(8); + assertThat(ProtoMessageRuntimeHelpers.uint64Add(4, 4, CelOptions.DEFAULT)).isEqualTo(8); + assertThat(ProtoMessageRuntimeHelpers.uint64Add(-1, 1, CelOptions.LEGACY)).isEqualTo(0); assertThrows( - ArithmeticException.class, () -> RuntimeHelpers.uint64Add(-1, 1, CelOptions.DEFAULT)); + ArithmeticException.class, + () -> ProtoMessageRuntimeHelpers.uint64Add(-1, 1, CelOptions.DEFAULT)); } @Test public void uint64Add_unsignedLongs() throws Exception { - assertThat(RuntimeHelpers.uint64Add(UnsignedLong.valueOf(4), UnsignedLong.valueOf(4))) + assertThat( + ProtoMessageRuntimeHelpers.uint64Add(UnsignedLong.valueOf(4), UnsignedLong.valueOf(4))) .isEqualTo(UnsignedLong.valueOf(8)); assertThat( - RuntimeHelpers.uint64Add( + ProtoMessageRuntimeHelpers.uint64Add( UnsignedLong.MAX_VALUE.minus(UnsignedLong.ONE), UnsignedLong.ONE)) .isEqualTo(UnsignedLong.MAX_VALUE); assertThrows( ArithmeticException.class, - () -> RuntimeHelpers.uint64Add(UnsignedLong.MAX_VALUE, UnsignedLong.ONE)); + () -> ProtoMessageRuntimeHelpers.uint64Add(UnsignedLong.MAX_VALUE, UnsignedLong.ONE)); } @Test public void uint64Multiply_signedLongs() throws Exception { - assertThat(RuntimeHelpers.uint64Multiply(32, 2, CelOptions.LEGACY)).isEqualTo(64); - assertThat(RuntimeHelpers.uint64Multiply(32, 2, CelOptions.DEFAULT)).isEqualTo(64); + assertThat(ProtoMessageRuntimeHelpers.uint64Multiply(32, 2, CelOptions.LEGACY)).isEqualTo(64); + assertThat(ProtoMessageRuntimeHelpers.uint64Multiply(32, 2, CelOptions.DEFAULT)).isEqualTo(64); assertThat( - RuntimeHelpers.uint64Multiply( + ProtoMessageRuntimeHelpers.uint64Multiply( Long.MIN_VALUE, 2, CelOptions.newBuilder() @@ -191,105 +202,119 @@ public void uint64Multiply_signedLongs() throws Exception { .isEqualTo(0); assertThrows( ArithmeticException.class, - () -> RuntimeHelpers.uint64Multiply(Long.MIN_VALUE, 2, CelOptions.DEFAULT)); + () -> ProtoMessageRuntimeHelpers.uint64Multiply(Long.MIN_VALUE, 2, CelOptions.DEFAULT)); } @Test public void uint64Multiply_unsignedLongs() throws Exception { - assertThat(RuntimeHelpers.uint64Multiply(UnsignedLong.valueOf(32), UnsignedLong.valueOf(2))) + assertThat( + ProtoMessageRuntimeHelpers.uint64Multiply( + UnsignedLong.valueOf(32), UnsignedLong.valueOf(2))) .isEqualTo(UnsignedLong.valueOf(64)); assertThrows( ArithmeticException.class, - () -> RuntimeHelpers.uint64Multiply(UnsignedLong.MAX_VALUE, UnsignedLong.valueOf(2))); + () -> + ProtoMessageRuntimeHelpers.uint64Multiply( + UnsignedLong.MAX_VALUE, UnsignedLong.valueOf(2))); } @Test public void uint64Multiply_throwsWhenNegativeOrGreaterThanSignedLongMax() throws Exception { - assertThrows(IllegalArgumentException.class, () -> RuntimeHelpers.uint64Multiply(-1, 0)); - assertThrows(IllegalArgumentException.class, () -> RuntimeHelpers.uint64Multiply(0, -1)); + assertThrows( + IllegalArgumentException.class, () -> ProtoMessageRuntimeHelpers.uint64Multiply(-1, 0)); + assertThrows( + IllegalArgumentException.class, () -> ProtoMessageRuntimeHelpers.uint64Multiply(0, -1)); } @Test public void uint64Multiply_unsignedComparisonAndArithmeticIsUnsigned() throws Exception { // In twos complement, -1 is represented by all bits being set. This is equivalent to the // maximum value when unsigned. - assertThat(RuntimeHelpers.uint64Multiply(-1, 0, CelOptions.DEFAULT)).isEqualTo(0); - assertThat(RuntimeHelpers.uint64Multiply(0, -1, CelOptions.DEFAULT)).isEqualTo(0); + assertThat(ProtoMessageRuntimeHelpers.uint64Multiply(-1, 0, CelOptions.DEFAULT)).isEqualTo(0); + assertThat(ProtoMessageRuntimeHelpers.uint64Multiply(0, -1, CelOptions.DEFAULT)).isEqualTo(0); } @Test public void uint64Divide_unsignedLongs() { - assertThat(RuntimeHelpers.uint64Divide(UnsignedLong.ZERO, UnsignedLong.ONE)) + assertThat(ProtoMessageRuntimeHelpers.uint64Divide(UnsignedLong.ZERO, UnsignedLong.ONE)) .isEqualTo(UnsignedLong.ZERO); - assertThat(RuntimeHelpers.uint64Divide(UnsignedLong.MAX_VALUE, UnsignedLong.MAX_VALUE)) + assertThat( + ProtoMessageRuntimeHelpers.uint64Divide(UnsignedLong.MAX_VALUE, UnsignedLong.MAX_VALUE)) .isEqualTo(UnsignedLong.ONE); assertThrows( CelRuntimeException.class, - () -> RuntimeHelpers.uint64Divide(UnsignedLong.MAX_VALUE, UnsignedLong.ZERO)); + () -> ProtoMessageRuntimeHelpers.uint64Divide(UnsignedLong.MAX_VALUE, UnsignedLong.ZERO)); } @Test public void uint64Divide_throwsWhenNegativeOrGreaterThanSignedLongMax() throws Exception { - assertThrows(IllegalArgumentException.class, () -> RuntimeHelpers.uint64Divide(0, -1)); - assertThrows(IllegalArgumentException.class, () -> RuntimeHelpers.uint64Divide(-1, -1)); + assertThrows( + IllegalArgumentException.class, () -> ProtoMessageRuntimeHelpers.uint64Divide(0, -1)); + assertThrows( + IllegalArgumentException.class, () -> ProtoMessageRuntimeHelpers.uint64Divide(-1, -1)); } @Test public void uint64Divide_unsignedComparisonAndArithmeticIsUnsigned() throws Exception { // In twos complement, -1 is represented by all bits being set. This is equivalent to the // maximum value when unsigned. - assertThat(RuntimeHelpers.uint64Divide(0, -1, CelOptions.DEFAULT)).isEqualTo(0); - assertThat(RuntimeHelpers.uint64Divide(-1, -1, CelOptions.DEFAULT)).isEqualTo(1); + assertThat(ProtoMessageRuntimeHelpers.uint64Divide(0, -1, CelOptions.DEFAULT)).isEqualTo(0); + assertThat(ProtoMessageRuntimeHelpers.uint64Divide(-1, -1, CelOptions.DEFAULT)).isEqualTo(1); } @Test public void uint64Mod_unsignedLongs() throws Exception { - assertThat(RuntimeHelpers.uint64Mod(UnsignedLong.ONE, UnsignedLong.valueOf(2))) + assertThat(ProtoMessageRuntimeHelpers.uint64Mod(UnsignedLong.ONE, UnsignedLong.valueOf(2))) .isEqualTo(UnsignedLong.ONE); - assertThat(RuntimeHelpers.uint64Mod(UnsignedLong.ONE, UnsignedLong.ONE)) + assertThat(ProtoMessageRuntimeHelpers.uint64Mod(UnsignedLong.ONE, UnsignedLong.ONE)) .isEqualTo(UnsignedLong.ZERO); assertThrows( CelRuntimeException.class, - () -> RuntimeHelpers.uint64Mod(UnsignedLong.ONE, UnsignedLong.ZERO)); + () -> ProtoMessageRuntimeHelpers.uint64Mod(UnsignedLong.ONE, UnsignedLong.ZERO)); } @Test public void uint64Mod_throwsWhenNegativeOrGreaterThanSignedLongMax() throws Exception { - assertThrows(IllegalArgumentException.class, () -> RuntimeHelpers.uint64Mod(0, -1)); - assertThrows(IllegalArgumentException.class, () -> RuntimeHelpers.uint64Mod(-1, -1)); + assertThrows(IllegalArgumentException.class, () -> ProtoMessageRuntimeHelpers.uint64Mod(0, -1)); + assertThrows( + IllegalArgumentException.class, () -> ProtoMessageRuntimeHelpers.uint64Mod(-1, -1)); } @Test public void uint64Mod_unsignedComparisonAndArithmeticIsUnsigned() throws Exception { // In twos complement, -1 is represented by all bits being set. This is equivalent to the // maximum value when unsigned. - assertThat(RuntimeHelpers.uint64Mod(0, -1, CelOptions.DEFAULT)).isEqualTo(0); - assertThat(RuntimeHelpers.uint64Mod(-1, -1, CelOptions.DEFAULT)).isEqualTo(0); + assertThat(ProtoMessageRuntimeHelpers.uint64Mod(0, -1, CelOptions.DEFAULT)).isEqualTo(0); + assertThat(ProtoMessageRuntimeHelpers.uint64Mod(-1, -1, CelOptions.DEFAULT)).isEqualTo(0); } @Test public void uint64Subtract_signedLongs() throws Exception { - assertThat(RuntimeHelpers.uint64Subtract(-1, 2, CelOptions.LEGACY)).isEqualTo(-3); - assertThat(RuntimeHelpers.uint64Subtract(-1, 2, CelOptions.DEFAULT)).isEqualTo(-3); - assertThat(RuntimeHelpers.uint64Subtract(0, 1, CelOptions.LEGACY)).isEqualTo(-1); + assertThat(ProtoMessageRuntimeHelpers.uint64Subtract(-1, 2, CelOptions.LEGACY)).isEqualTo(-3); + assertThat(ProtoMessageRuntimeHelpers.uint64Subtract(-1, 2, CelOptions.DEFAULT)).isEqualTo(-3); + assertThat(ProtoMessageRuntimeHelpers.uint64Subtract(0, 1, CelOptions.LEGACY)).isEqualTo(-1); assertThrows( - ArithmeticException.class, () -> RuntimeHelpers.uint64Subtract(0, 1, CelOptions.DEFAULT)); + ArithmeticException.class, + () -> ProtoMessageRuntimeHelpers.uint64Subtract(0, 1, CelOptions.DEFAULT)); assertThrows( - ArithmeticException.class, () -> RuntimeHelpers.uint64Subtract(-3, -1, CelOptions.DEFAULT)); + ArithmeticException.class, + () -> ProtoMessageRuntimeHelpers.uint64Subtract(-3, -1, CelOptions.DEFAULT)); assertThrows( ArithmeticException.class, - () -> RuntimeHelpers.uint64Subtract(55, -40, CelOptions.DEFAULT)); + () -> ProtoMessageRuntimeHelpers.uint64Subtract(55, -40, CelOptions.DEFAULT)); } @Test public void uint64Subtract_unsignedLongs() throws Exception { - assertThat(RuntimeHelpers.uint64Subtract(UnsignedLong.ONE, UnsignedLong.ONE)) + assertThat(ProtoMessageRuntimeHelpers.uint64Subtract(UnsignedLong.ONE, UnsignedLong.ONE)) .isEqualTo(UnsignedLong.ZERO); - assertThat(RuntimeHelpers.uint64Subtract(UnsignedLong.valueOf(3), UnsignedLong.valueOf(2))) + assertThat( + ProtoMessageRuntimeHelpers.uint64Subtract( + UnsignedLong.valueOf(3), UnsignedLong.valueOf(2))) .isEqualTo(UnsignedLong.ONE); assertThrows( ArithmeticException.class, - () -> RuntimeHelpers.uint64Subtract(UnsignedLong.ONE, UnsignedLong.valueOf(2))); + () -> ProtoMessageRuntimeHelpers.uint64Subtract(UnsignedLong.ONE, UnsignedLong.valueOf(2))); } @Test @@ -313,67 +338,43 @@ public void maybeAdaptPrimitive_optionalValues() { @Test public void adaptProtoToValue_wrapperValues() throws Exception { - CelOptions celOptions = CelOptions.LEGACY; - assertThat(RuntimeHelpers.adaptProtoToValue(DYNAMIC_PROTO, BoolValue.of(true), celOptions)) - .isEqualTo(true); - assertThat( - RuntimeHelpers.adaptProtoToValue( - DYNAMIC_PROTO, BytesValue.of(ByteString.EMPTY), celOptions)) + assertThat(RUNTIME_HELPER.adaptProtoToValue(BoolValue.of(true))).isEqualTo(true); + assertThat(RUNTIME_HELPER.adaptProtoToValue(BytesValue.of(ByteString.EMPTY))) .isEqualTo(ByteString.EMPTY); - assertThat(RuntimeHelpers.adaptProtoToValue(DYNAMIC_PROTO, DoubleValue.of(1.5d), celOptions)) - .isEqualTo(1.5d); - assertThat(RuntimeHelpers.adaptProtoToValue(DYNAMIC_PROTO, FloatValue.of(1.5f), celOptions)) - .isEqualTo(1.5d); - assertThat(RuntimeHelpers.adaptProtoToValue(DYNAMIC_PROTO, Int32Value.of(12), celOptions)) - .isEqualTo(12L); - assertThat(RuntimeHelpers.adaptProtoToValue(DYNAMIC_PROTO, Int64Value.of(-12L), celOptions)) - .isEqualTo(-12L); - assertThat(RuntimeHelpers.adaptProtoToValue(DYNAMIC_PROTO, UInt32Value.of(123), celOptions)) - .isEqualTo(123L); - assertThat(RuntimeHelpers.adaptProtoToValue(DYNAMIC_PROTO, UInt64Value.of(1234L), celOptions)) - .isEqualTo(1234L); - assertThat(RuntimeHelpers.adaptProtoToValue(DYNAMIC_PROTO, StringValue.of("hello"), celOptions)) - .isEqualTo("hello"); + assertThat(RUNTIME_HELPER.adaptProtoToValue(DoubleValue.of(1.5d))).isEqualTo(1.5d); + assertThat(RUNTIME_HELPER.adaptProtoToValue(FloatValue.of(1.5f))).isEqualTo(1.5d); + assertThat(RUNTIME_HELPER.adaptProtoToValue(Int32Value.of(12))).isEqualTo(12L); + assertThat(RUNTIME_HELPER.adaptProtoToValue(Int64Value.of(-12L))).isEqualTo(-12L); + assertThat(RUNTIME_HELPER.adaptProtoToValue(UInt32Value.of(123))) + .isEqualTo(UnsignedLong.valueOf(123L)); + assertThat(RUNTIME_HELPER.adaptProtoToValue(UInt64Value.of(1234L))) + .isEqualTo(UnsignedLong.valueOf(1234L)); + assertThat(RUNTIME_HELPER.adaptProtoToValue(StringValue.of("hello"))).isEqualTo("hello"); - assertThat( - RuntimeHelpers.adaptProtoToValue( - DYNAMIC_PROTO, - UInt32Value.of(123), - CelOptions.newBuilder().enableUnsignedLongs(true).build())) + assertThat(RUNTIME_HELPER.adaptProtoToValue(UInt32Value.of(123))) .isEqualTo(UnsignedLong.valueOf(123L)); - assertThat( - RuntimeHelpers.adaptProtoToValue( - DYNAMIC_PROTO, - UInt64Value.of(1234L), - CelOptions.newBuilder().enableUnsignedLongs(true).build())) + assertThat(RUNTIME_HELPER.adaptProtoToValue(UInt64Value.of(1234L))) .isEqualTo(UnsignedLong.valueOf(1234L)); } @Test public void adaptProtoToValue_jsonValues() throws Exception { - assertThat( - RuntimeHelpers.adaptProtoToValue( - DYNAMIC_PROTO, - Value.newBuilder().setStringValue("json").build(), - CelOptions.LEGACY)) + assertThat(RUNTIME_HELPER.adaptProtoToValue(Value.newBuilder().setStringValue("json").build())) .isEqualTo("json"); assertThat( - RuntimeHelpers.adaptProtoToValue( - DYNAMIC_PROTO, + RUNTIME_HELPER.adaptProtoToValue( Value.newBuilder() .setListValue( ListValue.newBuilder() .addValues(Value.newBuilder().setNumberValue(1.2d).build())) - .build(), - CelOptions.LEGACY)) + .build())) .isEqualTo(ImmutableList.of(1.2d)); Map mp = new HashMap<>(); mp.put("list_value", ImmutableList.of(false, NullValue.NULL_VALUE)); assertThat( - RuntimeHelpers.adaptProtoToValue( - DYNAMIC_PROTO, + RUNTIME_HELPER.adaptProtoToValue( Struct.newBuilder() .putFields( "list_value", @@ -384,8 +385,7 @@ public void adaptProtoToValue_jsonValues() throws Exception { .addValues( Value.newBuilder().setNullValue(NullValue.NULL_VALUE))) .build()) - .build(), - CelOptions.LEGACY)) + .build())) .isEqualTo(mp); } @@ -405,16 +405,12 @@ public void adaptProtoToValue_anyValues() throws Exception { .build(); Any anyJsonValue = Any.pack(jsonValue); mp.put("list_value", ImmutableList.of(false, NullValue.NULL_VALUE)); - assertThat(RuntimeHelpers.adaptProtoToValue(DYNAMIC_PROTO, anyJsonValue, CelOptions.LEGACY)) - .isEqualTo(mp); + assertThat(RUNTIME_HELPER.adaptProtoToValue(anyJsonValue)).isEqualTo(mp); } @Test public void adaptProtoToValue_builderValue() throws Exception { - CelOptions celOptions = CelOptions.LEGACY; - assertThat( - RuntimeHelpers.adaptProtoToValue( - DYNAMIC_PROTO, BoolValue.newBuilder().setValue(true), celOptions)) + assertThat(RUNTIME_HELPER.adaptProtoToValue(BoolValue.newBuilder().setValue(true))) .isEqualTo(true); } diff --git a/runtime/src/test/java/dev/cel/runtime/RuntimeEqualityTest.java b/runtime/src/test/java/dev/cel/runtime/RuntimeEqualityTest.java index 0418c0c49..d942295ba 100644 --- a/runtime/src/test/java/dev/cel/runtime/RuntimeEqualityTest.java +++ b/runtime/src/test/java/dev/cel/runtime/RuntimeEqualityTest.java @@ -1,4 +1,4 @@ -// Copyright 2022 Google LLC +// Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -14,672 +14,27 @@ package dev.cel.runtime; -import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; -import com.google.auto.value.AutoValue; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; -import com.google.common.collect.ImmutableSet; -import com.google.common.primitives.UnsignedLong; -import com.google.protobuf.Any; -import com.google.protobuf.BoolValue; -import com.google.protobuf.ByteString; -import com.google.protobuf.BytesValue; -import com.google.protobuf.DoubleValue; -import com.google.protobuf.FloatValue; -import com.google.protobuf.Int32Value; -import com.google.protobuf.Int64Value; -import com.google.protobuf.ListValue; -import com.google.protobuf.NullValue; -import com.google.protobuf.StringValue; -import com.google.protobuf.Struct; -import com.google.protobuf.UInt32Value; -import com.google.protobuf.UInt64Value; -import com.google.protobuf.Value; -import com.google.protobuf.util.Durations; -import com.google.protobuf.util.Timestamps; -import com.google.rpc.context.AttributeContext; -import com.google.rpc.context.AttributeContext.Auth; -import com.google.rpc.context.AttributeContext.Peer; -import com.google.rpc.context.AttributeContext.Request; -import dev.cel.common.CelDescriptorUtil; +import com.google.testing.junit.testparameterinjector.TestParameterInjector; import dev.cel.common.CelOptions; -import dev.cel.common.CelRuntimeException; -import dev.cel.common.internal.AdaptingTypes; -import dev.cel.common.internal.BidiConverter; -import dev.cel.common.internal.DefaultDescriptorPool; -import dev.cel.common.internal.DefaultMessageFactory; -import dev.cel.common.internal.DynamicProto; -import java.util.Arrays; -import java.util.List; -import org.jspecify.annotations.Nullable; -import org.junit.Assert; +import dev.cel.expr.conformance.proto2.TestAllTypes; import org.junit.Test; import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.junit.runners.Parameterized.Parameter; -import org.junit.runners.Parameterized.Parameters; -@RunWith(Parameterized.class) +@RunWith(TestParameterInjector.class) public final class RuntimeEqualityTest { - private static final CelOptions EMPTY_OPTIONS = - CelOptions.newBuilder().disableCelStandardEquality(false).build(); - private static final CelOptions PROTO_EQUALITY = - CelOptions.newBuilder() - .disableCelStandardEquality(false) - .enableProtoDifferencerEquality(true) - .build(); - private static final CelOptions UNSIGNED_LONGS = - CelOptions.newBuilder().disableCelStandardEquality(false).build(); - private static final CelOptions PROTO_EQUALITY_UNSIGNED_LONGS = - CelOptions.newBuilder() - .disableCelStandardEquality(false) - .enableProtoDifferencerEquality(true) - .build(); - - private static final RuntimeEquality RUNTIME_EQUALITY = - new RuntimeEquality( - DynamicProto.create( - DefaultMessageFactory.create( - DefaultDescriptorPool.create( - CelDescriptorUtil.getAllDescriptorsFromFileDescriptor( - AttributeContext.getDescriptor().getFile()))))); - - @Test - public void inMap() throws Exception { - CelOptions celOptions = CelOptions.newBuilder().disableCelStandardEquality(false).build(); - ImmutableMap map = ImmutableMap.of("key", "value", "key2", "value2"); - assertThat(RUNTIME_EQUALITY.inMap(map, "key2", celOptions)).isTrue(); - assertThat(RUNTIME_EQUALITY.inMap(map, "key3", celOptions)).isFalse(); - - ImmutableMap mixedKeyMap = - ImmutableMap.of( - "key", "value", 2L, "value2", UnsignedLong.valueOf(42), "answer to everything"); - // Integer tests. - assertThat(RUNTIME_EQUALITY.inMap(mixedKeyMap, 2, celOptions)).isTrue(); - assertThat(RUNTIME_EQUALITY.inMap(mixedKeyMap, 3, celOptions)).isFalse(); - - // Long tests. - assertThat(RUNTIME_EQUALITY.inMap(mixedKeyMap, -1L, celOptions)).isFalse(); - assertThat(RUNTIME_EQUALITY.inMap(mixedKeyMap, 3L, celOptions)).isFalse(); - assertThat(RUNTIME_EQUALITY.inMap(mixedKeyMap, 2L, celOptions)).isTrue(); - assertThat(RUNTIME_EQUALITY.inMap(mixedKeyMap, 42L, celOptions)).isTrue(); - - // Floating point tests - assertThat(RUNTIME_EQUALITY.inMap(mixedKeyMap, -1.0d, celOptions)).isFalse(); - assertThat(RUNTIME_EQUALITY.inMap(mixedKeyMap, 2.1d, celOptions)).isFalse(); - assertThat( - RUNTIME_EQUALITY.inMap(mixedKeyMap, UnsignedLong.MAX_VALUE.doubleValue(), celOptions)) - .isFalse(); - assertThat(RUNTIME_EQUALITY.inMap(mixedKeyMap, 2.0d, celOptions)).isTrue(); - assertThat(RUNTIME_EQUALITY.inMap(mixedKeyMap, Double.NaN, celOptions)).isFalse(); - - // Unsigned long tests. - assertThat(RUNTIME_EQUALITY.inMap(mixedKeyMap, UnsignedLong.valueOf(1L), celOptions)).isFalse(); - assertThat(RUNTIME_EQUALITY.inMap(mixedKeyMap, UnsignedLong.valueOf(2L), celOptions)).isTrue(); - assertThat(RUNTIME_EQUALITY.inMap(mixedKeyMap, UnsignedLong.MAX_VALUE, celOptions)).isFalse(); - assertThat(RUNTIME_EQUALITY.inMap(mixedKeyMap, UInt64Value.of(2L), celOptions)).isTrue(); - - // Validate the legacy behavior as well. - assertThat(RUNTIME_EQUALITY.inMap(mixedKeyMap, 2, CelOptions.LEGACY)).isFalse(); - assertThat(RUNTIME_EQUALITY.inMap(mixedKeyMap, 2L, CelOptions.LEGACY)).isTrue(); - assertThat(RUNTIME_EQUALITY.inMap(mixedKeyMap, Int64Value.of(2L), CelOptions.LEGACY)).isFalse(); - assertThat(RUNTIME_EQUALITY.inMap(mixedKeyMap, UInt64Value.of(2L), CelOptions.LEGACY)) - .isFalse(); - } - - @Test - public void inList() throws Exception { - CelOptions celOptions = CelOptions.newBuilder().disableCelStandardEquality(false).build(); - ImmutableList list = ImmutableList.of("value", "value2"); - assertThat(RUNTIME_EQUALITY.inList(list, "value", celOptions)).isTrue(); - assertThat(RUNTIME_EQUALITY.inList(list, "value3", celOptions)).isFalse(); - - ImmutableList mixedValueList = ImmutableList.of(1, "value", 2, "value2"); - assertThat(RUNTIME_EQUALITY.inList(mixedValueList, 2, celOptions)).isTrue(); - assertThat(RUNTIME_EQUALITY.inList(mixedValueList, 3, celOptions)).isFalse(); - - assertThat(RUNTIME_EQUALITY.inList(mixedValueList, 2L, celOptions)).isTrue(); - assertThat(RUNTIME_EQUALITY.inList(mixedValueList, 3L, celOptions)).isFalse(); - - assertThat(RUNTIME_EQUALITY.inList(mixedValueList, 2.0, celOptions)).isTrue(); - assertThat(RUNTIME_EQUALITY.inList(mixedValueList, Double.NaN, celOptions)).isFalse(); - - assertThat(RUNTIME_EQUALITY.inList(mixedValueList, UnsignedLong.valueOf(2L), celOptions)) - .isTrue(); - assertThat(RUNTIME_EQUALITY.inList(mixedValueList, UnsignedLong.valueOf(3L), celOptions)) - .isFalse(); - - // Validate the legacy behavior as well. - assertThat(RUNTIME_EQUALITY.inList(mixedValueList, 2, CelOptions.LEGACY)).isTrue(); - assertThat(RUNTIME_EQUALITY.inList(mixedValueList, 2L, CelOptions.LEGACY)).isFalse(); - } - - @Test - public void indexMap() throws Exception { - ImmutableMap mixedKeyMap = - ImmutableMap.of(1L, "value", UnsignedLong.valueOf(2L), "value2"); - assertThat(RUNTIME_EQUALITY.indexMap(mixedKeyMap, 1.0, CelOptions.DEFAULT)).isEqualTo("value"); - assertThat(RUNTIME_EQUALITY.indexMap(mixedKeyMap, 2.0, CelOptions.DEFAULT)).isEqualTo("value2"); - Assert.assertThrows( - CelRuntimeException.class, - () -> RUNTIME_EQUALITY.indexMap(mixedKeyMap, 1.0, CelOptions.LEGACY)); - Assert.assertThrows( - CelRuntimeException.class, - () -> RUNTIME_EQUALITY.indexMap(mixedKeyMap, 1.1, CelOptions.DEFAULT)); - } - - @AutoValue - abstract static class State { - /** - * Expected comparison outcome when equality is performed with the given options. - * - *

The {@code null} value indicates that the outcome is an error. - */ - public abstract @Nullable Boolean outcome(); - - /** Set of options to use when performing the equality check. */ - public abstract CelOptions celOptions(); - - public static State create(@Nullable Boolean outcome, CelOptions celOptions) { - return new AutoValue_RuntimeEqualityTest_State(outcome, celOptions); - } - } - - /** Represents expected result states for an equality test case. */ - @AutoValue - abstract static class Result { - - /** The result {@code State} value associated with different feature flag combinations. */ - public abstract ImmutableSet states(); - - /** - * Creates a Result for a comparison that is undefined (throws an Exception) under both equality - * modes. - */ - public static Result undefined() { - return always(null); - } - - /** Creates a Result for a comparison that is false under both equality modes. */ - public static Result alwaysFalse() { - return always(false); - } - - /** Creates a Result for a comparison that is true under both equality modes. */ - public static Result alwaysTrue() { - return always(true); - } - - public static Result signed(Boolean outcome) { - return Result.builder() - .states( - ImmutableList.of( - State.create(outcome, EMPTY_OPTIONS), State.create(outcome, PROTO_EQUALITY))) - .build(); - } - - public static Result unsigned(Boolean outcome) { - return Result.builder() - .states( - ImmutableList.of( - State.create(outcome, UNSIGNED_LONGS), - State.create(outcome, PROTO_EQUALITY_UNSIGNED_LONGS))) - .build(); - } - - private static Result always(@Nullable Boolean outcome) { - return Result.builder() - .states( - ImmutableList.of( - State.create(outcome, EMPTY_OPTIONS), - State.create(outcome, PROTO_EQUALITY), - State.create(outcome, PROTO_EQUALITY_UNSIGNED_LONGS))) - .build(); - } - - private static Result proto(Boolean equalsOutcome, Boolean diffOutcome) { - return Result.builder() - .states( - ImmutableList.of( - State.create(equalsOutcome, EMPTY_OPTIONS), - State.create(diffOutcome, PROTO_EQUALITY), - State.create(diffOutcome, PROTO_EQUALITY_UNSIGNED_LONGS))) - .build(); - } - - public static Builder builder() { - return new AutoValue_RuntimeEqualityTest_Result.Builder(); - } - - @AutoValue.Builder - public abstract static class Builder { - abstract Builder states(ImmutableList states); - - abstract Result build(); - } - } - - @Parameter(0) - public Object lhs; - - @Parameter(1) - public Object rhs; - - @Parameter(2) - public Result result; - - @Parameters - public static List data() { - return Arrays.asList( - new Object[][] { - // Boolean tests. - {true, true, Result.alwaysTrue()}, - {BoolValue.of(true), true, Result.alwaysTrue()}, - {Any.pack(BoolValue.of(true)), true, Result.alwaysTrue()}, - {Value.newBuilder().setBoolValue(true).build(), true, Result.alwaysTrue()}, - {true, false, Result.alwaysFalse()}, - {0, false, Result.alwaysFalse()}, - - // Bytes tests. - {ByteString.copyFromUtf8("h¢"), ByteString.copyFromUtf8("h¢"), Result.alwaysTrue()}, - {ByteString.copyFromUtf8("hello"), ByteString.EMPTY, Result.alwaysFalse()}, - {BytesValue.of(ByteString.EMPTY), ByteString.EMPTY, Result.alwaysTrue()}, - { - BytesValue.of(ByteString.copyFromUtf8("h¢")), - ByteString.copyFromUtf8("h¢"), - Result.alwaysTrue() - }, - {Any.pack(BytesValue.of(ByteString.EMPTY)), ByteString.EMPTY, Result.alwaysTrue()}, - {"h¢", ByteString.copyFromUtf8("h¢"), Result.alwaysFalse()}, - - // Double tests. - {1.0, 1.0, Result.alwaysTrue()}, - {Double.valueOf(1.0), 1.0, Result.alwaysTrue()}, - {DoubleValue.of(42.5), 42.5, Result.alwaysTrue()}, - // Floats are unwrapped to double types. - {FloatValue.of(1.0f), 1.0, Result.alwaysTrue()}, - {Value.newBuilder().setNumberValue(-1.5D).build(), -1.5, Result.alwaysTrue()}, - {1.0, -1.0, Result.alwaysFalse()}, - {1.0, 1.0D, Result.alwaysTrue()}, - {1.0, 1.1D, Result.alwaysFalse()}, - {1.0D, 1.1f, Result.alwaysFalse()}, - {1.0, 1, Result.alwaysTrue()}, - - // Float tests. - {1.0f, 1.0f, Result.alwaysTrue()}, - {Float.valueOf(1.0f), 1.0f, Result.alwaysTrue()}, - {1.0f, -1.0f, Result.alwaysFalse()}, - {1.0f, 1.0, Result.alwaysTrue()}, - - // Integer tests. - {16, 16, Result.alwaysTrue()}, - {17, 16, Result.alwaysFalse()}, - {17, 16.0, Result.alwaysFalse()}, - - // Long tests. - {-15L, -15L, Result.alwaysTrue()}, - // Int32 values are unwrapped to int types. - {Int32Value.of(-15), -15L, Result.alwaysTrue()}, - {Int64Value.of(-15L), -15L, Result.alwaysTrue()}, - {Any.pack(Int32Value.of(-15)), -15L, Result.alwaysTrue()}, - {Any.pack(Int64Value.of(-15L)), -15L, Result.alwaysTrue()}, - {-15L, -16L, Result.alwaysFalse()}, - {-15L, -15, Result.alwaysTrue()}, - {-15L, 15.0, Result.alwaysFalse()}, - - // Null tests. - {null, null, Result.alwaysTrue()}, - {false, null, Result.alwaysFalse()}, - {0.0, null, Result.alwaysFalse()}, - {0, null, Result.alwaysFalse()}, - {null, "null", Result.alwaysFalse()}, - {"null", null, Result.alwaysFalse()}, - {null, NullValue.NULL_VALUE, Result.alwaysTrue()}, - {null, ImmutableList.of(), Result.alwaysFalse()}, - {ImmutableMap.of(), null, Result.alwaysFalse()}, - {ByteString.copyFromUtf8(""), null, Result.alwaysFalse()}, - {null, Timestamps.EPOCH, Result.alwaysFalse()}, - {Durations.ZERO, null, Result.alwaysFalse()}, - {NullValue.NULL_VALUE, NullValue.NULL_VALUE, Result.alwaysTrue()}, - { - Value.newBuilder().setNullValue(NullValue.NULL_VALUE).build(), - NullValue.NULL_VALUE, - Result.alwaysTrue() - }, - { - Any.pack(Value.newBuilder().setNullValue(NullValue.NULL_VALUE).build()), - NullValue.NULL_VALUE, - Result.alwaysTrue() - }, - - // String tests. - {"", "", Result.alwaysTrue()}, - {"str", "str", Result.alwaysTrue()}, - {StringValue.of("str"), "str", Result.alwaysTrue()}, - {Value.newBuilder().setStringValue("str").build(), "str", Result.alwaysTrue()}, - {Any.pack(StringValue.of("str")), "str", Result.alwaysTrue()}, - {Any.pack(Value.newBuilder().setStringValue("str").build()), "str", Result.alwaysTrue()}, - {"", "non-empty", Result.alwaysFalse()}, - - // Uint tests. - {UInt32Value.of(1234), 1234L, Result.alwaysTrue()}, - {UInt64Value.of(1234L), 1234L, Result.alwaysTrue()}, - {UInt64Value.of(1234L), Int64Value.of(1234L), Result.alwaysTrue()}, - {UInt32Value.of(1234), UnsignedLong.valueOf(1234L), Result.alwaysTrue()}, - {UInt64Value.of(1234L), UnsignedLong.valueOf(1234L), Result.alwaysTrue()}, - {Any.pack(UInt64Value.of(1234L)), UnsignedLong.valueOf(1234L), Result.alwaysTrue()}, - {UInt32Value.of(123), UnsignedLong.valueOf(1234L), Result.alwaysFalse()}, - {UInt64Value.of(123L), UnsignedLong.valueOf(1234L), Result.alwaysFalse()}, - {Any.pack(UInt64Value.of(123L)), UnsignedLong.valueOf(1234L), Result.alwaysFalse()}, - - // Cross-type equality tests. - {UInt32Value.of(1234), 1234.0, Result.alwaysTrue()}, - {UInt32Value.of(1234), 1234.0, Result.alwaysTrue()}, - {UInt64Value.of(1234L), 1234L, Result.alwaysTrue()}, - {UInt32Value.of(1234), 1234.1, Result.alwaysFalse()}, - {UInt64Value.of(1234L), 1233L, Result.alwaysFalse()}, - {UnsignedLong.valueOf(1234L), 1234L, Result.alwaysTrue()}, - {UnsignedLong.valueOf(1234L), 1234.1, Result.alwaysFalse()}, - {1234L, 1233.2, Result.alwaysFalse()}, - {-1234L, UnsignedLong.valueOf(1233L), Result.alwaysFalse()}, - - // List tests. - // Note, this list equality behaves equivalently to the following expression: - // 1.0 == 1.0 && "dos" == 2.0 && 3.0 == 4.0 - // The middle predicate is an error; however, the last comparison yields false and so - - // the error is short-circuited away. - {Arrays.asList(1.0, "dos", 3.0), Arrays.asList(1.0, 2.0, 4.0), Result.alwaysFalse()}, - {Arrays.asList("1", 2), ImmutableList.of("1", 2), Result.alwaysTrue()}, - {Arrays.asList("1", 2), ImmutableSet.of("1", 2), Result.alwaysTrue()}, - {Arrays.asList(1.0, 2.0, 3.0), Arrays.asList(1.0, 2.0), Result.alwaysFalse()}, - {Arrays.asList(1.0, 3.0), Arrays.asList(1.0, 2.0), Result.alwaysFalse()}, - { - AdaptingTypes.adaptingList( - ImmutableList.of(1, 2, 3), - BidiConverter.of(RuntimeHelpers.INT32_TO_INT64, RuntimeHelpers.INT64_TO_INT32)), - Arrays.asList(1L, 2L, 3L), - Result.alwaysTrue() - }, - { - ListValue.newBuilder() - .addValues(Value.newBuilder().setStringValue("hello")) - .addValues(Value.newBuilder().setStringValue("world")) - .build(), - ImmutableList.of("hello", "world"), - Result.alwaysTrue() - }, - { - ListValue.newBuilder() - .addValues(Value.newBuilder().setStringValue("hello")) - .addValues(Value.newBuilder().setListValue(ListValue.getDefaultInstance())) - .build(), - ImmutableList.of("hello", "world"), - Result.alwaysFalse() - }, - { - ListValue.newBuilder() - .addValues(Value.newBuilder().setListValue(ListValue.getDefaultInstance())) - .addValues( - Value.newBuilder() - .setListValue( - ListValue.newBuilder() - .addValues(Value.newBuilder().setBoolValue(true)))) - .build(), - ImmutableList.of(ImmutableList.of(), ImmutableList.of(true)), - Result.alwaysTrue() - }, - { - Value.newBuilder() - .setListValue( - ListValue.newBuilder() - .addValues(Value.newBuilder().setNumberValue(-1.5)) - .addValues(Value.newBuilder().setNumberValue(42.25))) - .build(), - AdaptingTypes.adaptingList( - ImmutableList.of(-1.5f, 42.25f), - BidiConverter.of(RuntimeHelpers.FLOAT_TO_DOUBLE, RuntimeHelpers.DOUBLE_TO_FLOAT)), - Result.alwaysTrue() - }, - - // Map tests. - {ImmutableMap.of("one", 1), ImmutableMap.of("one", "uno"), Result.alwaysFalse()}, - {ImmutableMap.of("two", 2), ImmutableMap.of("two", 3), Result.alwaysFalse()}, - {ImmutableMap.of("one", 2), ImmutableMap.of("two", 3), Result.alwaysFalse()}, - // Note, this map is the composition of the following two tests above where: - // ("one", 1) == ("one", "uno") -> error - // ("two", 2) == ("two", 3) -> false - // Within CEL error && false -> false, and the key order in the test has specifically - // been chosen to exercise this behavior. - { - ImmutableMap.of("one", 1, "two", 2), - ImmutableMap.of("one", "uno", "two", 3), - Result.alwaysFalse() - }, - {ImmutableMap.of("key", "value"), ImmutableMap.of("key", "value"), Result.alwaysTrue()}, - {ImmutableMap.of(), ImmutableMap.of("key", "value"), Result.alwaysFalse()}, - {ImmutableMap.of("key", "value"), ImmutableMap.of("key", "diff"), Result.alwaysFalse()}, - {ImmutableMap.of("key", 42), ImmutableMap.of("key", 42L), Result.alwaysTrue()}, - {ImmutableMap.of("key", 42.0), ImmutableMap.of("key", 42L), Result.alwaysTrue()}, - { - AdaptingTypes.adaptingMap( - ImmutableMap.of("key1", 42, "key2", 31, "key3", 20), - BidiConverter.identity(), - BidiConverter.of(RuntimeHelpers.INT32_TO_INT64, RuntimeHelpers.INT64_TO_INT32)), - ImmutableMap.of("key1", 42L, "key2", 31L, "key3", 20L), - Result.alwaysTrue() - }, - { - AdaptingTypes.adaptingMap( - ImmutableMap.of(1, 42.5f, 2, 31f, 3, 20.25f), - BidiConverter.of(RuntimeHelpers.INT32_TO_INT64, RuntimeHelpers.INT64_TO_INT32), - BidiConverter.of(RuntimeHelpers.FLOAT_TO_DOUBLE, RuntimeHelpers.DOUBLE_TO_FLOAT)), - ImmutableMap.of(1L, 42.5D, 2L, 31D, 3L, 20.25D), - Result.alwaysTrue() - }, - { - AdaptingTypes.adaptingMap( - ImmutableMap.of("1", 42.5f, "2", 31f, "3", 20.25f), - BidiConverter.identity(), - BidiConverter.of(RuntimeHelpers.FLOAT_TO_DOUBLE, RuntimeHelpers.DOUBLE_TO_FLOAT)), - Struct.getDefaultInstance(), - Result.alwaysFalse() - }, - { - AdaptingTypes.adaptingMap( - ImmutableMap.of("1", 42.5f, "2", 31f, "3", 20.25f), - BidiConverter.identity(), - BidiConverter.of(RuntimeHelpers.FLOAT_TO_DOUBLE, RuntimeHelpers.DOUBLE_TO_FLOAT)), - Struct.newBuilder() - .putFields("1", Value.newBuilder().setNumberValue(42.5D).build()) - .putFields("2", Value.newBuilder().setNumberValue(31D).build()) - .putFields("3", Value.newBuilder().setNumberValue(20.25D).build()) - .build(), - Result.alwaysTrue() - }, - { - AdaptingTypes.adaptingMap( - ImmutableMap.of("1", 42.5f, "2", 31f, "3", 20.25f), - BidiConverter.identity(), - BidiConverter.of(RuntimeHelpers.FLOAT_TO_DOUBLE, RuntimeHelpers.DOUBLE_TO_FLOAT)), - Struct.newBuilder() - .putFields("1", Value.newBuilder().setNumberValue(42.5D).build()) - .putFields("2", Value.newBuilder().setNumberValue(31D).build()) - .putFields("3", Value.newBuilder().setStringValue("oops").build()) - .build(), - Result.alwaysFalse() - }, - - // Protobuf tests. - { - AttributeContext.newBuilder().setRequest(Request.getDefaultInstance()).build(), - AttributeContext.newBuilder().setRequest(Request.newBuilder().setHost("")).build(), - Result.alwaysTrue() - }, - { - AttributeContext.newBuilder() - .setRequest(Request.getDefaultInstance()) - .setOrigin(Peer.getDefaultInstance()) - .build(), - AttributeContext.newBuilder().setRequest(Request.getDefaultInstance()).build(), - Result.alwaysFalse() - }, - // Proto differencer unpacks any values. - { - AttributeContext.newBuilder() - .addExtensions( - Any.newBuilder() - .setTypeUrl("type.googleapis.com/google.rpc.context.AttributeContext") - .setValue(ByteString.copyFromUtf8("\032\000:\000")) - .build()) - .build(), - AttributeContext.newBuilder() - .addExtensions( - Any.newBuilder() - .setTypeUrl("type.googleapis.com/google.rpc.context.AttributeContext") - .setValue(ByteString.copyFromUtf8(":\000\032\000")) - .build()) - .build(), - Result.builder() - .states( - ImmutableList.of( - State.create(false, EMPTY_OPTIONS), State.create(true, PROTO_EQUALITY))) - .build() - }, - // If type url is missing, fallback to bytes comparison for payload. - { - AttributeContext.newBuilder() - .addExtensions( - Any.newBuilder().setValue(ByteString.copyFromUtf8("\032\000:\000")).build()) - .build(), - AttributeContext.newBuilder() - .addExtensions( - Any.newBuilder().setValue(ByteString.copyFromUtf8(":\000\032\000")).build()) - .build(), - Result.alwaysFalse() - }, - { - AttributeContext.newBuilder() - .setRequest(Request.getDefaultInstance()) - .setOrigin(Peer.getDefaultInstance()) - .build(), - "test string", - Result.alwaysFalse() - }, - { - AttributeContext.newBuilder() - .setRequest(Request.getDefaultInstance()) - .setOrigin(Peer.getDefaultInstance()) - .build(), - null, - Result.alwaysFalse() - }, - { - AttributeContext.newBuilder() - .addExtensions( - Any.pack( - AttributeContext.newBuilder() - .setRequest(Request.getDefaultInstance()) - .setOrigin(Peer.getDefaultInstance()) - .build())) - .build(), - AttributeContext.newBuilder() - .addExtensions( - Any.pack( - AttributeContext.newBuilder() - .setRequest(Request.getDefaultInstance()) - .build())) - .build(), - Result.alwaysFalse() - }, - { - AttributeContext.getDefaultInstance(), - AttributeContext.newBuilder() - .setRequest(Request.newBuilder().setHost("localhost")) - .build(), - Result.alwaysFalse() - }, - // Differently typed messages aren't comparable. - {AttributeContext.getDefaultInstance(), Auth.getDefaultInstance(), Result.alwaysFalse()}, - // Message.equals() treats NaN values as equal. Message differencer treats NaN values - // as inequal (the same behavior as the C++ implementation). - { - AttributeContext.newBuilder() - .setRequest( - Request.newBuilder() - .setAuth( - Auth.newBuilder() - .setClaims( - Struct.newBuilder() - .putFields( - "custom", - Value.newBuilder() - .setNumberValue(Double.NaN) - .build())))) - .build(), - AttributeContext.newBuilder() - .setRequest( - Request.newBuilder() - .setAuth( - Auth.newBuilder() - .setClaims( - Struct.newBuilder() - .putFields( - "custom", - Value.newBuilder() - .setNumberValue(Double.NaN) - .build())))) - .build(), - Result.proto(/* equalsOutcome= */ true, /* diffOutcome= */ false), - }, - - // Note: this is the motivating use case for converting to heterogeneous equality in - // the future. - { - AttributeContext.newBuilder() - .setRequest( - Request.newBuilder() - .setAuth( - Auth.newBuilder() - .setClaims( - Struct.newBuilder() - .putFields( - "custom", - Value.newBuilder().setNumberValue(123.0).build())))) - .build(), - AttributeContext.newBuilder() - .setRequest( - Request.newBuilder() - .setAuth( - Auth.newBuilder() - .setClaims( - Struct.newBuilder() - .putFields( - "custom", - Value.newBuilder().setBoolValue(true).build())))) - .build(), - Result.alwaysFalse(), - }, - }); - } @Test - public void objectEquals() throws Exception { - for (State state : result.states()) { - if (state.outcome() == null) { - Assert.assertThrows( - CelRuntimeException.class, - () -> RUNTIME_EQUALITY.objectEquals(lhs, rhs, state.celOptions())); - Assert.assertThrows( - CelRuntimeException.class, - () -> RUNTIME_EQUALITY.objectEquals(rhs, lhs, state.celOptions())); - return; - } - assertThat(RUNTIME_EQUALITY.objectEquals(lhs, rhs, state.celOptions())) - .isEqualTo(state.outcome()); - assertThat(RUNTIME_EQUALITY.objectEquals(rhs, lhs, state.celOptions())) - .isEqualTo(state.outcome()); - } + public void objectEquals_messageLite_throws() { + RuntimeEquality runtimeEquality = + RuntimeEquality.create(RuntimeHelpers.create(), CelOptions.DEFAULT); + + // Unimplemented until CelLiteDescriptor is available. + assertThrows( + UnsupportedOperationException.class, + () -> + runtimeEquality.objectEquals( + TestAllTypes.newBuilder(), TestAllTypes.getDefaultInstance())); } } diff --git a/testing/src/main/java/dev/cel/testing/BUILD.bazel b/testing/src/main/java/dev/cel/testing/BUILD.bazel index 5c8e06a98..6fa78cf9f 100644 --- a/testing/src/main/java/dev/cel/testing/BUILD.bazel +++ b/testing/src/main/java/dev/cel/testing/BUILD.bazel @@ -107,7 +107,6 @@ java_library( "//common/types:type_providers", "//extensions:optional_library", "//runtime", - "//runtime:runtime_helper", "//runtime:unknown_attributes", "@cel_spec//proto/cel/expr:checked_java_proto", "@cel_spec//proto/cel/expr/conformance/proto3:test_all_types_java_proto", diff --git a/testing/src/main/java/dev/cel/testing/BaseInterpreterTest.java b/testing/src/main/java/dev/cel/testing/BaseInterpreterTest.java index b7f524761..be57a0de4 100644 --- a/testing/src/main/java/dev/cel/testing/BaseInterpreterTest.java +++ b/testing/src/main/java/dev/cel/testing/BaseInterpreterTest.java @@ -75,7 +75,6 @@ import dev.cel.runtime.CelRuntimeFactory; import dev.cel.runtime.CelUnknownSet; import dev.cel.runtime.CelVariableResolver; -import dev.cel.runtime.RuntimeHelpers; import dev.cel.testing.testdata.proto3.StandaloneGlobalEnum; import java.io.IOException; import java.util.Arrays; @@ -227,7 +226,7 @@ public void arithmInt64() { declareVariable("y", CelProtoTypes.DYN); source = "x + y == 1"; - runTest(extend(ImmutableMap.of("x", -5L), ImmutableMap.of("y", 6))); + runTest(extend(ImmutableMap.of("x", -5L), ImmutableMap.of("y", 6L))); } @Test @@ -2333,7 +2332,7 @@ private static TestOnlyVariableResolver newInstance(Map map) { @Override public Optional find(String name) { - return Optional.ofNullable(RuntimeHelpers.maybeAdaptPrimitive(map.get(name))); + return Optional.ofNullable(map.get(name)); } @Override