Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions bundle/src/test/java/dev/cel/bundle/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ java_library(
"@maven//:com_google_truth_extensions_truth_proto_extension",
"@maven//:junit_junit",
"@maven//:org_jspecify_jspecify",
"//testing/protos:single_file_java_proto",
],
)

Expand Down
19 changes: 19 additions & 0 deletions bundle/src/test/java/dev/cel/bundle/CelImplTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@
import dev.cel.runtime.CelUnknownSet;
import dev.cel.runtime.CelVariableResolver;
import dev.cel.runtime.UnknownContext;
import dev.cel.testing.testdata.SingleFileProto.SingleFile;
import dev.cel.testing.testdata.proto3.StandaloneGlobalEnum;
import java.time.Instant;
import java.util.ArrayList;
Expand Down Expand Up @@ -2193,6 +2194,24 @@ public void toBuilder_isImmutable() {
assertThat(newRuntimeBuilder).isNotEqualTo(celImpl.toRuntimeBuilder());
}

@Test
public void eval_withJsonFieldName()
throws Exception {
Cel cel =
standardCelBuilderWithMacros()
.addVar("file", StructTypeReference.create(SingleFile.getDescriptor().getFullName()))
.addMessageTypes(SingleFile.getDescriptor())
.setOptions(CelOptions.current().enableJsonFieldNames(true).build())
.build();
CelAbstractSyntaxTree ast =
cel.compile("file.camelCased").getAst();

Object result = cel.createProgram(ast).eval(ImmutableMap.of("file", SingleFile.newBuilder().setSnakeCased("foo").build()));

assertThat(result).isEqualTo("foo");
}


private static TypeProvider aliasingProvider(ImmutableMap<String, Type> typeAliases) {
return new TypeProvider() {
@Override
Expand Down
10 changes: 7 additions & 3 deletions checker/src/main/java/dev/cel/checker/CelCheckerLegacyImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -456,9 +456,13 @@ public CelCheckerLegacyImpl build() {
}

CelTypeProvider messageTypeProvider =
new ProtoMessageTypeProvider(
CelDescriptorUtil.getAllDescriptorsFromFileDescriptor(
fileTypeSet, celOptions.resolveTypeDependencies()));
ProtoMessageTypeProvider.newBuilder()
.setAllowJsonFieldNames(celOptions.enableJsonFieldNames())
.setCelDescriptors(
CelDescriptorUtil.getAllDescriptorsFromFileDescriptor(
fileTypeSet, celOptions.resolveTypeDependencies()))
.build();

if (celTypeProvider != null && fileTypeSet.isEmpty()) {
messageTypeProvider = celTypeProvider;
} else if (celTypeProvider != null) {
Expand Down
12 changes: 11 additions & 1 deletion common/src/main/java/dev/cel/common/CelOptions.java
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ public enum ProtoUnsetFieldOptions {

public abstract boolean enableNamespacedDeclarations();

public abstract boolean enableJsonFieldNames();

// Evaluation related options

public abstract boolean disableCelStandardEquality();
Expand Down Expand Up @@ -150,6 +152,7 @@ public static Builder newBuilder() {
.enableTimestampEpoch(false)
.enableHeterogeneousNumericComparisons(false)
.enableNamespacedDeclarations(true)
.enableJsonFieldNames(false)
// Evaluation options
.disableCelStandardEquality(true)
.evaluateCanonicalTypesToNativeValues(false)
Expand All @@ -170,7 +173,8 @@ public static Builder newBuilder() {
.enableStringConcatenation(true)
.enableListConcatenation(true)
.enableComprehension(true)
.maxRegexProgramSize(-1);
.maxRegexProgramSize(-1)
;
}

/**
Expand Down Expand Up @@ -529,6 +533,12 @@ public abstract static class Builder {
*/
public abstract Builder maxRegexProgramSize(int value);

/**
* TODO
*/
public abstract Builder enableJsonFieldNames(boolean value);


public abstract CelOptions build();
}
}
140 changes: 111 additions & 29 deletions common/src/main/java/dev/cel/common/types/ProtoMessageTypeProvider.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,9 @@

package dev.cel.common.types;

import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableMap.toImmutableMap;

import com.google.common.collect.ImmutableCollection;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableMultimap;
import com.google.common.collect.ImmutableSet;
Expand All @@ -34,11 +32,11 @@
import dev.cel.common.CelDescriptorUtil;
import dev.cel.common.CelDescriptors;
import dev.cel.common.internal.FileDescriptorSetConverter;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;

/**
* The {@code ProtoMessageTypeProvider} implements the {@link CelTypeProvider} interface to provide
Expand Down Expand Up @@ -68,35 +66,35 @@ public final class ProtoMessageTypeProvider implements CelTypeProvider {
.buildOrThrow();

private final ImmutableMap<String, CelType> allTypes;
private final boolean allowJsonFieldNames;

/** Returns a new builder for {@link ProtoMessageTypeProvider}. */
public static Builder newBuilder() {
return new Builder();
}

@Deprecated
public ProtoMessageTypeProvider() {
this(CelDescriptors.builder().build());
this(CelDescriptors.builder().build(), false);
}

@Deprecated
public ProtoMessageTypeProvider(FileDescriptorSet descriptorSet) {
this(
CelDescriptorUtil.getAllDescriptorsFromFileDescriptor(
FileDescriptorSetConverter.convert(descriptorSet)));
FileDescriptorSetConverter.convert(descriptorSet)), false);
}

@Deprecated
public ProtoMessageTypeProvider(Iterable<Descriptor> descriptors) {
this(
CelDescriptorUtil.getAllDescriptorsFromFileDescriptor(
ImmutableSet.copyOf(Iterables.transform(descriptors, Descriptor::getFile))));
ImmutableSet.copyOf(Iterables.transform(descriptors, Descriptor::getFile))), false);
}

@Deprecated
public ProtoMessageTypeProvider(ImmutableSet<FileDescriptor> fileDescriptors) {
this(CelDescriptorUtil.getAllDescriptorsFromFileDescriptor(fileDescriptors));
}

public ProtoMessageTypeProvider(CelDescriptors celDescriptors) {
this.allTypes =
ImmutableMap.<String, CelType>builder()
.putAll(createEnumTypes(celDescriptors.enumDescriptors()))
.putAll(
createProtoMessageTypes(
celDescriptors.messageTypeDescriptors(), celDescriptors.extensionDescriptors()))
.buildOrThrow();
this(CelDescriptorUtil.getAllDescriptorsFromFileDescriptor(fileDescriptors), false);
}

@Override
Expand All @@ -120,8 +118,14 @@ private ImmutableMap<String, CelType> createProtoMessageTypes(
if (protoMessageTypes.containsKey(descriptor.getFullName())) {
continue;
}
ImmutableList<String> fieldNames =
descriptor.getFields().stream().map(FieldDescriptor::getName).collect(toImmutableList());

ImmutableSet.Builder<String> fieldNamesBuilder = ImmutableSet.builder() ;
for (FieldDescriptor fd : descriptor.getFields()) {
fieldNamesBuilder.add(fd.getName());
if (allowJsonFieldNames) {
fieldNamesBuilder.add(fd.getJsonName());
}
}

Map<String, FieldDescriptor> extensionFields = new HashMap<>();
for (FieldDescriptor extension : extensionMap.get(descriptor.getFullName())) {
Expand All @@ -133,7 +137,7 @@ private ImmutableMap<String, CelType> createProtoMessageTypes(
descriptor.getFullName(),
ProtoMessageType.create(
descriptor.getFullName(),
ImmutableSet.copyOf(fieldNames),
fieldNamesBuilder.build(),
new FieldResolver(this, descriptor)::findField,
new FieldResolver(this, extensions)::findField));
}
Expand All @@ -158,19 +162,35 @@ private ImmutableMap<String, CelType> createEnumTypes(
}

private static class FieldResolver {
private final CelTypeProvider celTypeProvider;
private final ProtoMessageTypeProvider protoMessageTypeProvider;
private final ImmutableMap<String, FieldDescriptor> fields;

private FieldResolver(CelTypeProvider celTypeProvider, Descriptor descriptor) {
private static ImmutableMap<String, FieldDescriptor> collectFieldDescriptorMap(ProtoMessageTypeProvider protoMessageTypeProvider, Descriptor descriptor) {
ImmutableMap.Builder<String, FieldDescriptor> builder = ImmutableMap.builder();
for (FieldDescriptor fd : descriptor.getFields()) {
builder.put(fd.getName(), fd);
if (protoMessageTypeProvider.allowJsonFieldNames) {
builder.put(fd.getJsonName(), fd);
}
}

return builder.buildKeepingLast();
}

private FieldResolver(
ProtoMessageTypeProvider protoMessageTypeProvider,
Descriptor descriptor
) {
this(
celTypeProvider,
descriptor.getFields().stream()
.collect(toImmutableMap(FieldDescriptor::getName, Function.identity())));
protoMessageTypeProvider,
collectFieldDescriptorMap(protoMessageTypeProvider, descriptor)
);
}

private FieldResolver(
CelTypeProvider celTypeProvider, ImmutableMap<String, FieldDescriptor> fields) {
this.celTypeProvider = celTypeProvider;
ProtoMessageTypeProvider protoMessageTypeProvider,
ImmutableMap<String, FieldDescriptor> fields) {
this.protoMessageTypeProvider = protoMessageTypeProvider;
this.fields = fields;
}

Expand Down Expand Up @@ -203,11 +223,11 @@ private Optional<CelType> findFieldInternal(FieldDescriptor fieldDescriptor) {
String messageName = descriptor.getFullName();
fieldType =
CelTypes.getWellKnownCelType(messageName)
.orElse(celTypeProvider.findType(descriptor.getFullName()).orElse(null));
.orElse(protoMessageTypeProvider.findType(descriptor.getFullName()).orElse(null));
break;
case ENUM:
EnumDescriptor enumDescriptor = fieldDescriptor.getEnumType();
fieldType = celTypeProvider.findType(enumDescriptor.getFullName()).orElse(null);
fieldType = protoMessageTypeProvider.findType(enumDescriptor.getFullName()).orElse(null);
break;
default:
fieldType = PROTO_TYPE_TO_CEL_TYPE.get(fieldDescriptor.getType());
Expand All @@ -222,4 +242,66 @@ private Optional<CelType> findFieldInternal(FieldDescriptor fieldDescriptor) {
return Optional.of(fieldType);
}
}

/** Builder for {@link ProtoMessageTypeProvider}. */
public static final class Builder {
private final ImmutableSet.Builder<FileDescriptor> fileDescriptors = ImmutableSet.builder();
private boolean allowJsonFieldNames;
private CelDescriptors celDescriptors;

/** Adds a {@link FileDescriptor} to the provider. */
public Builder addFileDescriptors(FileDescriptor... fileDescriptors) {
return addFileDescriptors(Arrays.asList(fileDescriptors));
}

/** Adds a collection of {@link FileDescriptor}s to the provider. */
public Builder addFileDescriptors(Iterable<FileDescriptor> fileDescriptors) {
this.fileDescriptors.addAll(fileDescriptors);
return this;
}

/** Adds a collection of {@link Descriptor}s. The parent file of each descriptor is added. */
public Builder addDescriptors(Iterable<Descriptor> descriptors) {
this.fileDescriptors.addAll(Iterables.transform(descriptors, Descriptor::getFile));
return this;
}

/** Adds a {@link FileDescriptorSet} to the provider. */
public Builder addFileDescriptorSet(FileDescriptorSet fileDescriptorSet) {
this.fileDescriptors.addAll(FileDescriptorSetConverter.convert(fileDescriptorSet));
return this;
}

public Builder setAllowJsonFieldNames(boolean allowJsonFieldNames) {
this.allowJsonFieldNames = allowJsonFieldNames;
return this;
}

public Builder setCelDescriptors(CelDescriptors celDescriptors) {
this.celDescriptors = celDescriptors;
return this;
}

/** Builds the {@link ProtoMessageTypeProvider}. */
public ProtoMessageTypeProvider build() {
// CelDescriptors celDescriptors = CelDescriptorUtil.getAllDescriptorsFromFileDescriptor(fileDescriptors.build());
return new ProtoMessageTypeProvider(celDescriptors, allowJsonFieldNames);
}

private Builder() {}
}

private ProtoMessageTypeProvider(
CelDescriptors celDescriptors,
boolean allowJsonFieldNames
) {
this.allowJsonFieldNames = allowJsonFieldNames;
this.allTypes =
ImmutableMap.<String, CelType>builder()
.putAll(createEnumTypes(celDescriptors.enumDescriptors()))
.putAll(
createProtoMessageTypes(
celDescriptors.messageTypeDescriptors(), celDescriptors.extensionDescriptors()))
.buildOrThrow();
}
}
1 change: 1 addition & 0 deletions common/src/test/java/dev/cel/common/types/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ java_library(
"//common/types:cel_types",
"//common/types:message_type_provider",
"//common/types:type_providers",
"//testing/protos:single_file_java_proto",
"@cel_spec//proto/cel/expr:checked_java_proto",
"@cel_spec//proto/cel/expr/conformance/proto2:test_all_types_java_proto",
"@cel_spec//proto/cel/expr/conformance/proto3:test_all_types_java_proto",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import dev.cel.common.types.CelTypeProvider.CombinedCelTypeProvider;
import dev.cel.common.types.StructType.Field;
import dev.cel.expr.conformance.proto2.TestAllTypes;
import dev.cel.expr.conformance.proto2.TestAllTypesExtensions;
import dev.cel.testing.testdata.SingleFileProto.SingleFile;
import java.util.Optional;
import org.junit.Test;
import org.junit.runner.RunWith;
Expand Down Expand Up @@ -254,4 +256,22 @@ public void types_combinedDuplicateProviderIsSameAsFirst() {
CombinedCelTypeProvider combined = new CombinedCelTypeProvider(proto3Provider, proto3Provider);
assertThat(combined.types()).hasSize(proto3Provider.types().size());
}

@Test
public void findField_withJsonNameOption() {
ProtoMessageTypeProvider typeProvider =
ProtoMessageTypeProvider.newBuilder()
.addFileDescriptors(SingleFile.getDescriptor().getFile())
.setAllowJsonFieldNames(true)
.build();

ProtoMessageType msgType = (ProtoMessageType) typeProvider.findType(SingleFile.getDescriptor().getFullName()).get();

// Note that these are the same fields, with json_name option set
Optional<Field> snakeCasedField = msgType.findField("snake_cased");
Optional<Field> jsonNameField = msgType.findField("camelCased");

assertThat(snakeCasedField).hasValue(Field.of("snake_cased", SimpleType.STRING));
assertThat(jsonNameField).hasValue(Field.of("camelCased", SimpleType.STRING));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,15 @@ private FieldDescriptor findField(Descriptor descriptor, String fieldName) {
}
}

if (fieldDescriptor == null && celOptions.enableJsonFieldNames()) {
for (FieldDescriptor fd : descriptor.getFields()) {
if (fd.getJsonName().equals(fieldName)) {
fieldDescriptor = fd;
break;
}
}
}

if (fieldDescriptor == null) {
throw new IllegalArgumentException(
String.format(
Expand Down
1 change: 1 addition & 0 deletions testing/src/test/resources/protos/single_file.proto
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,5 @@ message SingleFile {

string name = 1;
Path path = 2;
string snake_cased = 3 [json_name = "camelCased"];
}
Loading